# PyTorch on XLA Devices PyTorch runs on XLA devices, like TPUs, with the [torch_xla package](https://github.com/pytorch/xla/). This document describes how to run your models on these devices. ## Creating an XLA Tensor PyTorch/XLA adds a new `xla` device type to PyTorch. This device type works just like other PyTorch device types. For example, here's how to create and print an XLA tensor: ```python import torch import torch_xla import torch_xla.core.xla_model as xm t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t) ``` This code should look familiar. PyTorch/XLA uses the same interface as regular PyTorch with a few additions. Importing `torch_xla` initializes PyTorch/XLA, and `xm.xla_device()` returns the current XLA device. This may be a CPU or TPU depending on your environment. ## XLA Tensors are PyTorch Tensors PyTorch operations can be performed on XLA tensors just like CPU or CUDA tensors. For example, XLA tensors can be added together: ```python t0 = torch.randn(2, 2, device=xm.xla_device()) t1 = torch.randn(2, 2, device=xm.xla_device()) print(t0 + t1) ``` Or matrix multiplied: ```python print(t0.mm(t1)) ``` Or used with neural network modules: ```python l_in = torch.randn(10, device=xm.xla_device()) linear = torch.nn.Linear(10, 20).to(xm.xla_device()) l_out = linear(l_in) print(l_out) ``` Like other device types, XLA tensors only work with other XLA tensors on the same device. So code like ```python l_in = torch.randn(10, device=xm.xla_device()) linear = torch.nn.Linear(10, 20) l_out = linear(l_in) print(l_out) # Input tensor is not an XLA tensor: torch.FloatTensor ``` will throw an error since the `torch.nn.Linear` module is on the CPU. ## Running Models on XLA Devices Building a new PyTorch network or converting an existing one to run on XLA devices requires only a few lines of XLA-specific code. The following snippets highlight these lines when running on a single device and multiple devices with XLA multi-processing. ### Running on a Single XLA Device The following snippet shows a network training on a single XLA device: ```python import torch_xla.core.xla_model as xm device = xm.xla_device() model = MNIST().train().to(device) loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) for data, target in train_loader: optimizer.zero_grad() data = data.to(device) target = target.to(device) output = model(data) loss = loss_fn(output, target) loss.backward() optimizer.step() xm.mark_step() ``` This snippet highlights how easy it is to switch your model to run on XLA. The model definition, dataloader, optimizer and training loop can work on any device. The only XLA-specific code is a couple lines that acquire the XLA device and mark the step. Calling `xm.mark_step()` at the end of each training iteration causes XLA to execute its current graph and update the model's parameters. See [XLA Tensor Deep Dive](#xla-tensor-deep-dive) for more on how XLA creates graphs and runs operations. ### Running on Multiple XLA Devices with Multi-processing PyTorch/XLA makes it easy to accelerate training by running on multiple XLA devices. The following snippet shows how: ```python import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl import torch_xla.distributed.xla_multiprocessing as xmp def _mp_fn(index): device = xm.xla_device() mp_device_loader = pl.MpDeviceLoader(train_loader, device) model = MNIST().train().to(device) loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) for data, target in mp_device_loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) if __name__ == '__main__': xmp.spawn(_mp_fn, args=()) ``` There are three differences between this multi-device snippet and the previous single device snippet: - `xmp.spawn()` creates the processes that each run an XLA device. - `MpDeviceLoader` loads the training data onto each device. - `xm.optimizer_step(optimizer)` consolidates the gradients between cores and issues the XLA device step computation. The model definition, optimizer definition and training loop remain the same. > **NOTE:** It is important to note that, when using multi-processing, the user can start retrieving and accessing XLA devices only from within the target function of `xmp.spawn()` (or any function which has `xmp.spawn()` as parent in the call stack). See the [full multiprocessing example](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist.py) for more on training a network on multiple XLA devices with multi-processing. ## XLA Tensor Deep Dive Using XLA tensors and devices requires changing only a few lines of code. But even though XLA tensors act a lot like CPU and CUDA tensors, their internals are different. This section describes what makes XLA tensors unique. ### XLA Tensors are Lazy CPU and CUDA tensors launch operations immediately or eagerly. XLA tensors, on the other hand, are lazy. They record operations in a graph until the results are needed. Deferring execution like this lets XLA optimize it. A graph of multiple separate operations might be fused into a single optimized operation, for example. Lazy execution is generally invisible to the caller. PyTorch/XLA automatically constructs the graphs, sends them to XLA devices, and synchronizes when copying data between an XLA device and the CPU. Inserting a barrier when taking an optimizer step explicitly synchronizes the CPU and the XLA device. For more information about our lazy tensor design, you can read [this paper](https://arxiv.org/pdf/2102.13267.pdf). ### XLA Tensors and bFloat16 PyTorch/XLA can use the [bfloat16](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format) datatype when running on TPUs. In fact, PyTorch/XLA handles float types (`torch.float` and `torch.double`) differently on TPUs. This behavior is controlled by the `XLA_USE_BF16` and `XLA_DOWNCAST_BF16` environment variable: - By default both `torch.float` and `torch.double` are `torch.float` on TPUs. - If `XLA_USE_BF16` is set, then `torch.float` and `torch.double` are both `bfloat16` on TPUs. - If `XLA_DOWNCAST_BF16` is set, then `torch.float` is `bfloat16` on TPUs and `torch.double` is `float32` on TPUs. - If a PyTorch tensor has `torch.bfloat16` data type, this will be directly mapped to the TPU `bfloat16` (XLA `BF16` primitive type). Developers should note that *XLA tensors on TPUs will always report their PyTorch datatype* regardless of the actual datatype they're using. This conversion is automatic and opaque. If an XLA tensor on a TPU is moved back to the CPU it will be converted from its actual datatype to its PyTorch datatype. Depending on how your code operates, this conversion triggered by the type of processing unit can be important. ### Memory Layout The internal data representation of XLA tensors is opaque to the user. They do not expose their storage and they always appear to be contiguous, unlike CPU and CUDA tensors. This allows XLA to adjust a tensor's memory layout for better performance. ### Moving XLA Tensors to and from the CPU XLA tensors can be moved from the CPU to an XLA device and from an XLA device to the CPU. If a view is moved then the data its viewing is also copied to the other device and the view relationship is not preserved. Put another way, once data is copied to another device it has no relationship with its previous device or any tensors on it. Again, depending on how your code operates, appreciating and accommodating this transition can be important. ### Saving and Loading XLA Tensors XLA tensors should be moved to the CPU before saving, as in the following snippet: ```python import torch import torch_xla import torch_xla.core.xla_model as xm device = xm.xla_device() t0 = torch.randn(2, 2, device=device) t1 = torch.randn(2, 2, device=device) tensors = (t0.cpu(), t1.cpu()) torch.save(tensors, 'tensors.pt') tensors = torch.load('tensors.pt') t0 = tensors[0].to(device) t1 = tensors[1].to(device) ``` This lets you put the loaded tensors on any available device, not just the one on which they were initialized. Per the above note on moving XLA tensors to the CPU, care must be taken when working with views. Instead of saving views it is recommended that you recreate them after the tensors have been loaded and moved to their destination device(s). A utility API is provided to save data by taking care of previously moving it to CPU: ```python import torch import torch_xla import torch_xla.core.xla_model as xm xm.save(model.state_dict(), path) ``` In case of multiple devices, the above API will only save the data for the master device ordinal (0). In case where memory is limited compared to the size of the model parameters, an API is provided that reduces the memory footprint on the host: ```python import torch_xla.utils.serialization as xser xser.save(model.state_dict(), path) ``` This API streams XLA tensors to CPU one at a time, reducing the amount of host memory used, but it requires a matching load API to restore: ```python import torch_xla.utils.serialization as xser state_dict = xser.load(path) model.load_state_dict(state_dict) ``` Directly saving XLA tensors is possible but not recommended. XLA tensors are always loaded back to the device they were saved from, and if that device is unavailable the load will fail. PyTorch/XLA, like all of PyTorch, is under active development and this behavior may change in the future. ## Further Reading Additional documentation is available at the [PyTorch/XLA repo](https://github.com/pytorch/xla/). More examples of running networks on TPUs are available [here](https://github.com/pytorch-tpu/examples).