AI Model Training with JAX

Last Updated : 23 Aug, 2025

JAX is a cutting edge machine learning and numerical computing library developed by Google that combines the familiarity of NumPy with powerful features like automatic differentiation, just-in-time (JIT) compilation and vectorization for highly efficient model training. It seamlessly runs code on CPUs, GPUs and TPUs using XLA compilation to maximize speed and hardware utilization all without requiring manual device placement calls like .cuda() in PyTorch.

Building on JAX, Flax is a neural network library that provides higher-level abstractions such as nn.Module to enable rapid experimentation with deep neural architectures in a modular and scalable way. Flax supports advanced features including checkpointing, regularization and multi-device training, making it ideal for scalable research and production workflows that fully leverage JAX’s performance and accelerator capabilities.

It provides:

  • NumPy API: Provides a familiar interface for those who use NumPy but supercharges performance on accelerators.
  • Automatic Differentiation: Easily computes gradients for arbitrary functions which is essential for deep learning.
  • JIT Compilation (jax.jit): Compiles functions to optimized machine code for speedups.
  • Auto-vectorization (jax.vmap, jax.pmap): Effortlessly parallelises computations across data batches and devices.

Implementation

Lets see a example of making a model using jax:

Step 1 : Importing Required Libraries

JAX provides a NumPy-like API (jnp) for high-performance arrays and mathematical operations and supports automatic differentiation.

Python
import jax
import jax.numpy as jnp

Step 2: Defining Model Initialization

Set up the weights and bias for your linear regression model:

Python
def init_params(rng_key, input_dim, output_dim):
    w_key, b_key = jax.random.split(rng_key)
    W = jax.random.normal(w_key, (input_dim, output_dim))
    b = jax.random.normal(b_key, (output_dim,))
    return {'W': W, 'b': b}

Step 3 : Defining the Model (Linear Layer)

Python
def model(params, x):
    return jnp.dot(x, params['W']) + params['b']

Step 4 : Defining the Loss Function

Here we will use Mean Squared Error as loss function.

Python
def loss_fn(params, x, y):
    preds = model(params, x)
    return jnp.mean((preds - y) ** 2)

Step 5 : Defining One Gradient Update Step

Here JIT (@jax.jit) compiles this to run as fast as possible on CPU, GPU or TPU.

Python
@jax.jit
def update(params, x, y, lr=0.01):
    grads = jax.grad(loss_fn)(params, x, y)
    return {k: v - lr * grads[k] for k, v in params.items()}

Step 6 : Generating Training and Testing Data

Here 80% data will be used for training and 20% for testing.

Python
key = jax.random.PRNGKey(0)
n_train, n_test = 256, 64
x_train = jax.random.normal(key, (n_train, 2))
true_w = jnp.array([[1.5], [-2.0]])
true_b = jnp.array([0.5])
y_train = x_train @ true_w + true_b + 0.1 * jax.random.normal(key, (n_train, 1))

x_test = jax.random.normal(key, (n_test, 2))
y_test = x_test @ true_w + true_b + 0.1 * jax.random.normal(key, (n_test, 1))

Step 7 : Initialize Model Parameters

Python
params = init_params(key, input_dim=2, output_dim=1)

Step 8 : Training Loop

Perform multiple updates over the training data. Here we set epochs to 100.

Python
epochs = 100
for epoch in range(epochs):
    params = update(params, x_train, y_train)
    if (epoch+1) % 20 == 0:
        train_loss = loss_fn(params, x_train, y_train)
        print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}")

Step 9 : Evaluating the Model

Here we evaluate model and a low test loss means the model learned well.

Python
test_loss = loss_fn(params, x_test, y_test)
print(f"Test Loss: {test_loss:.4f}")

Step 10. Make a Sample Prediction

Python
sample_x = jnp.array([[0.0, 0.0]])
pred_y = model(params, sample_x)
print("Prediction for input [0.0, 0.0]:", pred_y)

Output:

Screenshot-2025-07-18-at-11558PM
Output

Google Colab Link : AI Model Training with JAX

Best Practices and Common Pitfalls

  • Pure Functions: All JAX-transformed functions (like those passed to jit, vmap or grad) must be pure: no side effects, consistent outputs for same inputs.
  • Statelessness: Keep parameters explicit and always pass them to your functions.
  • Randomness: JAX uses functional random number generation; you manage RNG keys explicitly for reproducibility.
  • No In-place Mutation: Operations must create new arrays for updates no in-place value changes as in NumPy.

Practical Use Case: Training on Real Datasets

  • Dataset is split into batches using data loaders.
  • Training and evaluation steps are JIT-compiled for speed.
  • Model parameters reside natively on the accelerator for the whole training.
Comment

Explore