-
Notifications
You must be signed in to change notification settings - Fork 440
Closed
Description
The jax backend currently has a very simple device management option. We would like to improve it such that users can specify specific devices to run on like with our Pytorch models. This should automatically move jax numpy arrays using jax.device_put in a similar way the @auto_move_data decorator does.
Helpful issue jax-ml/jax#1914
Will need a way to move an existing train state across devices: google/flax#1783
Reactions are currently unavailable