Skip to content

Granular device management for Jax #1553

@justjhong

Description

@justjhong

The jax backend currently has a very simple device management option. We would like to improve it such that users can specify specific devices to run on like with our Pytorch models. This should automatically move jax numpy arrays using jax.device_put in a similar way the @auto_move_data decorator does.

Helpful issue jax-ml/jax#1914

Will need a way to move an existing train state across devices: google/flax#1783

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions