JAX is a cutting edge machine learning and numerical computing library developed by Google that combines the familiarity of NumPy with powerful features like automatic differentiation, just-in-time (JIT) compilation and vectorization for highly efficient model training. It seamlessly runs code on CPUs, GPUs and TPUs using XLA compilation to maximize speed and hardware utilization all without requiring manual device placement calls like .cuda() in PyTorch.
Building on JAX, Flax is a neural network library that provides higher-level abstractions such as nn.Module to enable rapid experimentation with deep neural architectures in a modular and scalable way. Flax supports advanced features including checkpointing, regularization and multi-device training, making it ideal for scalable research and production workflows that fully leverage JAX’s performance and accelerator capabilities.
It provides:
- NumPy API: Provides a familiar interface for those who use NumPy but supercharges performance on accelerators.
- Automatic Differentiation: Easily computes gradients for arbitrary functions which is essential for deep learning.
- JIT Compilation (jax.jit): Compiles functions to optimized machine code for speedups.
- Auto-vectorization (jax.vmap, jax.pmap): Effortlessly parallelises computations across data batches and devices.
Implementation
Lets see a example of making a model using jax:
Step 1 : Importing Required Libraries
JAX provides a NumPy-like API (jnp) for high-performance arrays and mathematical operations and supports automatic differentiation.
import jax
import jax.numpy as jnp
Step 2: Defining Model Initialization
Set up the weights and bias for your linear regression model:
def init_params(rng_key, input_dim, output_dim):
w_key, b_key = jax.random.split(rng_key)
W = jax.random.normal(w_key, (input_dim, output_dim))
b = jax.random.normal(b_key, (output_dim,))
return {'W': W, 'b': b}
Step 3 : Defining the Model (Linear Layer)
def model(params, x):
return jnp.dot(x, params['W']) + params['b']
Step 4 : Defining the Loss Function
Here we will use Mean Squared Error as loss function.
def loss_fn(params, x, y):
preds = model(params, x)
return jnp.mean((preds - y) ** 2)
Step 5 : Defining One Gradient Update Step
Here JIT (@jax.jit) compiles this to run as fast as possible on CPU, GPU or TPU.
@jax.jit
def update(params, x, y, lr=0.01):
grads = jax.grad(loss_fn)(params, x, y)
return {k: v - lr * grads[k] for k, v in params.items()}
Step 6 : Generating Training and Testing Data
Here 80% data will be used for training and 20% for testing.
key = jax.random.PRNGKey(0)
n_train, n_test = 256, 64
x_train = jax.random.normal(key, (n_train, 2))
true_w = jnp.array([[1.5], [-2.0]])
true_b = jnp.array([0.5])
y_train = x_train @ true_w + true_b + 0.1 * jax.random.normal(key, (n_train, 1))
x_test = jax.random.normal(key, (n_test, 2))
y_test = x_test @ true_w + true_b + 0.1 * jax.random.normal(key, (n_test, 1))
Step 7 : Initialize Model Parameters
params = init_params(key, input_dim=2, output_dim=1)
Step 8 : Training Loop
Perform multiple updates over the training data. Here we set epochs to 100.
epochs = 100
for epoch in range(epochs):
params = update(params, x_train, y_train)
if (epoch+1) % 20 == 0:
train_loss = loss_fn(params, x_train, y_train)
print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}")
Step 9 : Evaluating the Model
Here we evaluate model and a low test loss means the model learned well.
test_loss = loss_fn(params, x_test, y_test)
print(f"Test Loss: {test_loss:.4f}")
Step 10. Make a Sample Prediction
sample_x = jnp.array([[0.0, 0.0]])
pred_y = model(params, sample_x)
print("Prediction for input [0.0, 0.0]:", pred_y)
Output:

Google Colab Link : AI Model Training with JAX
Best Practices and Common Pitfalls
- Pure Functions: All JAX-transformed functions (like those passed to jit, vmap or grad) must be pure: no side effects, consistent outputs for same inputs.
- Statelessness: Keep parameters explicit and always pass them to your functions.
- Randomness: JAX uses functional random number generation; you manage RNG keys explicitly for reproducibility.
- No In-place Mutation: Operations must create new arrays for updates no in-place value changes as in NumPy.
Practical Use Case: Training on Real Datasets
- Dataset is split into batches using data loaders.
- Training and evaluation steps are JIT-compiled for speed.
- Model parameters reside natively on the accelerator for the whole training.