Skip to content

v1.10.0

Latest

Choose a tag to compare

@github-actions github-actions released this 02 Nov 14:11
· 230 commits to main since this release
v1.10.0

Warp v1.10.0

Warp v1.10 expands JAX integration with automatic differentiation support and multi-device jax.pmap() compatibility. The tile programming model has been enhanced with axis-specific reductions, component-level indexing, and convenience functions for creating tiles.

Performance has been significantly improved in several areas: BVH operations now support in-place rebuilding for CUDA graphs and configurable leaf sizes, built-in function calls from Python are up to 70× faster, and additional sparse matrix and FEM operations can now be captured in CUDA graphs.

Additional usability improvements include negative indexing and slicing for arrays, atomic bitwise operations, and new built-in functions including error functions and type casting.

Important: This release removes the warp.sim module (deprecated since v1.8), which has been superseded by the Newton physics engine. See the Announcements section below for migration guidance and other upcoming changes.

For a complete list of changes, see the full changelog.

New features

JAX automatic differentiation (experimental)

Warp now supports experimental automatic differentiation with JAX, allowing kernels to participate in JAX automatic differentiation workflows. This feature is contributed by @mehdiataei and builds on earlier work by @jaro-sevcik. It enables computing gradients through Warp kernels using jax.grad() by passing enable_backward=True to jax_kernel().

Key capabilities include:

  • Single and multiple output kernels: Compute gradients for kernels with one or more output arrays
  • Static input auto-detection: Scalar inputs are automatically treated as static (non-differentiable) arguments
  • Vector and matrix arrays: Arrays of composite types like wp.vec2 or wp.mat22 are fully supported
  • Multi-device execution: Compatible with jax.pmap() for distributed forward and backward passes across multiple GPUs
import jax
from warp.jax_experimental import jax_kernel

@wp.kernel
def my_kernel(a: wp.array(dtype=float), out: wp.array(dtype=float)):
    i = wp.tid()
    out[i] = a[i] ** 2.0

# Enable automatic differentiation
jax_func = jax_kernel(my_kernel, num_outputs=1, enable_backward=True)

# Compute gradients through the kernel
grad_fn = jax.grad(lambda a: jax.numpy.sum(jax_func(a)[0]))
gradient = grad_fn(input_array)  # gradient: [2*a[0], 2*a[1], ...]

This feature is experimental and has some current limitations. See the JAX Automatic Differentiation documentation for complete examples, usage details, and limitations.

Multi-device JAX support with jax.pmap()

Warp now properly supports jax.pmap() and jax.shard_map() for multi-device parallel execution, thanks to fixes contributed by @chaserileyroberts. Previously, device targeting issues prevented Warp callables from working correctly within these JAX primitives—JAX would invoke callbacks from multiple threads targeting different devices, but Warp would always execute on the default device. The fix ensures proper device coordination by extracting device ordinals from XLA FFI and adding thread synchronization for concurrent callbacks, enabling efficient data-parallel workflows across multiple GPUs.

In-place BVH rebuilding with CUDA graph support

A new wp.Bvh.rebuild() method enables rebuilding BVH hierarchies in-place without allocating new memory. This complements the existing refit() method and is particularly useful when primitive distributions change significantly.

CUDA graph capture: Unlike creating a new BVH, rebuild() reuses existing buffers, making it safe to capture in CUDA graphs. Previously captured graphs that include queries on the BVH remain valid after rebuilding, enabling high-performance repeated updates without graph re-capture overhead.

Construction algorithms: On CUDA devices, in-place rebuild supports "lbvh" only. On CPU, "sah" and "median" are supported. Defaults are chosen automatically based on the device.

Tile programming enhancements

The tile programming model has been enhanced with new capabilities to make tile-based computations more expressive and convenient:

Axis-specific reductions

The tile-reduction functions wp.tile_reduce() and wp.tile_sum() now support an optional axis parameter, enabling reductions along a specific dimension of a tile rather than reducing the entire tile to a single value. This enhancement brings NumPy-like axis semantics to tile operations.

@wp.kernel
def tile_reduce_axis(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
    a = wp.tile_load(x, shape=(4, 8), storage="shared")
    # Sum along axis 0, reducing shape from (4, 8) to (8,)
    b = wp.tile_sum(a, axis=0)
    wp.tile_store(y, b)


x = wp.array(np.arange(32).reshape(4, 8), dtype=float)
# x = [[ 0.  1.  2.  3.  4.  5.  6.  7.]
#      [ 8.  9. 10. 11. 12. 13. 14. 15.]
#      [16. 17. 18. 19. 20. 21. 22. 23.]
#      [24. 25. 26. 27. 28. 29. 30. 31.]]
y = wp.zeros(8, dtype=float)

wp.launch_tiled(tile_reduce_axis, dim=(1,), inputs=[x], outputs=[y], block_dim=32)
# y = [48. 52. 56. 60. 64. 68. 72. 76.]  (column sums)

Component-level indexing

Tiles of composite types (vectors, matrices, quaternions) now support component-level indexing and assignment. You can directly index into individual components using extended indexing syntax:

  • Vector components: tile[i][1] extracts the second component of a vector at position i
  • Matrix elements: tile[i][1, 1] accesses the element at row 1, column 1 of a matrix at position i

This provides more convenient and expressive syntax for working with structured data in tiles.

Creating tiles filled with a constant value

The new wp.tile_full() function provides a convenient way to create tiles initialized with a constant value, similar to NumPy's np.full():

# Create an 8x8 tile filled with 3.14
tile = wp.tile_full(shape=(8, 8), value=3.14, dtype=float)

New example

The new example_tile_mcgp.py example demonstrates tile-based Monte Carlo methods by implementing a walk-on-spheres algorithm for solving Laplace's equation on volumetric domains.

Performance improvements

Built-in function calls from Python

Calling Warp built-in functions from Python scope (e.g., wp.normalize(), wp.transform_identity(), matrix arithmetic like mat * mat) is now significantly faster thanks to optimizations in overload resolution. Previously, each function call would iterate through all overloads, attempt argument binding, and pack parameters into C types until finding a match. Now, Warp caches the resolved overload and parameter packing strategy based on argument types using @functools.lru_cache, eliminating redundant resolution overhead on subsequent calls.

In microbenchmarks, repeated wp.mat44 multiplication at Python scope is up to 70× faster (~570 μs → ~8 μs), while operations like wp.transform_identity() see 3-4× speedups (~100 μs → ~30 μs). The magnitude of improvement varies by operation complexity, with greater gains for operations requiring more expensive overload resolution.

Breaking change: As part of this optimization, support for passing lists, tuples, and other non-Warp array arguments to built-in functions has been removed. Calls like wp.normalize([1.0, 2.0, 3.0]) must now be written as wp.normalize(wp.vec3(1.0, 2.0, 3.0)). This simplifies the function call path and removes expensive sequence-flattening logic that was incompatible with efficient caching.

Configurable BVH leaf size

wp.Bvh and wp.Mesh now expose tunable leaf_size and bvh_leaf_size parameters, respectively, allowing users to control the number of primitives stored in each leaf node for performance optimization. The optimal leaf size depends on the query workload:

  • Intersection queries (ray casting, AABB overlap): Smaller leaf sizes (e.g., 1) are generally optimal, reducing unnecessary primitive checks
  • Closest point queries: Larger leaf sizes (e.g., 4-8) can improve performance by checking more primitives together and reducing traversal overhead
  • Mixed workloads: Moderate values (e.g., 4) provide a balanced trade-off

Behavior change: The default leaf_size for wp.Bvh has changed from 4 (hardcoded) to 1, optimizing for intersection queries which are more common. wp.Mesh retains a default bvh_leaf_size of 4 as a compromise between intersection and closest-point query performance. Users performing primarily closest-point queries may benefit from explicitly setting larger leaf sizes.

Sparse matrix operations with CUDA graphs

Sparse matrix operations in warp.sparse can now be captured in CUDA graphs for allocation-free execution. Operations like bsr_axpy(), bsr_assign(), and bsr_set_transpose() preserve matrix topology when using masked=True, while bsr_mm() adds a new max_new_nnz parameter that allows specifying an upper bound on new non-zero blocks for flexible graph capture when sparsity patterns vary within known bounds.

FEM operations with CUDA graphs

Building warp.fem geometry and function space partitions can now be captured in CUDA graphs by specifying upper bounds on partition sizes: max_cell_count and max_side_count for ExplicitGeometryPartition, and max_node_count for make_space_partition(). Additionally, building fields and restrictions is now synchronization-free by default.

Language enhancements

Array indexing and slicing improvements

Warp arrays now support negative indexing and improved slicing behavior, making array manipulation more intuitive and consistent with NumPy conventions.

Negative indexing: Access elements from the end of an array using negative indices:

@wp.kernel
def use_negative_indexing(arr: wp.array(dtype=float)):
    last = arr[-1]  # Last element
    second_last = arr[-2]  # Second-to-last element

Enhanced array slicing: Arrays now support more flexible slicing operations within kernels, including stride-based access patterns. This works with both regular arrays and tile operations:

@wp.kernel
def tile_load_strided(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
    # Load every other element from a 16x16 region into an 8x8 tile
    tile = wp.tile_load(input[::2, ::2], shape=(8, 8))
    wp.tile_store(output, tile)

input = wp.array(np.arange(256).reshape(16, 16), dtype=float)
output = wp.zeros((8, 8), dtype=float)
wp.launch_tiled(tile_load_strided, dim=(1,), inputs=[input, output], block_dim=32)

# output contains every other element from input:
# [[  0.   2.   4.   6.   8.  10.  12.  14.]
#  [ 32.  34.  36.  38.  40.  42.  44.  46.]
#  [ 64.  66.  68.  70.  72.  74.  76.  78.]
#  ...
#  [224. 226. 228. 230. 232. 234. 236. 238.]]

New built-in functions

  • Error functions: Added wp.erf(), wp.erfc(), wp.erfinv(), and wp.erfcinv() for error function computations
  • Type casting: Added wp.cast() to reinterpret values as different types while preserving bit patterns (e.g., reinterpreting float bits as int)
  • Atomic bitwise operations: Added wp.atomic_and(), wp.atomic_or(), and wp.atomic_xor() for thread-safe bitwise operations on integers, contributed by @j3soon
  • Sparse matrix utilities: Added wp.sparse.bsr_row_index() and wp.sparse.bsr_block_index() as kernel-level functions to efficiently determine which row a given block belongs to without manually searching through the compressed offset array

Bug fixes

AArch64 CPU execution with tiles

Fixed segmentation faults when running tile-based kernels on AArch64 CPUs, affecting platforms including NVIDIA Jetson (Thor, Orin), DGX Spark, Grace Hopper, and Grace Blackwell systems. The fix uses stack memory allocation instead of static memory to work around limitations in LLVM's JIT compiler.

This change is enabled by default on all CPU architectures and can be disabled if needed via wp.config.enable_tiles_in_stack_memory = False. If you encounter issues that are resolved by disabling this setting, please report them on our GitHub Issues page.

Note: This primarily affects CPU execution of tile operations, which is less common in Warp workflows but useful for debugging or scenarios in which GPU memory transfer overhead outweighs compute benefits.

Native library version verification

Warp now performs runtime version checking to detect mismatches between the Python package and native libraries (e.g., warp.dll, warp.so). This helps diagnose issues in which multiple Warp installations on the same system may cause the wrong native libraries to be loaded. When a mismatch is detected, a warning is issued but execution continues. If you see such warnings, ensure you're loading Warp from the expected installation location and that your environment doesn't have conflicting Warp versions.

Announcements

Removal of warp.sim module

The warp.sim module has been removed in this release. This module was formally deprecated in Warp v1.8 (July 2025) and has been superseded by the Newton physics engine, an independent package managed as a Linux Foundation project with a redesigned API focused on robotics and robot learning.

Migration: Users relying on warp.sim should migrate to Newton. For guidance on transitioning from warp.sim to Newton, please consult the Newton migration guide. The original deprecation announcement and community discussion can be found in GitHub Discussion #735.

Questions and discussions about Newton should be directed to the Newton Discussions section. Existing issues in the Warp repository concerning warp.sim will be closed.

JAX FFI is now the default

The default implementation of jax_kernel() is now based on JAX's Foreign Function Interface (FFI), which is required for JAX version 0.8 and newer. Most users should not need to change their code, as the FFI-based version has been available since Warp 1.7 and provides better performance through CUDA graph capture. The previous custom call implementation is still available as wp.jax_experimental.custom_call.jax_kernel() for users on older JAX versions, but it is deprecated and will not work with JAX version 0.8 or later.

Internal code reorganization: _src folder

As part of ongoing efforts to clarify Warp's public API surface, internal implementation code has been reorganized into a warp._src subpackage. This change helps distinguish between public APIs that users should rely on versus internal implementation details that may change without notice.

What this means for users:

  • No immediate breaking changes: All existing imports continue to work. Modules like warp.context, warp.types, and warp.fem remain accessible at their current paths through compatibility shims.
  • Visible in stack traces: You may see warp._src paths in error messages and stack traces (e.g., warp._src.context instead of warp.context).
  • Future direction: In upcoming releases, we plan to define and formalize the public API surface. Once established, public modules will be updated to re-export all designated public symbols, and then compatibility shims will be removed. Code that imports from internal modules will need to be updated to use public APIs or explicitly import from warp._src.* (acknowledging the use of private APIs).

This reorganization is the first step in a multi-phase effort to establish a stable public API. If you encounter any issues introduced by this reorganization, please report them on our GitHub Issues page.

Upcoming removals

The following features will be removed in v1.11 (planned for January 2026):

  • Constructing matrices from row vectors: The ability to construct matrices by passing row vectors to the matrix constructor (e.g., wp.mat22(wp.vec2(1, 2), wp.vec2(3, 4))). Use wp.matrix_from_rows() or wp.matrix_from_cols() instead. This deprecation was originally announced in v1.9 with a planned removal in v1.10, but has been extended one release cycle. While kernel-scope usage had been emitting deprecation warnings since v1.9, it was discovered that Python-scope usage lacked proper warnings. Starting in v1.10, both contexts now emit deprecation warnings.
  • graph_compatible parameter in jax_callable(): The boolean graph_compatible parameter has been deprecated in favor of the new graph_mode parameter which accepts GraphMode enum values (GraphMode.JAX, GraphMode.WARP, or GraphMode.NONE).

Platform support

  • Python 3.14: Warp now supports Python 3.14, expanding compatibility beyond the previous maximum of Python 3.13.
  • Intel-based macOS (x86_64): Support for Intel Macs has been removed in this release. Apple Silicon Macs (ARM64) continue to be fully supported with CPU execution. Users on Intel-based Macs can continue using Warp 1.9.x or earlier versions.
  • Python 3.8: We plan to drop support for Python 3.8 (end-of-life since 2024-10-07) starting with the next minor release (#1019).

Acknowledgments

We also thank the following contributors from outside the core Warp development team:

  • @j3soon for adding support for atomic bitwise operations
  • @chaserileyroberts for fixing issues with JAX interop on multiple devices
  • @mehdiataei and @jaro-sevcik for adding support for JAX automatic differentiation
  • @liblaf for improving type annotations for struct() and overload() decorators
  • @manuelkNVDA for adding support for multi-process compilation of the core library
  • @boomanaiden154 for contributing fixes to handle upcoming removals in LLVM