A jagged_hstu_attention example that works on Pallas TPU#2218
Open
AmesingFlank wants to merge 1 commit into
Open
A jagged_hstu_attention example that works on Pallas TPU#2218AmesingFlank wants to merge 1 commit into
AmesingFlank wants to merge 1 commit into
Conversation
AmesingFlank
added a commit
that referenced
this pull request
May 3, 2026
Add a new jagged HSTU attention example that works on both Triton and Pallas backends, unlike the original jagged_hstu_attn.py which is Triton-only. Key differences from jagged_hstu_attn.py: - Uses hl.grid([num_batches, H]) + hl.tile(start, end) with data-dependent bounds instead of hl.tile([B, H, max_seq_len]) with if-guards. This avoids wasted computation on padding and eliminates the max_seq_len parameter. - Compatible with both Triton and Pallas backends (the original uses tensor-derived if-predicates that Pallas doesn't support). - Uses >= for causal mask (original uses >, excluding self-attention). - Provides autotune_baseline_fn so the autotuner works even when the default config doesn't fit in TPU VMEM. Benchmark results (B=8, max_seq=128, H=8, D=64): - Triton (H100): ~210x over PyTorch reference - Pallas (TPU v4): ~15x over PyTorch reference Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2218, branch: AmesingFlank/stack/40
243e692 to
1bc1938
Compare
affa93a to
93ddba7
Compare
This was referenced May 3, 2026
Merged
AmesingFlank
added a commit
that referenced
this pull request
May 3, 2026
Add a new jagged HSTU attention example that works on both Triton and Pallas backends, unlike the original jagged_hstu_attn.py which is Triton-only. Key differences from jagged_hstu_attn.py: - Uses hl.grid([num_batches, H]) + hl.tile(start, end) with data-dependent bounds instead of hl.tile([B, H, max_seq_len]) with if-guards. This avoids wasted computation on padding and eliminates the max_seq_len parameter. - Compatible with both Triton and Pallas backends (the original uses tensor-derived if-predicates that Pallas doesn't support). - Uses >= for causal mask (original uses >, excluding self-attention). - Provides autotune_baseline_fn so the autotuner works even when the default config doesn't fit in TPU VMEM. Benchmark results (B=8, max_seq=128, H=8, D=64): - Triton (H100): ~210x over PyTorch reference - Pallas (TPU v4): ~15x over PyTorch reference Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2218, branch: AmesingFlank/stack/40
AmesingFlank
added a commit
that referenced
this pull request
May 3, 2026
Add a new jagged HSTU attention example that works on both Triton and Pallas backends, unlike the original jagged_hstu_attn.py which is Triton-only. Key differences from jagged_hstu_attn.py: - Uses hl.grid([num_batches, H]) + hl.tile(start, end) with data-dependent bounds instead of hl.tile([B, H, max_seq_len]) with if-guards. This avoids wasted computation on padding and eliminates the max_seq_len parameter. - Compatible with both Triton and Pallas backends (the original uses tensor-derived if-predicates that Pallas doesn't support). - Uses >= for causal mask (original uses >, excluding self-attention). - Provides autotune_baseline_fn so the autotuner works even when the default config doesn't fit in TPU VMEM. Benchmark results (B=8, max_seq=128, H=8, D=64): - Triton (H100): ~210x over PyTorch reference - Pallas (TPU v4): ~15x over PyTorch reference Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2218, branch: AmesingFlank/stack/40
93ddba7 to
0da969e
Compare
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
Add a new jagged HSTU attention example that works on both Triton and Pallas backends, unlike the original jagged_hstu_attn.py which is Triton-only. Key differences from jagged_hstu_attn.py: - Uses hl.grid([num_batches, H]) + hl.tile(start, end) with data-dependent bounds instead of hl.tile([B, H, max_seq_len]) with if-guards. This avoids wasted computation on padding and eliminates the max_seq_len parameter. - Compatible with both Triton and Pallas backends (the original uses tensor-derived if-predicates that Pallas doesn't support). - Uses >= for causal mask (original uses >, excluding self-attention). - Provides autotune_baseline_fn so the autotuner works even when the default config doesn't fit in TPU VMEM. Benchmark results (B=8, max_seq=128, H=8, D=64): - Triton (H100): ~210x over PyTorch reference - Pallas (TPU v4): ~15x over PyTorch reference Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2218, branch: AmesingFlank/stack/40
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
Add a new jagged HSTU attention example that works on both Triton and Pallas backends, unlike the original jagged_hstu_attn.py which is Triton-only. Key differences from jagged_hstu_attn.py: - Uses hl.grid([num_batches, H]) + hl.tile(start, end) with data-dependent bounds instead of hl.tile([B, H, max_seq_len]) with if-guards. This avoids wasted computation on padding and eliminates the max_seq_len parameter. - Compatible with both Triton and Pallas backends (the original uses tensor-derived if-predicates that Pallas doesn't support). - Uses >= for causal mask (original uses >, excluding self-attention). - Provides autotune_baseline_fn so the autotuner works even when the default config doesn't fit in TPU VMEM. Benchmark results (B=8, max_seq=128, H=8, D=64): - Triton (H100): ~210x over PyTorch reference - Pallas (TPU v4): ~15x over PyTorch reference Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2218, branch: AmesingFlank/stack/40
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
Add a new jagged HSTU attention example that works on both Triton and Pallas backends, unlike the original jagged_hstu_attn.py which is Triton-only. Key differences from jagged_hstu_attn.py: - Uses hl.grid([num_batches, H]) + hl.tile(start, end) with data-dependent bounds instead of hl.tile([B, H, max_seq_len]) with if-guards. This avoids wasted computation on padding and eliminates the max_seq_len parameter. - Compatible with both Triton and Pallas backends (the original uses tensor-derived if-predicates that Pallas doesn't support). - Uses >= for causal mask (original uses >, excluding self-attention). - Provides autotune_baseline_fn so the autotuner works even when the default config doesn't fit in TPU VMEM. Benchmark results (B=8, max_seq=128, H=8, D=64): - Triton (H100): ~210x over PyTorch reference - Pallas (TPU v4): ~15x over PyTorch reference Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2218, branch: AmesingFlank/stack/40
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
Add a new jagged HSTU attention example that works on both Triton and Pallas backends, unlike the original jagged_hstu_attn.py which is Triton-only. Key differences from jagged_hstu_attn.py: - Uses hl.grid([num_batches, H]) + hl.tile(start, end) with data-dependent bounds instead of hl.tile([B, H, max_seq_len]) with if-guards. This avoids wasted computation on padding and eliminates the max_seq_len parameter. - Compatible with both Triton and Pallas backends (the original uses tensor-derived if-predicates that Pallas doesn't support). - Uses >= for causal mask (original uses >, excluding self-attention). - Provides autotune_baseline_fn so the autotuner works even when the default config doesn't fit in TPU VMEM. Benchmark results (B=8, max_seq=128, H=8, D=64): - Triton (H100): ~210x over PyTorch reference - Pallas (TPU v4): ~15x over PyTorch reference Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2218, branch: AmesingFlank/stack/40
0da969e to
7a9f432
Compare
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
Add a new jagged HSTU attention example that works on both Triton and Pallas backends, unlike the original jagged_hstu_attn.py which is Triton-only. Key differences from jagged_hstu_attn.py: - Uses hl.grid([num_batches, H]) + hl.tile(start, end) with data-dependent bounds instead of hl.tile([B, H, max_seq_len]) with if-guards. This avoids wasted computation on padding and eliminates the max_seq_len parameter. - Compatible with both Triton and Pallas backends (the original uses tensor-derived if-predicates that Pallas doesn't support). - Uses >= for causal mask (original uses >, excluding self-attention). - Provides autotune_baseline_fn so the autotuner works even when the default config doesn't fit in TPU VMEM. Benchmark results (B=8, max_seq=128, H=8, D=64): - Triton (H100): ~210x over PyTorch reference - Pallas (TPU v4): ~15x over PyTorch reference Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2218, branch: AmesingFlank/stack/40
23d7cdb to
129e8a2
Compare
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
Add a new jagged HSTU attention example that works on both Triton and Pallas backends, unlike the original jagged_hstu_attn.py which is Triton-only. This kernel works on Pallas TPU, because it doesn't depend on any integer indexing, but instead, it uses tiled access on a jagged dimension, which translates to DMA that dynamically slices the -3th dimension, while loading the -2th and -1th dimension in full (`pltpu.make_async_copy(q[pl.ds(start, size), :, :])`, which is the same way RaggedPagedAttention handles jagged data) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2218, branch: AmesingFlank/stack/40
129e8a2 to
24e75f8
Compare
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
Add a new jagged HSTU attention example that works on both Triton and Pallas backends, unlike the original jagged_hstu_attn.py which is Triton-only. This kernel works on Pallas TPU, because it doesn't depend on any integer indexing, but instead, it uses tiled access on a jagged dimension, which translates to DMA that dynamically slices the -3th dimension, while loading the -2th and -1th dimension in full (`pltpu.make_async_copy(q[pl.ds(start, size), :, :])`, which is the same way RaggedPagedAttention handles jagged data) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2218, branch: AmesingFlank/stack/40
AmesingFlank
added a commit
that referenced
this pull request
May 4, 2026
Add a new jagged HSTU attention example that works on both Triton and Pallas backends, unlike the original jagged_hstu_attn.py which is Triton-only. This kernel works on Pallas TPU, because it doesn't depend on any integer indexing, but instead, it uses tiled access on a jagged dimension, which translates to DMA that dynamically slices the -3th dimension, while loading the -2th and -1th dimension in full (`pltpu.make_async_copy(q[pl.ds(start, size), :, :])`, which is the same way RaggedPagedAttention handles jagged data) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2218, branch: AmesingFlank/stack/40
24e75f8 to
ea56f6e
Compare
jansel
approved these changes
May 5, 2026
norx1991
approved these changes
May 6, 2026
This was referenced May 14, 2026
Add a new jagged HSTU attention example that works on Pallas backends, unlike the original jagged_hstu_attn.py which is Triton-only. This kernel works on Pallas TPU, because it doesn't depend on any integer indexing, but instead, it uses tiled access on a jagged dimension, which translates to DMA that dynamically slices the -3th dimension, while loading the -2th and -1th dimension in full (`pltpu.make_async_copy(q[pl.ds(start, size), :, :])`, which is the same way RaggedPagedAttention handles jagged data) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> stack-info: PR: #2218, branch: AmesingFlank/stack/40
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked PRs:
A jagged_hstu_attention example that works on Pallas TPU
Add a new jagged HSTU attention example that works on
Pallas backends, unlike the original jagged_hstu_attn.py which is
Triton-only. This kernel works on Pallas TPU, because it doesn't depend
on any integer indexing, but instead, it uses tiled access on a jagged
dimension, which translates to DMA that dynamically slices the -3th
dimension, while loading the -2th and -1th dimension in full
(
pltpu.make_async_copy(q[pl.ds(start, size), :, :]), which is the sameway RaggedPagedAttention handles jagged data)
Co-Authored-By: Claude Opus 4.6 noreply@anthropic.com