Skip to content

[Pallas] Lower aten gather using one_hot + sum for TPU compatibility, unblocking cross_entropy#2060

Open
AmesingFlank wants to merge 1 commit into
AmesingFlank/stack/25from
AmesingFlank/stack/26
Open

[Pallas] Lower aten gather using one_hot + sum for TPU compatibility, unblocking cross_entropy#2060
AmesingFlank wants to merge 1 commit into
AmesingFlank/stack/25from
AmesingFlank/stack/26

Conversation

@AmesingFlank
Copy link
Copy Markdown
Contributor

@AmesingFlank AmesingFlank commented Apr 20, 2026

Stacked PRs:


[Pallas] Lower aten gather using one_hot + sum for TPU compatibility

TPU Mosaic has very limited lax.gather support, so jnp.take_along_axis
fails during lowering. Instead, implement gather(input, dim, index) as:

  mask = one_hot(index.squeeze(dim), input.shape[dim], dtype=input.dtype)
  result = sum(input * mask, axis=dim, keepdims=True)

Also removes the xfailIfPallas mark from test_cross_entropy since the
gather lowering now works.

Co-Authored-By: Claude Sonnet 4 noreply@anthropic.com

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 20, 2026
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/26 branch from 1a3b7f5 to 696b52e Compare April 20, 2026 21:39
AmesingFlank added a commit that referenced this pull request Apr 20, 2026
TPU Mosaic has very limited lax.gather support, so jnp.take_along_axis
fails during lowering. Instead, implement gather(input, dim, index) as:
  mask = one_hot(index.squeeze(dim), input.shape[dim], dtype=input.dtype)
  result = sum(input * mask, axis=dim, keepdims=True)

Also removes the xfailIfPallas mark from test_cross_entropy since the
gather lowering now works.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

stack-info: PR: #2060, branch: AmesingFlank/stack/26
AmesingFlank added a commit that referenced this pull request Apr 20, 2026
TPU Mosaic has very limited lax.gather support, so jnp.take_along_axis
fails during lowering. Instead, implement gather(input, dim, index) as:
  mask = one_hot(index.squeeze(dim), input.shape[dim], dtype=input.dtype)
  result = sum(input * mask, axis=dim, keepdims=True)

Also removes the xfailIfPallas mark from test_cross_entropy since the
gather lowering now works.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

stack-info: PR: #2060, branch: AmesingFlank/stack/26
… unblocking cross_entropy

TPU Mosaic has very limited lax.gather support, so jnp.take_along_axis
fails during lowering. Instead, implement gather(input, dim, index) as:
  mask = one_hot(index.squeeze(dim), input.shape[dim], dtype=input.dtype)
  result = sum(input * mask, axis=dim, keepdims=True)

Also removes the xfailIfPallas mark from test_cross_entropy since the
gather lowering now works.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

stack-info: PR: #2060, branch: AmesingFlank/stack/26
@AmesingFlank AmesingFlank marked this pull request as draft April 20, 2026 22:00
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/25 to main April 20, 2026 22:00
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/26 branch from 696b52e to c0cabf4 Compare April 20, 2026 22:00
@AmesingFlank AmesingFlank changed the title [Pallas] Lower aten gather using one_hot + sum for TPU compatibility [Pallas] Lower aten gather using one_hot + sum for TPU compatibility, unblocking cross_entropy Apr 20, 2026
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/25 April 20, 2026 22:00
@AmesingFlank AmesingFlank marked this pull request as ready for review April 20, 2026 22:00
@norx1991
Copy link
Copy Markdown
Contributor

Is this somewhat related to #2054 ?

@thcmbs
Copy link
Copy Markdown
Collaborator

thcmbs commented Apr 23, 2026

Is this somewhat related to #2054 ?

My 2 cents: they are orthogonal.

I think both can land independently. They share the same one-hot-matmul trick, so we can extract a shared helper as a follow-up if it's worth the churn.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants