{"id":113051,"date":"2026-03-12T10:30:00","date_gmt":"2026-03-12T17:30:00","guid":{"rendered":"https:\/\/2.zoppoz.workers.dev:443\/https\/developer.nvidia.com\/blog\/?p=113051"},"modified":"2026-04-02T11:35:39","modified_gmt":"2026-04-02T18:35:39","slug":"build-accelerated-differentiable-computational-physics-code-for-ai-with-nvidia-warp","status":"publish","type":"post","link":"https:\/\/2.zoppoz.workers.dev:443\/https\/developer.nvidia.com\/blog\/build-accelerated-differentiable-computational-physics-code-for-ai-with-nvidia-warp\/","title":{"rendered":"Build Accelerated, Differentiable Computational Physics Code for AI with NVIDIA Warp"},"content":{"rendered":"\n<p><a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/www.nvidia.com\/en-us\/solutions\/cae\/\">Computer-aided engineering (CAE)<\/a> is shifting from human-driven workflows toward AI-driven ones, including physics foundation models that generalize across geometries and operating conditions. Unlike LLMs, these models depend on large volumes of high-fidelity, physics-compliant data.\u00a0<\/p>\n\n\n\n<p>Recent <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/arxiv.org\/abs\/2511.20455\">scaling-law work on computational fluid dynamics (CFD) surrogates<\/a> indicates that simulation-generated training data is often the limiting cost in practice. This pushes requirements onto the simulator, which must be GPU-native, fast, and able to plug directly into ML workflows.<\/p>\n\n\n\n<p><a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/github.com\/NVIDIA\/warp\">NVIDIA Warp<\/a> is a framework for accelerated simulation, data generation, and spatial computing that bridges CUDA and Python. Warp enables developers to write high-performance kernels as regular Python functions that are JIT-compiled into efficient code for execution on the GPU. Unlike the tensor-based frameworks, in which developers express computation as operations on entire N-dimensional arrays, developers author flexible kernels in the Warp framework that execute simultaneously across all elements of a computational grid.&nbsp;<\/p>\n\n\n\n<p>Simulation kernels are often expressed on computational grids and rely on data-dependent control flow like conditionals, early-outs, and selective updates that vary per element. In tensor frameworks, these patterns require composing Boolean masks that quickly become unwieldy and can waste computation on irrelevant elements. In a Warp kernel, each thread can branch, skip, or exit independently, expressing this logic naturally without masking workarounds.<\/p>\n\n\n\n<p>Furthermore, as this post will show, solvers written in Warp can be easily made differentiable through the Warp native support for automatic differentiation. They are straightforward to integrate with optimization or training workflows while remaining interoperable with frameworks like PyTorch, JAX, and NumPy for use cases spanning simulation, robotics, perception, and geometry processing.<\/p>\n\n\n\n<p>This post walks you through how to build a 2D Navier\u2013Stokes solver entirely in Warp. It explains how the Warp programming model maps onto a PDE solver. Then, it differentiates through the simulation to solve an optimal perturbation problem end-to-end. It closes with industrial case studies showcasing what Warp can enable in production workflows. For more information, see the <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/github.com\/NVIDIA\/warp\/blob\/main\/warp\/examples\/core\/example_fft_poisson_navier_stokes_2d.py\">2D Navier\u2013Stokes solver example<\/a> and <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/github.com\/NVIDIA\/warp\/blob\/main\/warp\/examples\/optim\/example_navier_stokes_perturbation.py\">2D Navier-Stokes optimal perturbation example<\/a> on the NVIDIA\/warp GitHub repo.<\/p>\n\n\n\n<h2 id=\"how_to_write_a_2d_navier\u2013stokes_solver_using_warp\"  class=\"wp-block-heading\"><strong>How to write a 2D Navier<em>\u2013<\/em>Stokes solver using Warp<\/strong><a href=\"#how_to_write_a_2d_navier\u2013stokes_solver_using_warp\" aria-label=\"Scroll to How to write a 2D Navier\u2013Stokes solver using Warp section\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>To keep the focus on Warp rather than on numerical methods, a textbook example of 2D decaying turbulence is used here, described by the vorticity-streamfunction formulation of the incompressible Navier-Stokes equations. The vorticity \\(\\omega\\) evolves according to the transport equation:<\/p>\n\n\n\n<p class=\"has-text-align-center\">\\(\\frac{\\partial \\omega}{\\partial t} + \\frac{\\partial \\psi}{\\partial y}\\frac{\\partial \\omega}{\\partial x} &#8211; \\frac{\\partial \\psi}{\\partial x}\\frac{\\partial \\omega}{\\partial y} = \\frac{1}{\\text{Re}}\\nabla^2 \\omega \\tag{1}\\)<\/p>\n\n\n\n<p>and the streamfunction \\(\\psi\\) is recovered from vorticity through the Poisson equation:<\/p>\n\n\n\n<p class=\"has-text-align-center\">\\(\\nabla^2 \\psi = -\\omega \\tag{2}\\)<\/p>\n\n\n\n<p>With periodic boundary conditions, the equation above reduces to an algebraic equation in Fourier space bypassing the need for iterative solvers:<\/p>\n\n\n\n<p class=\"has-text-align-center\">\\(\\hat{\\psi}_{m,n} = \\frac{\\hat{\\omega}_{m,n}}{k_x^2 + k_y^2} \\tag{3}\\)<\/p>\n\n\n\n<p>where \\((k_x, k_y)\\) is the wavenumber pair in the Fourier space. The solver makes use of the Fast Fourier Transform (FFT) algorithm to efficiently transform \\(\\omega\\) and \\(\\psi\\) to Fourier space and vice versa.<\/p>\n\n\n\n<p>Each timestep has two subcomponents (Figure 1). First, the vorticity transport equation is discretized on an \\(N \\times N\\) grid over an \\(L \\times L\\) square domain. The solution is marched forward in time by \\(\\Delta t\\) using a third-order <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/api.drum.lib.umd.edu\/server\/api\/core\/bitstreams\/faa96f40-0fbb-4af0-84a3-5e3b80c908df\/content\">strong stability-preserving Runge-Kutta (RK3) scheme<\/a> to obtain \\(\\omega(t+\\Delta t)\\). Second, the Poisson equation is solved in the Fourier space to obtain the updated \\(\\psi(t+\\Delta t)\\).&nbsp;<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure data-wp-context=\"{&quot;imageId&quot;:&quot;6a2f7c13e1480&quot;}\" data-wp-interactive=\"core\/image\" data-wp-key=\"6a2f7c13e1480\" class=\"aligncenter size-full wp-lightbox-container\"><img loading=\"lazy\" decoding=\"async\" width=\"1791\" height=\"418\" data-wp-class--hide=\"state.isContentHidden\" data-wp-class--show=\"state.isContentVisible\" data-wp-init=\"callbacks.setButtonStyles\" data-wp-on--click=\"actions.showLightbox\" data-wp-on--load=\"callbacks.setButtonStyles\" data-wp-on-window--resize=\"callbacks.setButtonStyles\" src=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/schematic-single-timestep-loop-solver.webp\" alt=\"Flowchart of one solver timestep: starting with $\\omega(t)$ and $\\psi(t)$, discretization\/time marching computes $\\omega(t+\\Delta t)$, then a Fourier Poisson solver computes $\\psi(t+\\Delta t)$, which feeds back to the next timestep.\n\" class=\"wp-image-113069\" srcset=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/schematic-single-timestep-loop-solver.webp 1791w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/schematic-single-timestep-loop-solver-179x42.png 179w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/schematic-single-timestep-loop-solver-300x70.png 300w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/schematic-single-timestep-loop-solver-768x179.png 768w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/schematic-single-timestep-loop-solver-625x146.png 625w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/schematic-single-timestep-loop-solver-1536x358.png 1536w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/schematic-single-timestep-loop-solver-645x151.png 645w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/schematic-single-timestep-loop-solver-500x117.png 500w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/schematic-single-timestep-loop-solver-160x37.png 160w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/schematic-single-timestep-loop-solver-362x84.png 362w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/schematic-single-timestep-loop-solver-471x110.png 471w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/schematic-single-timestep-loop-solver-1024x239.png 1024w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/schematic-single-timestep-loop-solver-960x224.png 960w\" sizes=\"auto, (max-width: 1791px) 100vw, 1791px\" \/><button\n\t\t\tclass=\"lightbox-trigger\"\n\t\t\ttype=\"button\"\n\t\t\taria-haspopup=\"dialog\"\n\t\t\taria-label=\"Enlarge\"\n\t\t\tdata-wp-init=\"callbacks.initTriggerButton\"\n\t\t\tdata-wp-on--click=\"actions.showLightbox\"\n\t\t\tdata-wp-style--right=\"state.imageButtonRight\"\n\t\t\tdata-wp-style--top=\"state.imageButtonTop\"\n\t\t>\n\t\t\t<svg xmlns=\"https:\/\/2.zoppoz.workers.dev:443\/http\/www.w3.org\/2000\/svg\" width=\"12\" height=\"12\" fill=\"none\" viewBox=\"0 0 12 12\">\n\t\t\t\t<path fill=\"#fff\" d=\"M2 0a2 2 0 0 0-2 2v2h1.5V2a.5.5 0 0 1 .5-.5h2V0H2Zm2 10.5H2a.5.5 0 0 1-.5-.5V8H0v2a2 2 0 0 0 2 2h2v-1.5ZM8 12v-1.5h2a.5.5 0 0 0 .5-.5V8H12v2a2 2 0 0 1-2 2H8Zm2-12a2 2 0 0 1 2 2v2h-1.5V2a.5.5 0 0 0-.5-.5H8V0h2Z\" \/>\n\t\t\t<\/svg>\n\t\t<\/button><figcaption class=\"wp-element-caption\"><em><em>Figure 1. Schematic of a single timestep loop for the solver<\/em><\/em><\/figcaption><\/figure>\n<\/div>\n\n\n<p>Thus, the forward solver has two building blocks that will be described in the subsequent sections:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>Warp kernel for the discretization and time marching <\/li>\n\n\n\n<li>FFT-based Poisson solver <\/li>\n<\/ul>\n\n\n\n<h3 id=\"building_block_1_finite-difference_discretization_and_time_marching\"  class=\"wp-block-heading\"><strong>Building block 1: Finite-difference discretization and time marching<\/strong><a href=\"#building_block_1_finite-difference_discretization_and_time_marching\" aria-label=\"Scroll to Building block 1: Finite-difference discretization and time marching section\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>The advection and diffusion terms in the vorticity transport equation are approximated with second-order central finite differences shown in Figure 2. Higher-order discretization could also be used, but the central second-order scheme is chosen for simplicity.<\/p>\n\n\n\n<figure class=\"wp-block-gallery has-nested-images columns-default is-cropped wp-block-gallery-2 is-layout-flex wp-block-gallery-is-layout-flex\">\n<figure data-wp-context=\"{&quot;imageId&quot;:&quot;6a2f7c13e209f&quot;}\" data-wp-interactive=\"core\/image\" data-wp-key=\"6a2f7c13e209f\" class=\"wp-block-image size-full wp-lightbox-container\"><img loading=\"lazy\" decoding=\"async\" width=\"944\" height=\"746\" data-wp-class--hide=\"state.isContentHidden\" data-wp-class--show=\"state.isContentVisible\" data-wp-init=\"callbacks.setButtonStyles\" data-wp-on--click=\"actions.showLightbox\" data-wp-on--load=\"callbacks.setButtonStyles\" data-wp-on-window--resize=\"callbacks.setButtonStyles\" data-id=\"113128\" src=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image6-1.webp\" alt=\"Finite difference stencils for $latex \\omega$ \" class=\"wp-image-113128\" srcset=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image6-1.webp 944w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image6-1-146x115.png 146w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image6-1-300x237.png 300w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image6-1-768x607.png 768w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image6-1-625x494.png 625w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image6-1-645x510.png 645w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image6-1-380x300.png 380w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image6-1-114x90.png 114w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image6-1-362x286.png 362w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image6-1-139x110.png 139w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image6-1-683x540.png 683w\" sizes=\"auto, (max-width: 944px) 100vw, 944px\" \/><button\n\t\t\tclass=\"lightbox-trigger\"\n\t\t\ttype=\"button\"\n\t\t\taria-haspopup=\"dialog\"\n\t\t\taria-label=\"Enlarge\"\n\t\t\tdata-wp-init=\"callbacks.initTriggerButton\"\n\t\t\tdata-wp-on--click=\"actions.showLightbox\"\n\t\t\tdata-wp-style--right=\"state.imageButtonRight\"\n\t\t\tdata-wp-style--top=\"state.imageButtonTop\"\n\t\t>\n\t\t\t<svg xmlns=\"https:\/\/2.zoppoz.workers.dev:443\/http\/www.w3.org\/2000\/svg\" width=\"12\" height=\"12\" fill=\"none\" viewBox=\"0 0 12 12\">\n\t\t\t\t<path fill=\"#fff\" d=\"M2 0a2 2 0 0 0-2 2v2h1.5V2a.5.5 0 0 1 .5-.5h2V0H2Zm2 10.5H2a.5.5 0 0 1-.5-.5V8H0v2a2 2 0 0 0 2 2h2v-1.5ZM8 12v-1.5h2a.5.5 0 0 0 .5-.5V8H12v2a2 2 0 0 1-2 2H8Zm2-12a2 2 0 0 1 2 2v2h-1.5V2a.5.5 0 0 0-.5-.5H8V0h2Z\" \/>\n\t\t\t<\/svg>\n\t\t<\/button><\/figure>\n\n\n\n<figure data-wp-context=\"{&quot;imageId&quot;:&quot;6a2f7c13e29ab&quot;}\" data-wp-interactive=\"core\/image\" data-wp-key=\"6a2f7c13e29ab\" class=\"wp-block-image size-full wp-lightbox-container\"><img loading=\"lazy\" decoding=\"async\" width=\"944\" height=\"746\" data-wp-class--hide=\"state.isContentHidden\" data-wp-class--show=\"state.isContentVisible\" data-wp-init=\"callbacks.setButtonStyles\" data-wp-on--click=\"actions.showLightbox\" data-wp-on--load=\"callbacks.setButtonStyles\" data-wp-on-window--resize=\"callbacks.setButtonStyles\" data-id=\"113129\" src=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image4-3.webp\" alt=\"Finite difference stencils for $latex \\psi$\" class=\"wp-image-113129\" srcset=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image4-3.webp 944w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image4-3-146x115.png 146w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image4-3-300x237.png 300w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image4-3-768x607.png 768w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image4-3-625x494.png 625w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image4-3-645x510.png 645w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image4-3-380x300.png 380w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image4-3-114x90.png 114w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image4-3-362x286.png 362w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image4-3-139x110.png 139w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/image4-3-683x540.png 683w\" sizes=\"auto, (max-width: 944px) 100vw, 944px\" \/><button\n\t\t\tclass=\"lightbox-trigger\"\n\t\t\ttype=\"button\"\n\t\t\taria-haspopup=\"dialog\"\n\t\t\taria-label=\"Enlarge\"\n\t\t\tdata-wp-init=\"callbacks.initTriggerButton\"\n\t\t\tdata-wp-on--click=\"actions.showLightbox\"\n\t\t\tdata-wp-style--right=\"state.imageButtonRight\"\n\t\t\tdata-wp-style--top=\"state.imageButtonTop\"\n\t\t>\n\t\t\t<svg xmlns=\"https:\/\/2.zoppoz.workers.dev:443\/http\/www.w3.org\/2000\/svg\" width=\"12\" height=\"12\" fill=\"none\" viewBox=\"0 0 12 12\">\n\t\t\t\t<path fill=\"#fff\" d=\"M2 0a2 2 0 0 0-2 2v2h1.5V2a.5.5 0 0 1 .5-.5h2V0H2Zm2 10.5H2a.5.5 0 0 1-.5-.5V8H0v2a2 2 0 0 0 2 2h2v-1.5ZM8 12v-1.5h2a.5.5 0 0 0 .5-.5V8H12v2a2 2 0 0 1-2 2H8Zm2-12a2 2 0 0 1 2 2v2h-1.5V2a.5.5 0 0 0-.5-.5H8V0h2Z\" \/>\n\t\t\t<\/svg>\n\t\t<\/button><\/figure>\n<figcaption class=\"blocks-gallery-caption wp-element-caption\"><em><em><em>Figure 2. Finite difference stencils for \\(\\omega\\) and \\(\\psi\\)<\/em><\/em><\/em><\/figcaption><\/figure>\n\n\n\n<p>The following <code>rk3_update()<\/code> kernel computes the diffusion and the advection terms and performs a single&nbsp; RK3 substep update. The <code>step()<\/code> function calls this kernel three times per timestep, once for each RK3 stage, with different coefficients (<code>coeff0<\/code>, <code>coeff1<\/code>, <code>coeff2<\/code>) for each stage.<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n@wp.kernel\ndef rk3_update(\n    n: int, h: float, re: float, dt: float,\n    coeff0: float, coeff1: float, coeff2: float,\n    omega_0: wp.array2d(dtype=float),\n    omega_1: wp.array2d(dtype=float),\n    psi: wp.array2d(dtype=float),\n    omega_out: wp.array2d(dtype=float)\n): \n\n   &quot;&quot;&quot;Perform a single substep of SSP-RK3.&quot;&quot;&quot;\n\n    i, j = wp.tid()\n\n    left = cyclic_index(i - 1, n)\n    right = cyclic_index(i + 1, n)\n    top = cyclic_index(j + 1, n)\n    down = cyclic_index(j - 1, n)\n\n    inv_h2 = 1.0 \/ (h * h)\n    laplacian = (\n        omega_1&#x5B;right, j] + omega_1&#x5B;left, j] + omega_1&#x5B;i, top] + omega_1&#x5B;i, down] - 4.0 * omega_1&#x5B;i,j]\n    ) * inv_h2\n\n    inv_2h = 1.0 \/ (2.0 * h)\n    j1 = ((omega_1&#x5B;right, j] - omega_1&#x5B;left, j]) * inv_2h) * ((psi&#x5B;i, top] - psi&#x5B;i, down]) * inv_2h)\n    j2 = ((omega_1&#x5B;i, top] - omega_1&#x5B;i, down]) * inv_2h) * ((psi&#x5B;right, j] - psi&#x5B;left, j]) * inv_2h)\n\n    rhs = (1.0 \/ re) * laplacian + j2 - j1\n\n    omega_out&#x5B;i, j] = coeff0 * omega_0&#x5B;i, j] + coeff1 * omega_1&#x5B;i, j] + coeff2 * dt * rhs\n<\/pre><\/div>\n\n\n<p>The <code>rk3_update()<\/code> kernel follows the <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/docs.nvidia.com\/cuda\/cuda-programming-guide\/01-introduction\/programming-model.html#warps-and-simt\">single-instruction, multiple-threads (SIMT)<\/a> paradigm where each thread maps to one grid point on the computational domain, and all \\(N \\times N\\) points are updated simultaneously with a single <code>wp.launch()<\/code> call.<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\nwp.launch(rk3_update,\n          dim=(self.n, self.n), # one thread per grid point\n          inputs=&#x5B;self.n, self.h, self.re, self.dt,\n                  stage_coeff&#x5B;0], stage_coeff&#x5B;1], stage_coeff&#x5B;2],\n                  self.omega_0, \n                  self.omega_1, \n                  self.psi,\n                ],\n\t        outputs=&#x5B;self.omega_tmp]\n         )\n<\/pre><\/div>\n\n<div class=\"wp-block-image\">\n<figure data-wp-context=\"{&quot;imageId&quot;:&quot;6a2f7c13e3638&quot;}\" data-wp-interactive=\"core\/image\" data-wp-key=\"6a2f7c13e3638\" class=\"aligncenter size-full wp-lightbox-container\"><img loading=\"lazy\" decoding=\"async\" width=\"1905\" height=\"795\" data-wp-class--hide=\"state.isContentHidden\" data-wp-class--show=\"state.isContentVisible\" data-wp-init=\"callbacks.setButtonStyles\" data-wp-on--click=\"actions.showLightbox\" data-wp-on--load=\"callbacks.setButtonStyles\" data-wp-on-window--resize=\"callbacks.setButtonStyles\" src=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/simt-2d-grid.webp\" alt=\"SIMT update on an $N\\times N$ grid: $N^2$ threads run in parallel, one per cell $(i,j)$; each thread reads the five-point stencil values from timestep $n-1$ and writes the updated $\\omega_{i,j}^n$ for timestep $n$.\n\" class=\"wp-image-113607\" srcset=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/simt-2d-grid.webp 1905w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/simt-2d-grid-179x75.png 179w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/simt-2d-grid-300x125.png 300w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/simt-2d-grid-768x321.png 768w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/simt-2d-grid-625x261.png 625w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/simt-2d-grid-1536x641.png 1536w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/simt-2d-grid-645x269.png 645w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/simt-2d-grid-500x209.png 500w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/simt-2d-grid-160x67.png 160w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/simt-2d-grid-362x151.png 362w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/simt-2d-grid-264x110.png 264w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/simt-2d-grid-1024x427.png 1024w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/simt-2d-grid-960x401.png 960w\" sizes=\"auto, (max-width: 1905px) 100vw, 1905px\" \/><button\n\t\t\tclass=\"lightbox-trigger\"\n\t\t\ttype=\"button\"\n\t\t\taria-haspopup=\"dialog\"\n\t\t\taria-label=\"Enlarge\"\n\t\t\tdata-wp-init=\"callbacks.initTriggerButton\"\n\t\t\tdata-wp-on--click=\"actions.showLightbox\"\n\t\t\tdata-wp-style--right=\"state.imageButtonRight\"\n\t\t\tdata-wp-style--top=\"state.imageButtonTop\"\n\t\t>\n\t\t\t<svg xmlns=\"https:\/\/2.zoppoz.workers.dev:443\/http\/www.w3.org\/2000\/svg\" width=\"12\" height=\"12\" fill=\"none\" viewBox=\"0 0 12 12\">\n\t\t\t\t<path fill=\"#fff\" d=\"M2 0a2 2 0 0 0-2 2v2h1.5V2a.5.5 0 0 1 .5-.5h2V0H2Zm2 10.5H2a.5.5 0 0 1-.5-.5V8H0v2a2 2 0 0 0 2 2h2v-1.5ZM8 12v-1.5h2a.5.5 0 0 0 .5-.5V8H12v2a2 2 0 0 1-2 2H8Zm2-12a2 2 0 0 1 2 2v2h-1.5V2a.5.5 0 0 0-.5-.5H8V0h2Z\" \/>\n\t\t\t<\/svg>\n\t\t<\/button><figcaption class=\"wp-element-caption\"><em><em>Figure 3. SIMT update of \\(\\omega\\) on the 2D grid. Thread (i, j) updates cell (i, j) to the next timestep using values from neighboring cells in the stencil at the current timestep<\/em><\/em><\/figcaption><\/figure>\n<\/div>\n\n\n<h3 id=\"building_block_2_fft_poisson_solver\"  class=\"wp-block-heading\"><strong>Building block 2: FFT Poisson solver<\/strong><a href=\"#building_block_2_fft_poisson_solver\" aria-label=\"Scroll to Building block 2: FFT Poisson solver section\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p><a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer.nvidia.com\/blog\/introducing-tile-based-programming-in-warp-1-5-0\/\">Warp tile-based primitives<\/a> enable solving the Poisson equation in Fourier space. The key operations are <code>wp.tile_fft()<\/code> and <code>wp.tile_ifft()<\/code>, which perform the forward and inverse FFT, respectively, on a single row loaded into a tile. A full 2D FFT on an \\(N \\times N\\) array is then decomposed into three steps: row-wise FFT -&gt; transpose -&gt; row-wise FFT. The schematic in Figure 4 explains how <code>fft_tiled()<\/code> and <code>ifft_tiled()<\/code> compute the forward and inverse FFT under the hood.<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n@wp.kernel\ndef fft_tiled(x: wp.array2d(dtype=wp.vec2f), y: wp.array2d(dtype=wp.vec2f)):\n    &quot;&quot;&quot;Row-wise FFT using tile primitives.&quot;&quot;&quot;\n    i, _, _ = wp.tid()\n    a = wp.tile_load(x, shape=(1, N_GRID), offset=(i, 0))\n    wp.tile_fft(a)\n    wp.tile_store(y, a, offset=(i, 0))\n<\/pre><\/div>\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n@wp.kernel\ndef ifft_tiled(x: wp.array2d(dtype=wp.vec2f), y: wp.array2d(dtype=wp.vec2f)):\n    &quot;&quot;&quot;Row-wise inverse FFT using tile primitives.&quot;&quot;&quot;\n    i, _, _ = wp.tid()\n    a = wp.tile_load(x, shape=(1, N_GRID), offset=(i, 0))\n    wp.tile_ifft(a)\n    wp.tile_store(y, a, offset=(i, 0))\n<\/pre><\/div>\n\n<div class=\"wp-block-image\">\n<figure data-wp-context=\"{&quot;imageId&quot;:&quot;6a2f7c13e41f0&quot;}\" data-wp-interactive=\"core\/image\" data-wp-key=\"6a2f7c13e41f0\" class=\"aligncenter size-full wp-lightbox-container\"><img loading=\"lazy\" decoding=\"async\" width=\"1862\" height=\"1033\" data-wp-class--hide=\"state.isContentHidden\" data-wp-class--show=\"state.isContentVisible\" data-wp-init=\"callbacks.setButtonStyles\" data-wp-on--click=\"actions.showLightbox\" data-wp-on--load=\"callbacks.setButtonStyles\" data-wp-on-window--resize=\"callbacks.setButtonStyles\" src=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/spatial-spectral-row-wise-tile-nxn-grid.webp\" alt=\"Row-wise GPU on an $N\\times N$ grid: one thread block per row loads the row  into a register tile, performs an in-place FFT cooperatively, and stores the result to a new array in the frequency domain.\n\" class=\"wp-image-113088\" srcset=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/spatial-spectral-row-wise-tile-nxn-grid.webp 1862w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/spatial-spectral-row-wise-tile-nxn-grid-179x99.png 179w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/spatial-spectral-row-wise-tile-nxn-grid-300x166.png 300w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/spatial-spectral-row-wise-tile-nxn-grid-768x426.png 768w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/spatial-spectral-row-wise-tile-nxn-grid-625x347.png 625w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/spatial-spectral-row-wise-tile-nxn-grid-1536x852.png 1536w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/spatial-spectral-row-wise-tile-nxn-grid-645x358.png 645w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/spatial-spectral-row-wise-tile-nxn-grid-500x277.png 500w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/spatial-spectral-row-wise-tile-nxn-grid-160x90.png 160w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/spatial-spectral-row-wise-tile-nxn-grid-362x201.png 362w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/spatial-spectral-row-wise-tile-nxn-grid-198x110.png 198w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/spatial-spectral-row-wise-tile-nxn-grid-1024x568.png 1024w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/spatial-spectral-row-wise-tile-nxn-grid-960x533.png 960w\" sizes=\"auto, (max-width: 1862px) 100vw, 1862px\" \/><button\n\t\t\tclass=\"lightbox-trigger\"\n\t\t\ttype=\"button\"\n\t\t\taria-haspopup=\"dialog\"\n\t\t\taria-label=\"Enlarge\"\n\t\t\tdata-wp-init=\"callbacks.initTriggerButton\"\n\t\t\tdata-wp-on--click=\"actions.showLightbox\"\n\t\t\tdata-wp-style--right=\"state.imageButtonRight\"\n\t\t\tdata-wp-style--top=\"state.imageButtonTop\"\n\t\t>\n\t\t\t<svg xmlns=\"https:\/\/2.zoppoz.workers.dev:443\/http\/www.w3.org\/2000\/svg\" width=\"12\" height=\"12\" fill=\"none\" viewBox=\"0 0 12 12\">\n\t\t\t\t<path fill=\"#fff\" d=\"M2 0a2 2 0 0 0-2 2v2h1.5V2a.5.5 0 0 1 .5-.5h2V0H2Zm2 10.5H2a.5.5 0 0 1-.5-.5V8H0v2a2 2 0 0 0 2 2h2v-1.5ZM8 12v-1.5h2a.5.5 0 0 0 .5-.5V8H12v2a2 2 0 0 1-2 2H8Zm2-12a2 2 0 0 1 2 2v2h-1.5V2a.5.5 0 0 0-.5-.5H8V0h2Z\" \/>\n\t\t\t<\/svg>\n\t\t<\/button><figcaption class=\"wp-element-caption\"><em><em>Figure 4. Row-wise <code>tile_fft<\/code> on an NxN grid. Each block loads one row into a register tile, computes the FFT cooperatively, and stores the result back to global memory<\/em><\/em><\/figcaption><\/figure>\n<\/div>\n\n\n<p>A 2D FFT also requires a transpose between the row-wise passes. This can use either the SIMT or tile paradigm (through <code>wp.tile_transpose<\/code>). For simplicity, the SIMT version is shown below:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n@wp.kernel\ndef transpose(x: wp.array2d(dtype=wp.vec2f), y: wp.array2d(dtype=wp.vec2f)):\n    i, j = wp.tid()\n    y&#x5B;i, j] = x&#x5B;j, i]\n<\/pre><\/div>\n\n\n<p>Composing these three kernels, <code>fft_tiled<\/code> -&gt; <code>transpose<\/code> -&gt; <code>fft_tiled<\/code>, together gives a full 2D forward FFT. The inverse follows the same pattern with <code>ifft_tiled<\/code>.<\/p>\n\n\n\n<h3 id=\"putting_the_building_blocks_together\"  class=\"wp-block-heading\"><strong>Putting the building blocks together<\/strong><a href=\"#putting_the_building_blocks_together\" aria-label=\"Scroll to Putting the building blocks together section\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>The <code>step()<\/code> function in the example relies on a few other helper kernels that are not discussed in detail here. For the definitions of those kernels, see the <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/github.com\/NVIDIA\/warp\/blob\/main\/warp\/examples\/core\/example_fft_poisson_navier_stokes_2d.py\">2D Navier\u2013Stokes solver example<\/a> on the NVIDIA\/warp GitHub repo. With all the building blocks in place, a single <code>step()<\/code> call advances the simulation by one timestep. The <code>self._solve_poisson()<\/code> method in the example code abstracts away the \\(\\omega(t+\\Delta t) \\xrightarrow{\\text{FFT}} \\hat{\\omega} \\xrightarrow{\\text{Eq.\\,3}} \\hat{\\psi} \\xrightarrow{\\text{IFFT}} \\psi(t+\\Delta t)\\) pipeline for modularity.<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\n def step(self) -&gt; None:\n        &quot;&quot;&quot;Advance simulation by one timestep using SSP-RK3.&quot;&quot;&quot;\n        for stage_coeff in self.rk3_coeffs:\n            wp.launch(\n                rk3_update,\n                dim=(self.n, self.n),\n                inputs=&#x5B;\n                    self.n, self.h, self.re, self.dt,\n                    stage_coeff&#x5B;0], stage_coeff&#x5B;1], stage_coeff&#x5B;2],\n                    self.omega_0,\n                    self.omega_1,\n                    self.psi,\n                ],\n               outputs=&#x5B;self.omega_tmp],\n            )\n            # Swap buffers for next RK3 substep\n            self.omega_1, self.omega_tmp = self.omega_tmp, self.omega_1\n\n            # Update streamfunction for next timestep\n            self._solve_poisson()\n        \n        # Copy updated vorticity to self.omega_0 for the next timestep\n        wp.copy(self.omega_0, self.omega_1)\n<\/pre><\/div>\n\n\n<p>Running the solver produces the decaying turbulence field shown in Figure 5. On the GPU, the <code>step()<\/code> function is captured into a <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/docs.nvidia.com\/cuda\/cuda-programming-guide\/04-special-topics\/cuda-graphs.html\">CUDA Graph<\/a> through <code>wp.ScopedCapture<\/code> and replayed with <code>wp.capture_launch()<\/code> for all subsequent frames, eliminating per-launch overhead.<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure data-wp-context=\"{&quot;imageId&quot;:&quot;6a2f7c13e4fb9&quot;}\" data-wp-interactive=\"core\/image\" data-wp-key=\"6a2f7c13e4fb9\" class=\"aligncenter size-full wp-lightbox-container\"><img loading=\"lazy\" decoding=\"async\" width=\"400\" height=\"400\" data-wp-class--hide=\"state.isContentHidden\" data-wp-class--show=\"state.isContentVisible\" data-wp-init=\"callbacks.setButtonStyles\" data-wp-on--click=\"actions.showLightbox\" data-wp-on--load=\"callbacks.setButtonStyles\" data-wp-on-window--resize=\"callbacks.setButtonStyles\" src=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/2d-decaying-turbulence-gif.gif\" alt=\"Pseudocolor GIF of two-dimensional decaying turbulence at $\\mathrm{Re}=1000$, showing intertwined vortical filaments and eddy structures across the domain.\" class=\"wp-image-113151\" srcset=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/2d-decaying-turbulence-gif.gif 400w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/2d-decaying-turbulence-gif-115x115.gif 115w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/2d-decaying-turbulence-gif-300x300.gif 300w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/2d-decaying-turbulence-gif-90x90.gif 90w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/2d-decaying-turbulence-gif-32x32.gif 32w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/2d-decaying-turbulence-gif-50x50.gif 50w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/2d-decaying-turbulence-gif-64x64.gif 64w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/2d-decaying-turbulence-gif-96x96.gif 96w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/2d-decaying-turbulence-gif-128x128.gif 128w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/2d-decaying-turbulence-gif-150x150.gif 150w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/2d-decaying-turbulence-gif-362x362.gif 362w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/2d-decaying-turbulence-gif-110x110.gif 110w\" sizes=\"auto, (max-width: 400px) 100vw, 400px\" \/><button\n\t\t\tclass=\"lightbox-trigger\"\n\t\t\ttype=\"button\"\n\t\t\taria-haspopup=\"dialog\"\n\t\t\taria-label=\"Enlarge\"\n\t\t\tdata-wp-init=\"callbacks.initTriggerButton\"\n\t\t\tdata-wp-on--click=\"actions.showLightbox\"\n\t\t\tdata-wp-style--right=\"state.imageButtonRight\"\n\t\t\tdata-wp-style--top=\"state.imageButtonTop\"\n\t\t>\n\t\t\t<svg xmlns=\"https:\/\/2.zoppoz.workers.dev:443\/http\/www.w3.org\/2000\/svg\" width=\"12\" height=\"12\" fill=\"none\" viewBox=\"0 0 12 12\">\n\t\t\t\t<path fill=\"#fff\" d=\"M2 0a2 2 0 0 0-2 2v2h1.5V2a.5.5 0 0 1 .5-.5h2V0H2Zm2 10.5H2a.5.5 0 0 1-.5-.5V8H0v2a2 2 0 0 0 2 2h2v-1.5ZM8 12v-1.5h2a.5.5 0 0 0 .5-.5V8H12v2a2 2 0 0 1-2 2H8Zm2-12a2 2 0 0 1 2 2v2h-1.5V2a.5.5 0 0 0-.5-.5H8V0h2Z\" \/>\n\t\t\t<\/svg>\n\t\t<\/button><figcaption class=\"wp-element-caption\"><em><em>Figure 5. Two-dimensional decaying turbulence at Re = 1,000<\/em><\/em><\/figcaption><\/figure>\n<\/div>\n\n\n<h2 id=\"differentiating_through_the_solver\"  class=\"wp-block-heading\"><strong>Differentiating through the solver<\/strong><a href=\"#differentiating_through_the_solver\" aria-label=\"Scroll to Differentiating through the solver section\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>Now that the working solver has been built, the next question is how to make it differentiable.<\/p>\n\n\n\n<p><em>Automatic differentiation<\/em> (AD) computes exact derivatives of a program by applying the chain rule to each elementary operation in the computational graph. Unlike finite differences, AD avoids step-size tuning and yields gradients accurate to machine precision. The key advantage of AD for PDE solvers is scaling: with a complex simulation on a large grid, each forward solve is already expensive, so methods like finite differences require \\(O(n)\\) full solves to get gradients with regard to \\(n\\) inputs.&nbsp;<\/p>\n\n\n\n<p>Reverse-mode AD computes all \\(\\partial \\mathcal{L}\/\\partial x_i\\) in roughly one forward pass plus one backward pass, making gradient-based optimization practical at production resolution. This is the same idea as backpropagation in neural nets, and it is why both deep learning and large-scale physics optimization can handle millions of degrees of freedom.<\/p>\n\n\n\n<p>The <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/nvidia.github.io\/warp\/user_guide\/differentiability.html\">Warp automatic differentiation<\/a> system generates two versions of a program at compile time for a differentiable simulation:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>Forward version<\/strong>: The code that takes physical inputs (initial conditions, discretized governing laws, and so on) and computes the simulation output (fields, derived quantities) as well as intermediate arrays needed for the adjoint version.<\/li>\n\n\n\n<li><strong>Adjoint version<\/strong>: An automatically generated counterpart to the forward simulation that can take sensitivities of a chosen quantity of interest with respect to the simulation outputs and propagate them all the way back to the inputs. This backward propagation reuses intermediate arrays from the forward execution to apply the chain rule of differentiation across the entire solver, yielding the simulation adjoint without constructing large symbolic expressions.&nbsp;<\/li>\n<\/ul>\n\n\n\n<p>Developers write the forward physics and Warp handles the gradient computation. Any <code>wp.array<\/code> that should be differentiable is allocated with <code>requires_grad=True<\/code>, which tells Warp to allocate a companion array for adjoint storage. The resulting adjoints can be used standalone (as in this example) or interoperated with PyTorch or JAX for end-to-end optimization, including training ML models. Currently, Warp supports reverse-mode AD only.&nbsp;<\/p>\n\n\n\n<p>To illustrate, the optimal perturbation problem outlined in <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/www.cambridge.org\/core\/journals\/journal-of-fluid-mechanics\/article\/prediction-and-control-of-twodimensional-decaying-turbulence-using-generative-adversarial-networks\/EAE377A4E1F784D135520DB1EE3F25A0\">Prediction and Control of Two-Dimensional Decaying Turbulence Using Generative Adversarial Networks<\/a> is tackled here. In turbulent flows, small perturbations to the initial conditions can amplify over time and significantly alter the trajectory of the flow. Identifying which perturbations grow the fastest is a stepping stone toward flow control and toward understanding which structures in the flow are dynamically significant. Concretely, the initial vorticity perturbation \\(\\Delta\\omega\\) is sought, which maximizes the divergence between perturbed and unperturbed trajectories at a lead time \\(\\tau\\).<\/p>\n\n\n\n<p>Let \\(F^{\\tau}\\) denote the forward solver applied for \\(\\tau\\) time units. The unperturbed trajectory is \\(Y^{*} = F^{\\tau}(\\omega_0)\\) and the perturbed trajectory is \\(\\tilde{Y} = F^{\\tau}(\\omega_0 + \\Delta\\omega)\\). The mean squared error (MSE)<\/p>\n\n\n\n<p class=\"has-text-align-center\">\\(\\mathrm{MSE} = -\\frac{1}{N^2}\\left\\| Y^* &#8211; \\tilde{Y} \\right\\|_2^2 \\tag{4}\\)<\/p>\n\n\n\n<p>is minimized, where the negative sign turns maximization of trajectory divergence into a minimization problem. To constrain the optimization, \\(\\mathrm{rms}(\\Delta\\omega) \\leq 0.2 \\times \\mathrm{rms}(\\omega_0)\\), that is, the perturbation RMS must not exceed 20% of the RMS of the initial vorticity field \\(\\omega_0\\).&nbsp;<\/p>\n\n\n\n<p>For more details, see the <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/github.com\/NVIDIA\/warp\/blob\/main\/warp\/examples\/optim\/example_navier_stokes_perturbation.py\">2D Navier-Stokes optimal perturbation example<\/a> on the NVIDIA\/warp GitHub repo. The following sections focus on the three key changes in the forward solver that would make it differentiable.<\/p>\n\n\n\n<h3 id=\"no_in-place_modifications\"  class=\"wp-block-heading\"><strong>No in-place modifications<\/strong><a href=\"#no_in-place_modifications\" aria-label=\"Scroll to No in-place modifications section\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p><code>wp.Tape()<\/code> records kernel launches in the forward pass and replays them in reverse to compute gradients. That only works if the intermediate values needed by the backward pass are still available, so arrays cannot be freely overwritten in place. This is the key difference from the nondifferentiable solver. In the forward-only version, two arrays could be switched, <code>omega_0<\/code> and <code>omega_1<\/code>, at the end of each timestep:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\nwp.copy(omega_0, omega_1)\n<\/pre><\/div>\n\n\n<p>For the differentiable solver, the RHS computation and the RK3 update need to be split into separate kernels that write to separate arrays. Thus a single RK3 update becomes the following. Note that <code>omega_1<\/code> values cannot be copied to <code>omega_0<\/code> at the end of each timestep as before.<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\nomega_out&#x5B;i, j] = coeff0 * omega_0&#x5B;i, j] + coeff1 * omega_in&#x5B;i, j] + coeff2 * dt * rhs&#x5B;i, j]\n<\/pre><\/div>\n\n\n<p>In Warp, all the intermediate arrays need to be explicitly defined by the user.&nbsp; This requires pre-allocating separate arrays for every RK substep at every timestep, which is the generally dominant GPU memory cost of any differentiable solver.<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\nself.omega_timestep = &#x5B;wp.zeros((n, n), dtype=wp.float32, requires_grad=True) for _ in range(T + 1)]\n\n# Intermediate arrays for each RK3 substep for each timestep\nself.omega_stage = &#x5B;]\nself.psi_stage = &#x5B;]\nself.rhs_stage = &#x5B;]\nself.fft_arrays = &#x5B;]\n\nfor _ in range(T):\n    s_omega, s_psi, s_rhs, s_fft = &#x5B;], &#x5B;], &#x5B;], &#x5B;]\n    for _ in range(3):\n        s_omega.append(wp.zeros((n, n), dtype=wp.float32, requires_grad=True))\n        s_psi.append(wp.zeros((n, n), dtype=wp.float32, requires_grad=True))\n        s_rhs.append(wp.zeros((n, n), dtype=wp.float32, requires_grad=True))\n        s_fft.append({&quot;omega_complex&quot;: wp.zeros((n, n), dtype=wp.vec2f, requires_grad=True),\n                      # ... plus 4 FFT scratch arrays, each (n, n) vec2f\n                    })\n    self.omega_stage.append(s_omega)\n    self.psi_stage.append(s_psi)\n    self.rhs_stage.append(s_rhs)\n    self.fft_arrays.append(s_fft)\n<\/pre><\/div>\n\n\n<p>Storing Warp arrays for every intermediate state scales linearly with the number of timesteps, which becomes prohibitive in long runs. One common approach is <em>gradient checkpointing<\/em>, saving only selected states, then recomputing the missing segments using the forward solver during the backward pass. This method trades extra forward compute for a much smaller memory footprint. For an example showing how to implement gradient checkpointing in Warp, see the <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/github.com\/NVIDIA\/warp\/blob\/main\/warp\/examples\/optim\/example_fluid_checkpoint.py\">fluid checkpoint example<\/a> on the NVIDIA\/warp GitHub repo.&nbsp;<\/p>\n\n\n\n<h3 id=\"recording_gradients_with_wptape\"  class=\"wp-block-heading\"><strong>Recording gradients with <code>wp.Tape()<\/code><\/strong><a href=\"#recording_gradients_with_wptape\" aria-label=\"Scroll to Recording gradients with wp.Tape() section\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>With the pre-allocated arrays in place, recording and differentiating the forward pass is straightforward:<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\nwith wp.Tape() as tape:\n    forward()  # wp.launch calls that take omega from t0 to t0 + lead t and calculate MSE \ntape.backward(loss) # Automatic differentiation to get derivatives of loss w.r.t Warp arrays\n<\/pre><\/div>\n\n\n<p>The <code>wp.Tape()<\/code> context records every <code>wp.launch()<\/code> call into a computational graph. <code>tape.backward(loss)<\/code> traverses that graph in reverse, computing the derivatives of <code>loss<\/code> with respect to the Warp arrays. Here the focus is the gradients of <code>loss<\/code> with respect to \\(\\Delta{\\omega}\\), which can be obtained through <code>delta_omega.grad<\/code>.&nbsp;&nbsp;<\/p>\n\n\n\n<h3 id=\"optimization_loop\"  class=\"wp-block-heading\"><strong>Optimization loop<\/strong><a href=\"#optimization_loop\" aria-label=\"Scroll to Optimization loop section\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>The following code block shows one optimization step. The <code>forward()<\/code> function is run on the perturbed initial vorticity to produce the final field and loss (MSE versus the unperturbed run). The tape records the kernel launches during this pass. <code>tape.backward(loss)<\/code> then backpropagates through the recorded graph to compute gradients with regard to the perturbation, and <code>optimizer.step()<\/code> updates the perturbation to reduce the loss. Finally, <code>tape.zero()<\/code> clears accumulated gradients before the next iteration.<\/p>\n\n\n<div class=\"wp-block-syntaxhighlighter-code \"><pre class=\"brush: python; title: ; notranslate\" title=\"\">\nwith wp.Tape() as tape:\n    forward() # Loss is computed inside forward() function\n\ntape.backward(loss)\noptimizer.step(&#x5B;delta_omega.grad.flatten()])\ntape.zero()\n<\/pre><\/div>\n\n\n<p>After 1,000 iterations, the optimizer discovers a structured perturbation \\(\\Delta\\omega\\) that amplifies trajectory divergence, driving the MSE from near-zero to ~250. The perturbation field obtained from the solver-in-the-loop optimization qualitatively resembles the one reported in <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/www.cambridge.org\/core\/journals\/journal-of-fluid-mechanics\/article\/prediction-and-control-of-twodimensional-decaying-turbulence-using-generative-adversarial-networks\/EAE377A4E1F784D135520DB1EE3F25A0\">Prediction and Control of Two-Dimensional Decaying Turbulence Using Generative Adversarial Networks<\/a>.<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure data-wp-context=\"{&quot;imageId&quot;:&quot;6a2f7c13e6471&quot;}\" data-wp-interactive=\"core\/image\" data-wp-key=\"6a2f7c13e6471\" class=\"aligncenter size-full wp-lightbox-container\"><img loading=\"lazy\" decoding=\"async\" width=\"2000\" height=\"1010\" data-wp-class--hide=\"state.isContentHidden\" data-wp-class--show=\"state.isContentVisible\" data-wp-init=\"callbacks.setButtonStyles\" data-wp-on--click=\"actions.showLightbox\" data-wp-on--load=\"callbacks.setButtonStyles\" data-wp-on-window--resize=\"callbacks.setButtonStyles\" src=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/optimization-progressing-1000-iterations-discovered-perturbation-1.gif\" alt=\"Optimization GIF for 1,000 iterations. $\\mathrm{MSE}(Y^*,\\tilde{Y}$ decreases over iterations (left), with field snapshots showing the baseline $\\omega_0$ (top center), learned perturbation $\\Delta\\omega$ (top right), target $Y^*$ (bottom center), and optimized output $\\tilde{Y}$ (bottom right) on a $[0,2\\pi]\\times[0,2\\pi]$ domain.\n\" class=\"wp-image-113117\" srcset=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/optimization-progressing-1000-iterations-discovered-perturbation-1.gif 2000w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/optimization-progressing-1000-iterations-discovered-perturbation-1-179x90.gif 179w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/optimization-progressing-1000-iterations-discovered-perturbation-1-300x152.gif 300w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/optimization-progressing-1000-iterations-discovered-perturbation-1-768x388.gif 768w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/optimization-progressing-1000-iterations-discovered-perturbation-1-625x316.gif 625w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/optimization-progressing-1000-iterations-discovered-perturbation-1-1536x776.gif 1536w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/optimization-progressing-1000-iterations-discovered-perturbation-1-645x326.gif 645w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/optimization-progressing-1000-iterations-discovered-perturbation-1-500x253.gif 500w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/optimization-progressing-1000-iterations-discovered-perturbation-1-160x81.gif 160w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/optimization-progressing-1000-iterations-discovered-perturbation-1-362x183.gif 362w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/optimization-progressing-1000-iterations-discovered-perturbation-1-218x110.gif 218w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/optimization-progressing-1000-iterations-discovered-perturbation-1-1024x517.gif 1024w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/optimization-progressing-1000-iterations-discovered-perturbation-1-960x485.gif 960w\" sizes=\"auto, (max-width: 2000px) 100vw, 2000px\" \/><button\n\t\t\tclass=\"lightbox-trigger\"\n\t\t\ttype=\"button\"\n\t\t\taria-haspopup=\"dialog\"\n\t\t\taria-label=\"Enlarge\"\n\t\t\tdata-wp-init=\"callbacks.initTriggerButton\"\n\t\t\tdata-wp-on--click=\"actions.showLightbox\"\n\t\t\tdata-wp-style--right=\"state.imageButtonRight\"\n\t\t\tdata-wp-style--top=\"state.imageButtonTop\"\n\t\t>\n\t\t\t<svg xmlns=\"https:\/\/2.zoppoz.workers.dev:443\/http\/www.w3.org\/2000\/svg\" width=\"12\" height=\"12\" fill=\"none\" viewBox=\"0 0 12 12\">\n\t\t\t\t<path fill=\"#fff\" d=\"M2 0a2 2 0 0 0-2 2v2h1.5V2a.5.5 0 0 1 .5-.5h2V0H2Zm2 10.5H2a.5.5 0 0 1-.5-.5V8H0v2a2 2 0 0 0 2 2h2v-1.5ZM8 12v-1.5h2a.5.5 0 0 0 .5-.5V8H12v2a2 2 0 0 1-2 2H8Zm2-12a2 2 0 0 1 2 2v2h-1.5V2a.5.5 0 0 0-.5-.5H8V0h2Z\" \/>\n\t\t\t<\/svg>\n\t\t<\/button><figcaption class=\"wp-element-caption\"><em><em>Figure 6. Optimization progressing over 1,000 iterations with discovered perturbation (top right)<\/em><\/em><\/figcaption><\/figure>\n<\/div>\n\n\n<p>To learn more, the NVIDIA\/warp GitHub repo includes <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/github.com\/NVIDIA\/warp\/tree\/main\/warp\/examples\/optim\">additional differentiable-solver examples<\/a> beyond CFD.&nbsp;See also a growing list of <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/nvidia.github.io\/warp\/user_guide\/publications.html\">research publications that leverage Warp<\/a>.<\/p>\n\n\n\n<h2 id=\"warp_in_practice_case_studies_of_ai-driven_industrial_workflows\"  class=\"wp-block-heading\"><strong>Warp in practice: Case studies of AI-driven industrial workflows<\/strong><a href=\"#warp_in_practice_case_studies_of_ai-driven_industrial_workflows\" aria-label=\"Scroll to Warp in practice: Case studies of AI-driven industrial workflows section\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>In real AI workflows, simulation and geometry sit inside larger systems (surrogate models, RL, design optimization, and so on). PyTorch and JAX handle training and tensor ops, but the simulation layer adds staged timestepping, stencil updates, and big spatial queries. Warp targets that kernel-heavy layer: you control execution, fuse kernels to cut memory traffic and launches, and use CUDA Graphs to reduce repeated dispatch. It also interoperates zero-copy with PyTorch and JAX tensors.<\/p>\n\n\n\n<h3 id=\"autodesk_xlb_\"  class=\"wp-block-heading\">Autodesk XLB <a href=\"#autodesk_xlb_\" aria-label=\"Scroll to Autodesk XLB  section\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>Autodesk Research built <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/www.research.autodesk.com\/publications\/xlb-differentiable-massively-parallel-lattice-boltzmann-library-python\/\">XLB<\/a>, a differentiable Lattice Boltzmann solver in Python with both Warp and JAX backends, enabling a direct comparison on the same formulation and hardware. On a ~134-million-cell lid-driven cavity benchmark, Warp ran about 8x faster than JAX on a single 40 GB <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/www.nvidia.com\/en-us\/data-center\/a100\/\">NVIDIA A100 Tensor Core GPU<\/a>, roughly matching the throughput that JAX needed eight A100 Tensor Core GPUs to reach. At larger sizes, Warp used ~2.5x\u20133x less memory and completed the largest case, on which JAX ran out of memory on the same GPU.<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure data-wp-context=\"{&quot;imageId&quot;:&quot;6a2f7c13e6fcd&quot;}\" data-wp-interactive=\"core\/image\" data-wp-key=\"6a2f7c13e6fcd\" class=\"aligncenter size-full wp-lightbox-container\"><img loading=\"lazy\" decoding=\"async\" width=\"1999\" height=\"854\" data-wp-class--hide=\"state.isContentHidden\" data-wp-class--show=\"state.isContentVisible\" data-wp-init=\"callbacks.setButtonStyles\" data-wp-on--click=\"actions.showLightbox\" data-wp-on--load=\"callbacks.setButtonStyles\" data-wp-on-window--resize=\"callbacks.setButtonStyles\" src=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/throughput-memory-usage-comparison-warp-jax.webp\" alt=\"Two bar charts comparing Warp and JAX on NVIDIA A100. Left: throughput in MLUPS: Warp single-GPU (8879) versus JAX single-GPU (1139) and JAX 8-GPU (8397). Right: memory usage at 128^3, 256^3, and  512^3 domain sizes \u2014 JAX OOMs at 512^3 while Warp fits in 24 GB.\n\" class=\"wp-image-113120\" srcset=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/throughput-memory-usage-comparison-warp-jax.webp 1999w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/throughput-memory-usage-comparison-warp-jax-179x76.png 179w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/throughput-memory-usage-comparison-warp-jax-300x128.png 300w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/throughput-memory-usage-comparison-warp-jax-768x328.png 768w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/throughput-memory-usage-comparison-warp-jax-625x267.png 625w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/throughput-memory-usage-comparison-warp-jax-1536x656.png 1536w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/throughput-memory-usage-comparison-warp-jax-645x276.png 645w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/throughput-memory-usage-comparison-warp-jax-500x214.png 500w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/throughput-memory-usage-comparison-warp-jax-160x68.png 160w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/throughput-memory-usage-comparison-warp-jax-362x155.png 362w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/throughput-memory-usage-comparison-warp-jax-257x110.png 257w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/throughput-memory-usage-comparison-warp-jax-1024x437.png 1024w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/throughput-memory-usage-comparison-warp-jax-960x410.png 960w\" sizes=\"auto, (max-width: 1999px) 100vw, 1999px\" \/><button\n\t\t\tclass=\"lightbox-trigger\"\n\t\t\ttype=\"button\"\n\t\t\taria-haspopup=\"dialog\"\n\t\t\taria-label=\"Enlarge\"\n\t\t\tdata-wp-init=\"callbacks.initTriggerButton\"\n\t\t\tdata-wp-on--click=\"actions.showLightbox\"\n\t\t\tdata-wp-style--right=\"state.imageButtonRight\"\n\t\t\tdata-wp-style--top=\"state.imageButtonTop\"\n\t\t>\n\t\t\t<svg xmlns=\"https:\/\/2.zoppoz.workers.dev:443\/http\/www.w3.org\/2000\/svg\" width=\"12\" height=\"12\" fill=\"none\" viewBox=\"0 0 12 12\">\n\t\t\t\t<path fill=\"#fff\" d=\"M2 0a2 2 0 0 0-2 2v2h1.5V2a.5.5 0 0 1 .5-.5h2V0H2Zm2 10.5H2a.5.5 0 0 1-.5-.5V8H0v2a2 2 0 0 0 2 2h2v-1.5ZM8 12v-1.5h2a.5.5 0 0 0 .5-.5V8H12v2a2 2 0 0 1-2 2H8Zm2-12a2 2 0 0 1 2 2v2h-1.5V2a.5.5 0 0 0-.5-.5H8V0h2Z\" \/>\n\t\t\t<\/svg>\n\t\t<\/button><figcaption class=\"wp-element-caption\"><em>Figure 7. Throughput and memory usage comparison between Warp and JAX<\/em><\/figcaption><\/figure>\n<\/div>\n\n\n<p>To learn more, see <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer.nvidia.com\/blog\/autodesk-research-brings-warp-speed-to-computational-fluid-dynamics-on-nvidia-gh200\/?ncid=so-link-133860\">Autodesk Research Brings Warp Speed to Computational Fluid Dynamics on NVIDIA GH200<\/a>.<\/p>\n\n\n\n<h3 id=\"google_deepmind_mujoco\"  class=\"wp-block-heading\">Google DeepMind MuJoCo<a href=\"#google_deepmind_mujoco\" aria-label=\"Scroll to Google DeepMind MuJoCo section\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>Google DeepMind has recently released MuJoCo Warp (MJWarp), a Warp-based backend for large-scale multibody dynamics. The Warp backend reaches up to 252x (locomotion) and 475x (manipulation) speedups over JAX on comparable hardware. MJWarp gets there by exploiting sparse matrix operations and speculative execution to more precisely dispatch compute, while remaining plug-compatible with JAX training.<\/p>\n\n\n\n<figure class=\"wp-block-gallery has-nested-images columns-default is-cropped wp-block-gallery-3 is-layout-flex wp-block-gallery-is-layout-flex\">\n<figure data-wp-context=\"{&quot;imageId&quot;:&quot;6a2f7c13e7a0f&quot;}\" data-wp-interactive=\"core\/image\" data-wp-key=\"6a2f7c13e7a0f\" class=\"wp-block-image size-full wp-lightbox-container\"><img loading=\"lazy\" decoding=\"async\" width=\"600\" height=\"371\" data-wp-class--hide=\"state.isContentHidden\" data-wp-class--show=\"state.isContentVisible\" data-wp-init=\"callbacks.setButtonStyles\" data-wp-on--click=\"actions.showLightbox\" data-wp-on--load=\"callbacks.setButtonStyles\" data-wp-on-window--resize=\"callbacks.setButtonStyles\" data-id=\"113573\" src=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/leap-hand-manipulation-physics-step.webp\" alt=\"MJWarp physics step throughput versus MuJoCo MJX on LEAP benchmarks.\" class=\"wp-image-113573\" srcset=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/leap-hand-manipulation-physics-step.webp 600w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/leap-hand-manipulation-physics-step-179x111.png 179w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/leap-hand-manipulation-physics-step-300x186.png 300w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/leap-hand-manipulation-physics-step-485x300.png 485w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/leap-hand-manipulation-physics-step-146x90.png 146w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/leap-hand-manipulation-physics-step-362x224.png 362w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/leap-hand-manipulation-physics-step-178x110.png 178w\" sizes=\"auto, (max-width: 600px) 100vw, 600px\" \/><button\n\t\t\tclass=\"lightbox-trigger\"\n\t\t\ttype=\"button\"\n\t\t\taria-haspopup=\"dialog\"\n\t\t\taria-label=\"Enlarge\"\n\t\t\tdata-wp-init=\"callbacks.initTriggerButton\"\n\t\t\tdata-wp-on--click=\"actions.showLightbox\"\n\t\t\tdata-wp-style--right=\"state.imageButtonRight\"\n\t\t\tdata-wp-style--top=\"state.imageButtonTop\"\n\t\t>\n\t\t\t<svg xmlns=\"https:\/\/2.zoppoz.workers.dev:443\/http\/www.w3.org\/2000\/svg\" width=\"12\" height=\"12\" fill=\"none\" viewBox=\"0 0 12 12\">\n\t\t\t\t<path fill=\"#fff\" d=\"M2 0a2 2 0 0 0-2 2v2h1.5V2a.5.5 0 0 1 .5-.5h2V0H2Zm2 10.5H2a.5.5 0 0 1-.5-.5V8H0v2a2 2 0 0 0 2 2h2v-1.5ZM8 12v-1.5h2a.5.5 0 0 0 .5-.5V8H12v2a2 2 0 0 1-2 2H8Zm2-12a2 2 0 0 1 2 2v2h-1.5V2a.5.5 0 0 0-.5-.5H8V0h2Z\" \/>\n\t\t\t<\/svg>\n\t\t<\/button><\/figure>\n\n\n\n<figure data-wp-context=\"{&quot;imageId&quot;:&quot;6a2f7c13e826e&quot;}\" data-wp-interactive=\"core\/image\" data-wp-key=\"6a2f7c13e826e\" class=\"wp-block-image size-full wp-lightbox-container\"><img loading=\"lazy\" decoding=\"async\" width=\"600\" height=\"371\" data-wp-class--hide=\"state.isContentHidden\" data-wp-class--show=\"state.isContentVisible\" data-wp-init=\"callbacks.setButtonStyles\" data-wp-on--click=\"actions.showLightbox\" data-wp-on--load=\"callbacks.setButtonStyles\" data-wp-on-window--resize=\"callbacks.setButtonStyles\" data-id=\"113575\" src=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/apptronik-locomotion-physics-step.webp\" alt=\"MJWarp physics step throughput versus MuJoCo MJX on Apptronik benchmarks.\" class=\"wp-image-113575\" srcset=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/apptronik-locomotion-physics-step.webp 600w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/apptronik-locomotion-physics-step-179x111.png 179w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/apptronik-locomotion-physics-step-300x186.png 300w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/apptronik-locomotion-physics-step-485x300.png 485w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/apptronik-locomotion-physics-step-146x90.png 146w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/apptronik-locomotion-physics-step-362x224.png 362w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/apptronik-locomotion-physics-step-178x110.png 178w\" sizes=\"auto, (max-width: 600px) 100vw, 600px\" \/><button\n\t\t\tclass=\"lightbox-trigger\"\n\t\t\ttype=\"button\"\n\t\t\taria-haspopup=\"dialog\"\n\t\t\taria-label=\"Enlarge\"\n\t\t\tdata-wp-init=\"callbacks.initTriggerButton\"\n\t\t\tdata-wp-on--click=\"actions.showLightbox\"\n\t\t\tdata-wp-style--right=\"state.imageButtonRight\"\n\t\t\tdata-wp-style--top=\"state.imageButtonTop\"\n\t\t>\n\t\t\t<svg xmlns=\"https:\/\/2.zoppoz.workers.dev:443\/http\/www.w3.org\/2000\/svg\" width=\"12\" height=\"12\" fill=\"none\" viewBox=\"0 0 12 12\">\n\t\t\t\t<path fill=\"#fff\" d=\"M2 0a2 2 0 0 0-2 2v2h1.5V2a.5.5 0 0 1 .5-.5h2V0H2Zm2 10.5H2a.5.5 0 0 1-.5-.5V8H0v2a2 2 0 0 0 2 2h2v-1.5ZM8 12v-1.5h2a.5.5 0 0 0 .5-.5V8H12v2a2 2 0 0 1-2 2H8Zm2-12a2 2 0 0 1 2 2v2h-1.5V2a.5.5 0 0 0-.5-.5H8V0h2Z\" \/>\n\t\t\t<\/svg>\n\t\t<\/button><\/figure>\n<figcaption class=\"blocks-gallery-caption wp-element-caption\"><em><em>Figure 8. MJWarp physics step throughput versus MuJoCo MJX on LEAP hand manipulation and Apptronik locomotion benchmarks<\/em><\/em><\/figcaption><\/figure>\n\n\n\n<p>To learn more, see the <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/github.com\/google-deepmind\/mujoco\/discussions\/3094\">MuJoCo Warp release announcement<\/a>.<\/p>\n\n\n\n<h3 id=\"c-infinity_autoassembler&nbsp;\"  class=\"wp-block-heading\">C-Infinity AutoAssembler&nbsp;<a href=\"#c-infinity_autoassembler&nbsp;\" aria-label=\"Scroll to C-Infinity AutoAssembler&nbsp; section\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p>The C-Infinity AutoAssembler ASI Engine shows the value of Warp in AI-driven industrial workflows beyond physics simulation. It converts full-fidelity CAD assemblies into motion constraints for AI planning by computing contact, interference, and clearance directly from raw geometry. Current CAD systems do not support these critical queries, which are required to construct manufacturing process plans, evaluate design changes, and generate execution instructions.&nbsp;<\/p>\n\n\n\n<p>The AutoAssembler ASI engine enables building a manufacturing compiler, transforming engineering CAD data directly to assembly instructions for either human or robot consumption. The technology is implemented using Warp kernels optimized for large scale processing to build spatial intelligence.&nbsp;<\/p>\n\n\n\n<p>On an <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/www.nvidia.com\/en-us\/data-center\/l4\/\">NVIDIA L4 Tensor Core<\/a> GPU, the Warp GPU backend achieved a speedup of up to 669x over optimized CPU baselines (based on state of the art libraries including FCL plus Embree). The technology is already in use within enterprise manufacturing workflows at top OEMs.<\/p>\n\n\n<div class=\"wp-block-image\">\n<figure data-wp-context=\"{&quot;imageId&quot;:&quot;6a2f7c13e8df1&quot;}\" data-wp-interactive=\"core\/image\" data-wp-key=\"6a2f7c13e8df1\" class=\"aligncenter size-full wp-lightbox-container\"><img loading=\"lazy\" decoding=\"async\" width=\"1999\" height=\"1174\" data-wp-class--hide=\"state.isContentHidden\" data-wp-class--show=\"state.isContentVisible\" data-wp-init=\"callbacks.setButtonStyles\" data-wp-on--click=\"actions.showLightbox\" data-wp-on--load=\"callbacks.setButtonStyles\" data-wp-on-window--resize=\"callbacks.setButtonStyles\" src=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/liason-graph-construction-performance.webp\" alt=\"Bar chart on log scale comparing CPU and GPU execution time for liaison graph construction on five meshes (59K\u201315M triangles). GPU speedups range from 57x to 669x.\n\" class=\"wp-image-113123\" srcset=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/liason-graph-construction-performance.webp 1999w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/liason-graph-construction-performance-179x105.png 179w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/liason-graph-construction-performance-300x176.png 300w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/liason-graph-construction-performance-768x451.png 768w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/liason-graph-construction-performance-625x367.png 625w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/liason-graph-construction-performance-1536x902.png 1536w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/liason-graph-construction-performance-645x379.png 645w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/liason-graph-construction-performance-500x294.png 500w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/liason-graph-construction-performance-153x90.png 153w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/liason-graph-construction-performance-362x213.png 362w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/liason-graph-construction-performance-187x110.png 187w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/liason-graph-construction-performance-1024x601.png 1024w, https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/02\/liason-graph-construction-performance-919x540.png 919w\" sizes=\"auto, (max-width: 1999px) 100vw, 1999px\" \/><button\n\t\t\tclass=\"lightbox-trigger\"\n\t\t\ttype=\"button\"\n\t\t\taria-haspopup=\"dialog\"\n\t\t\taria-label=\"Enlarge\"\n\t\t\tdata-wp-init=\"callbacks.initTriggerButton\"\n\t\t\tdata-wp-on--click=\"actions.showLightbox\"\n\t\t\tdata-wp-style--right=\"state.imageButtonRight\"\n\t\t\tdata-wp-style--top=\"state.imageButtonTop\"\n\t\t>\n\t\t\t<svg xmlns=\"https:\/\/2.zoppoz.workers.dev:443\/http\/www.w3.org\/2000\/svg\" width=\"12\" height=\"12\" fill=\"none\" viewBox=\"0 0 12 12\">\n\t\t\t\t<path fill=\"#fff\" d=\"M2 0a2 2 0 0 0-2 2v2h1.5V2a.5.5 0 0 1 .5-.5h2V0H2Zm2 10.5H2a.5.5 0 0 1-.5-.5V8H0v2a2 2 0 0 0 2 2h2v-1.5ZM8 12v-1.5h2a.5.5 0 0 0 .5-.5V8H12v2a2 2 0 0 1-2 2H8Zm2-12a2 2 0 0 1 2 2v2h-1.5V2a.5.5 0 0 0-.5-.5H8V0h2Z\" \/>\n\t\t\t<\/svg>\n\t\t<\/button><figcaption class=\"wp-element-caption\"><em>Figure 9. Liaison graph construction: CPU (FCL\/Embree) versus AutoAssembler ASI Engine (GPU) across five CAD assemblies of increasing complexity<\/em><\/figcaption><\/figure>\n<\/div>\n\n\n<p>To learn more, see <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/c-infinity.ai\/blog\/autoassembler-asi-accelerated-spatial-intelligence\">AutoAssembler ASI: Accelerated Spatial Intelligence, C-Infinity<\/a>.<\/p>\n\n\n\n<h2 id=\"get_started_with_warp_for_computational_physics_applications\"  class=\"wp-block-heading\"><strong>Get started with Warp for computational physics applications<\/strong><a href=\"#get_started_with_warp_for_computational_physics_applications\" aria-label=\"Scroll to Get started with Warp for computational physics applications section\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h2>\n\n\n\n<p>Warp enables you to write physics and geometry as GPU kernels in Python, without forcing everything into tensor-based frameworks. In CFD, timestepping and differentiable solves map cleanly to kernels, keeping the structure of the physics intact.<\/p>\n\n\n\n<p>This model already shows up in industrial AI workflows, including the Autodesk differentiable CFD solver, the Google DeepMind multibody dynamics work, and the C-Infinity spatial reasoning engine. With zero-copy interop to PyTorch and JAX, Warp plugs into ML pipelines while preserving the control flow these workloads need, with measured gains in performance, memory, and scalability.<\/p>\n\n\n\n<p>To get started with Warp for computational physics applications, check out these resources:&nbsp;<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/github.com\/NVIDIA\/accelerated-computing-hub\/blob\/32fe3d5a448446fd52c14a6726e1b867cbfed2d9\/Accelerated_Python_User_Guide\/notebooks\/Chapter_12_Intro_to_NVIDIA_Warp.ipynb\">Introduction to NVIDIA Warp notebook<\/a>&nbsp;<\/li>\n\n\n\n<li><a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/github.com\/NVIDIA\/warp\/blob\/main\/warp\/examples\/core\/example_fft_poisson_navier_stokes_2d.py\">2D Navier\u2013Stokes solver example&nbsp;<\/a><\/li>\n\n\n\n<li><a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/github.com\/NVIDIA\/warp\/blob\/main\/warp\/examples\/optim\/example_navier_stokes_perturbation.py\">2D Navier-Stokes optimal perturbation example<\/a>&nbsp;<\/li>\n\n\n\n<li><a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/nvidia.github.io\/warp\/\">NVIDIA Warp documentation<\/a><\/li>\n<\/ul>\n\n\n\n<p>To learn more, join the <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/www.nvidia.com\/gtc\/\">NVIDIA GTC 2026<\/a> session, <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/www.nvidia.com\/gtc\/session-catalog\/sessions\/gtc26-dlit81837\/\">How to Use NVIDIA Warp to Build GPU-Accelerated Computational Physics Simulations [DLIT81837]<\/a>. Watch the<a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/www.nvidia.com\/gtc\/keynote\/\"> GTC keynote<\/a> with NVIDIA founder and CEO Jensen Huang and explore more <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/www.nvidia.com\/gtc\/sessions\/physical-ai-days\/\">physical AI<\/a>, <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/www.nvidia.com\/gtc\/sessions\/robotics\/\">robotics<\/a>, and <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/www.nvidia.com\/gtc\/sessions\/computer-vision-and-video-analytics\/\">vision AI<\/a> GTC sessions.<\/p>\n\n\n\n<h3 id=\"acknowledgments&nbsp;\"  class=\"wp-block-heading\">Acknowledgments&nbsp;<a href=\"#acknowledgments&nbsp;\" aria-label=\"Scroll to Acknowledgments&nbsp; section\" class=\"heading-anchor-link\"><i class=\"fas fa-link\"><\/i><\/a><\/h3>\n\n\n\n<p><em>Thanks to Felix Meyer for contributing to this post and project.<\/em><\/p>\n\n\n\n<p><\/p>\n","protected":false},"excerpt":{"rendered":"<p>Computer-aided engineering (CAE) is shifting from human-driven workflows toward AI-driven ones, including physics foundation models that generalize across geometries and operating conditions. Unlike LLMs, these models depend on large volumes of high-fidelity, physics-compliant data.\u00a0 Recent scaling-law work on computational fluid dynamics (CFD) surrogates indicates that simulation-generated training data is often the limiting cost in practice. &hellip; <a href=\"https:\/\/2.zoppoz.workers.dev:443\/https\/developer.nvidia.com\/blog\/build-accelerated-differentiable-computational-physics-code-for-ai-with-nvidia-warp\/\">Continued<\/a><\/p>\n","protected":false},"author":2941,"featured_media":113150,"comment_status":"closed","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":{"_acf_changed":false,"publish_to_discourse":"","publish_post_category":"318","wpdc_auto_publish_overridden":"1","wpdc_topic_tags":"","wpdc_pin_topic":"","wpdc_pin_until":"","discourse_post_id":"1773107","discourse_permalink":"https:\/\/2.zoppoz.workers.dev:443\/https\/forums.developer.nvidia.com\/t\/build-accelerated-differentiable-computational-physics-code-for-ai-with-nvidia-warp\/363313","wpdc_publishing_response":"success","wpdc_publishing_error":"","nv_subtitle":"","ai_post_summary":"<ul><li>NVIDIA Warp provides a Python-based, GPU-native framework for authoring high-performance, differentiable simulation kernels, enabling direct integration with ML workflows and supporting advanced control flow on computational grids without the limitations of tensor-centric frameworks.<\/li><li>The example 2D NavierStokes solver in Warp leverages FFT-based Poisson solvers, second-order finite difference discretization, and SSP-RK3 time integration, with each computational grid point updated in parallel using the SIMT paradigm for maximal GPU utilization and efficiency.<\/li><li>Warps automatic differentiation system generates both forward and adjoint versions of simulation code, supporting reverse-mode AD for scalable, production-resolution optimization; memory management for differentiable solvers is addressed by explicit intermediate storage and gradient checkpointing.<\/li><li>Industrial case studiesAutodesk XLB, Google DeepMind MuJoCo Warp, and C-Infinity AutoAssemblerdemonstrate that Warp delivers significant speedups (up to 669x over CPU and &gt;250x over JAX in some cases), reduced memory usage, and seamless interoperability with PyTorch and JAX, accelerating AI-driven engineering, robotics, and spatial computing workflows on NVIDIA GPUs.<\/li><\/ul>","footnotes":"","_links_to":"","_links_to_target":""},"categories":[4146,63,503],"tags":[1916,453,55,61,1877,5032],"coauthors":[4755,4273,5030,5031,4794],"class_list":["post-113051","post","type-post","status-publish","format-standard","has-post-thumbnail","hentry","category-development","category-robotics","category-simulation-modeling-design","tag-computational-fluid-dynamics","tag-featured","tag-physics","tag-python","tag-research","tag-warp","tagify_workload-generative-ai","tagify_workload-data-science","tagify_workload-simulation-modeling-design"],"acf":{"post_industry":["HPC \/ Scientific Computing"],"post_products":["General"],"post_learning_levels":["Advanced Technical"],"post_content_types":["Tutorial"],"post_collections":""},"jetpack_featured_media_url":"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-content\/uploads\/2026\/03\/decaying-turbulence.webp","primary_category":{"category":"Simulation \/ Modeling \/ Design","link":"https:\/\/2.zoppoz.workers.dev:443\/https\/developer.nvidia.com\/blog\/category\/simulation-modeling-design\/","id":503,"data_source":""},"nv_translations":[{"language":"zh_CN","title":"\u4f7f\u7528 NVIDIA Warp \u4e3a AI \u6784\u5efa\u52a0\u901f\u7684\u53ef\u5fae\u5206\u8ba1\u7b97\u7269\u7406\u4ee3\u7801","post_id":17091}],"jetpack_shortlink":"https:\/\/2.zoppoz.workers.dev:443\/https\/wp.me\/pcCQAL-tpp","jetpack_likes_enabled":true,"jetpack_sharing_enabled":true,"_links":{"self":[{"href":"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/posts\/113051","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/users\/2941"}],"replies":[{"embeddable":true,"href":"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/comments?post=113051"}],"version-history":[{"count":88,"href":"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/posts\/113051\/revisions"}],"predecessor-version":[{"id":114063,"href":"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/posts\/113051\/revisions\/114063"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/media\/113150"}],"wp:attachment":[{"href":"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/media?parent=113051"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/categories?post=113051"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/tags?post=113051"},{"taxonomy":"author","embeddable":true,"href":"https:\/\/2.zoppoz.workers.dev:443\/https\/developer-blogs.nvidia.com\/wp-json\/wp\/v2\/coauthors?post=113051"}],"curies":[{"name":"wp","href":"https:\/\/2.zoppoz.workers.dev:443\/https\/api.w.org\/{rel}","templated":true}]}}