Skip to content

[SPMD][PoC] XLAShardedTensor & mark_sharding API #3476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
May 31, 2022
Merged

Conversation

yeounoh
Copy link
Contributor

@yeounoh yeounoh commented Apr 6, 2022

This contributes to #3871 . Test GSPMD feature in PyTorch/XLA with a single host 8 core setup. This PR aims to provide example implementations of

  • XLAShardedTensor and mark_sharding API to annotate different sharding strategies (e.g., tiled, replicated)
  • Extended XLATensor and IR node implementation to support XLA sharding annotation and clearing
  • Partitioning HLO computation graphs for GSPMD

Additionally, the full PoC example requires removing a constraint where only one core can be assigned to a replica. This is not feasible with the current runtime client, and I will also investigate and see if we can remove such a restriction and compile & execute the partitioned computation in XRT.

We will also merge with the TPU PjRT sometime soon, replacing XRT with PjRT and continue the SPMD experimentation, use PjRT SPMD as a reference implementation. cc @will-cromar

@yeounoh yeounoh requested review from miladm and JackCaoG April 6, 2022 22:57
@yeounoh yeounoh marked this pull request as draft April 6, 2022 22:58
@miladm miladm added the DO_NOT_MERGE Not for merging. label Apr 6, 2022
@adamantboy
Copy link

hello, could you please tell me the commit id of your pytorch?

@yeounoh yeounoh force-pushed the xla_spmd_test branch 2 times, most recently from 4b4cbac to 109803b Compare May 9, 2022 17:36
@yeounoh yeounoh requested a review from will-cromar May 9, 2022 20:26
@yeounoh yeounoh force-pushed the xla_spmd_test branch 21 times, most recently from fed00ac to ff409e4 Compare May 10, 2022 23:17
@yeounoh yeounoh marked this pull request as ready for review May 27, 2022 18:37
@yeounoh yeounoh merged commit 97ab3b8 into master May 31, 2022
@yeounoh yeounoh deleted the xla_spmd_test branch May 31, 2022 15:05
@yeounoh
Copy link
Contributor Author

yeounoh commented May 31, 2022

Landing, I will create a separate PR for PJRT integration.

@yeounoh yeounoh self-assigned this Jul 6, 2022
@yeounoh yeounoh removed the DO_NOT_MERGE Not for merging. label Aug 27, 2022
@yeounoh yeounoh changed the title XLA SPMD PoC implementation with XRT XLAShardedTensor & mark_sharding PoC Aug 29, 2022
@yeounoh yeounoh changed the title XLAShardedTensor & mark_sharding PoC [SPMD][PoC] XLAShardedTensor & mark_sharding API Aug 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
distributed SPMD and other distributed things.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants