This readme will have a subsection for every example *.py file.
Please follow the instructions in README.md to install torchax, then install requirements for all of the examples with
pip install -r requirements.txt
This file constructed by first copy & paste code fragments from this pytorch training tutorial: https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
Then adding few lines of code that serves the purpose of moving torch.Tensor
into
XLA devices
.
Example:
state_dict = pytree.tree_map_only(torch.Tensor,
torchax.tensor.move_to_device, state_dict)
This fragment moves the state_dict to XLA devices; then the state_dict is passed
back to model via load_state_dict
.
Then, you can train the model. This shows what is minimum to train a model on XLA
devices. The perf is not as good because we didn't use jax.jit
, this is intentional
as it is meant to showcase the minimum code change.
Example run:
(xla2) hanq-macbookpro:examples hanq$ python basic_training.py
Training set has 60000 instances
Validation set has 10000 instances
Bag Dress Sneaker T-shirt/top
tensor([[0.8820, 0.3807, 0.3010, 0.9266, 0.7253, 0.9265, 0.0688, 0.4567, 0.7035,
0.2279],
[0.3253, 0.1558, 0.1274, 0.2776, 0.2590, 0.4169, 0.1881, 0.7423, 0.4561,
0.5985],
[0.5067, 0.4514, 0.9758, 0.6088, 0.7438, 0.6811, 0.9609, 0.3572, 0.4504,
0.8738],
[0.1850, 0.1217, 0.8551, 0.2120, 0.9902, 0.7623, 0.1658, 0.6980, 0.3086,
0.5709]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.325265645980835
EPOCH 1:
batch 1000 loss: 1.041275198560208
batch 2000 loss: 0.6450189483696595
batch 3000 loss: 0.5793989677671343
batch 4000 loss: 0.5170258888280951
batch 5000 loss: 0.4920090722264722
batch 6000 loss: 0.48910293977567926
batch 7000 loss: 0.48058812761632724
batch 8000 loss: 0.47159107415075413
batch 9000 loss: 0.4712311488997657
batch 10000 loss: 0.4675815168160479
batch 11000 loss: 0.43210567891132085
batch 12000 loss: 0.445208148030797
batch 13000 loss: 0.4119230824254337
batch 14000 loss: 0.4190662656680215
batch 15000 loss: 0.4094535468676477
LOSS train 0.4094535468676477 valid XLA
This file constructed by first copy & paste code fragments from this pytorch training tutorial: https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
Then replacing torch optimizer with optax
optimizer; and use jax.grad
for
gradient instead of torch.Tensor.backward()
.
Then, you can train the model using jax ecosystem's training loop. This is meant to showcase how easy is to integrate with Jax.
Example run:
(xla2) hanq-macbookpro:examples hanq$ python basic_training_jax.py
Training set has 60000 instances
Validation set has 10000 instances
Pullover Ankle Boot Pullover Ankle Boot
tensor([[0.5279, 0.8340, 0.3131, 0.8608, 0.3668, 0.6192, 0.7453, 0.3261, 0.8872,
0.1854],
[0.7414, 0.8309, 0.8127, 0.8866, 0.2475, 0.2664, 0.0327, 0.6918, 0.6010,
0.2766],
[0.3304, 0.9135, 0.2762, 0.6737, 0.0480, 0.6150, 0.5610, 0.5804, 0.9607,
0.6450],
[0.9464, 0.9439, 0.3122, 0.1814, 0.1194, 0.5012, 0.2058, 0.1170, 0.7377,
0.7453]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.4054245948791504
EPOCH 1:
batch 1000 loss: 1.0705260595591972
batch 2000 loss: 1.0997755021179327
batch 3000 loss: 1.0186579653513108
batch 4000 loss: 0.9090727646966116
batch 5000 loss: 0.8309370622411024
batch 6000 loss: 0.8702225417760783
batch 7000 loss: 0.8750176187023462
batch 8000 loss: 0.9652624803795453
batch 9000 loss: 0.8688667197711766
batch 10000 loss: 0.8021814124770199
batch 11000 loss: 0.8000540231048071
batch 12000 loss: 0.9150884484921057
batch 13000 loss: 0.819690621060171
batch 14000 loss: 0.8569030471532278
batch 15000 loss: 0.8740896808278603
LOSS train 0.8740896808278603 valid 2.3132264614105225