Skip to content

A jagged_hstu_attention example that works on Pallas TPU#2218

Open
AmesingFlank wants to merge 1 commit into
mainfrom
AmesingFlank/stack/40
Open

A jagged_hstu_attention example that works on Pallas TPU#2218
AmesingFlank wants to merge 1 commit into
mainfrom
AmesingFlank/stack/40

Conversation

@AmesingFlank
Copy link
Copy Markdown
Contributor

@AmesingFlank AmesingFlank commented May 3, 2026

Stacked PRs:


A jagged_hstu_attention example that works on Pallas TPU

Add a new jagged HSTU attention example that works on
Pallas backends, unlike the original jagged_hstu_attn.py which is
Triton-only. This kernel works on Pallas TPU, because it doesn't depend
on any integer indexing, but instead, it uses tiled access on a jagged
dimension, which translates to DMA that dynamically slices the -3th
dimension, while loading the -2th and -1th dimension in full
(pltpu.make_async_copy(q[pl.ds(start, size), :, :]), which is the same
way RaggedPagedAttention handles jagged data)

Co-Authored-By: Claude Opus 4.6 noreply@anthropic.com

AmesingFlank added a commit that referenced this pull request May 3, 2026
Add a new jagged HSTU attention example that works on both Triton and
Pallas backends, unlike the original jagged_hstu_attn.py which is
Triton-only.

Key differences from jagged_hstu_attn.py:
- Uses hl.grid([num_batches, H]) + hl.tile(start, end) with
  data-dependent bounds instead of hl.tile([B, H, max_seq_len]) with
  if-guards. This avoids wasted computation on padding and eliminates
  the max_seq_len parameter.
- Compatible with both Triton and Pallas backends (the original uses
  tensor-derived if-predicates that Pallas doesn't support).
- Uses >= for causal mask (original uses >, excluding self-attention).
- Provides autotune_baseline_fn so the autotuner works even when the
  default config doesn't fit in TPU VMEM.

Benchmark results (B=8, max_seq=128, H=8, D=64):
- Triton (H100): ~210x over PyTorch reference
- Pallas (TPU v4): ~15x over PyTorch reference

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2218, branch: AmesingFlank/stack/40
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/39 branch from 243e692 to 1bc1938 Compare May 3, 2026 21:54
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/40 branch from affa93a to 93ddba7 Compare May 3, 2026 21:54
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 3, 2026
AmesingFlank added a commit that referenced this pull request May 3, 2026
Add a new jagged HSTU attention example that works on both Triton and
Pallas backends, unlike the original jagged_hstu_attn.py which is
Triton-only.

Key differences from jagged_hstu_attn.py:
- Uses hl.grid([num_batches, H]) + hl.tile(start, end) with
  data-dependent bounds instead of hl.tile([B, H, max_seq_len]) with
  if-guards. This avoids wasted computation on padding and eliminates
  the max_seq_len parameter.
- Compatible with both Triton and Pallas backends (the original uses
  tensor-derived if-predicates that Pallas doesn't support).
- Uses >= for causal mask (original uses >, excluding self-attention).
- Provides autotune_baseline_fn so the autotuner works even when the
  default config doesn't fit in TPU VMEM.

Benchmark results (B=8, max_seq=128, H=8, D=64):
- Triton (H100): ~210x over PyTorch reference
- Pallas (TPU v4): ~15x over PyTorch reference

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2218, branch: AmesingFlank/stack/40
AmesingFlank added a commit that referenced this pull request May 3, 2026
Add a new jagged HSTU attention example that works on both Triton and
Pallas backends, unlike the original jagged_hstu_attn.py which is
Triton-only.

Key differences from jagged_hstu_attn.py:
- Uses hl.grid([num_batches, H]) + hl.tile(start, end) with
  data-dependent bounds instead of hl.tile([B, H, max_seq_len]) with
  if-guards. This avoids wasted computation on padding and eliminates
  the max_seq_len parameter.
- Compatible with both Triton and Pallas backends (the original uses
  tensor-derived if-predicates that Pallas doesn't support).
- Uses >= for causal mask (original uses >, excluding self-attention).
- Provides autotune_baseline_fn so the autotuner works even when the
  default config doesn't fit in TPU VMEM.

Benchmark results (B=8, max_seq=128, H=8, D=64):
- Triton (H100): ~210x over PyTorch reference
- Pallas (TPU v4): ~15x over PyTorch reference

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2218, branch: AmesingFlank/stack/40
@AmesingFlank AmesingFlank marked this pull request as draft May 3, 2026 23:14
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/39 to main May 3, 2026 23:14
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/40 branch from 93ddba7 to 0da969e Compare May 3, 2026 23:14
@AmesingFlank AmesingFlank changed the title Add jagged_hstu_attn_2 example: cross-backend HSTU attention A jagged_hstu_attention example that works on Pallas TPU May 3, 2026
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/39 May 3, 2026 23:14
@AmesingFlank AmesingFlank marked this pull request as ready for review May 3, 2026 23:14
AmesingFlank added a commit that referenced this pull request May 4, 2026
Add a new jagged HSTU attention example that works on both Triton and
Pallas backends, unlike the original jagged_hstu_attn.py which is
Triton-only.

Key differences from jagged_hstu_attn.py:
- Uses hl.grid([num_batches, H]) + hl.tile(start, end) with
  data-dependent bounds instead of hl.tile([B, H, max_seq_len]) with
  if-guards. This avoids wasted computation on padding and eliminates
  the max_seq_len parameter.
- Compatible with both Triton and Pallas backends (the original uses
  tensor-derived if-predicates that Pallas doesn't support).
- Uses >= for causal mask (original uses >, excluding self-attention).
- Provides autotune_baseline_fn so the autotuner works even when the
  default config doesn't fit in TPU VMEM.

Benchmark results (B=8, max_seq=128, H=8, D=64):
- Triton (H100): ~210x over PyTorch reference
- Pallas (TPU v4): ~15x over PyTorch reference

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2218, branch: AmesingFlank/stack/40
AmesingFlank added a commit that referenced this pull request May 4, 2026
Add a new jagged HSTU attention example that works on both Triton and
Pallas backends, unlike the original jagged_hstu_attn.py which is
Triton-only.

Key differences from jagged_hstu_attn.py:
- Uses hl.grid([num_batches, H]) + hl.tile(start, end) with
  data-dependent bounds instead of hl.tile([B, H, max_seq_len]) with
  if-guards. This avoids wasted computation on padding and eliminates
  the max_seq_len parameter.
- Compatible with both Triton and Pallas backends (the original uses
  tensor-derived if-predicates that Pallas doesn't support).
- Uses >= for causal mask (original uses >, excluding self-attention).
- Provides autotune_baseline_fn so the autotuner works even when the
  default config doesn't fit in TPU VMEM.

Benchmark results (B=8, max_seq=128, H=8, D=64):
- Triton (H100): ~210x over PyTorch reference
- Pallas (TPU v4): ~15x over PyTorch reference

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2218, branch: AmesingFlank/stack/40
AmesingFlank added a commit that referenced this pull request May 4, 2026
Add a new jagged HSTU attention example that works on both Triton and
Pallas backends, unlike the original jagged_hstu_attn.py which is
Triton-only.

Key differences from jagged_hstu_attn.py:
- Uses hl.grid([num_batches, H]) + hl.tile(start, end) with
  data-dependent bounds instead of hl.tile([B, H, max_seq_len]) with
  if-guards. This avoids wasted computation on padding and eliminates
  the max_seq_len parameter.
- Compatible with both Triton and Pallas backends (the original uses
  tensor-derived if-predicates that Pallas doesn't support).
- Uses >= for causal mask (original uses >, excluding self-attention).
- Provides autotune_baseline_fn so the autotuner works even when the
  default config doesn't fit in TPU VMEM.

Benchmark results (B=8, max_seq=128, H=8, D=64):
- Triton (H100): ~210x over PyTorch reference
- Pallas (TPU v4): ~15x over PyTorch reference

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2218, branch: AmesingFlank/stack/40
AmesingFlank added a commit that referenced this pull request May 4, 2026
Add a new jagged HSTU attention example that works on both Triton and
Pallas backends, unlike the original jagged_hstu_attn.py which is
Triton-only.

Key differences from jagged_hstu_attn.py:
- Uses hl.grid([num_batches, H]) + hl.tile(start, end) with
  data-dependent bounds instead of hl.tile([B, H, max_seq_len]) with
  if-guards. This avoids wasted computation on padding and eliminates
  the max_seq_len parameter.
- Compatible with both Triton and Pallas backends (the original uses
  tensor-derived if-predicates that Pallas doesn't support).
- Uses >= for causal mask (original uses >, excluding self-attention).
- Provides autotune_baseline_fn so the autotuner works even when the
  default config doesn't fit in TPU VMEM.

Benchmark results (B=8, max_seq=128, H=8, D=64):
- Triton (H100): ~210x over PyTorch reference
- Pallas (TPU v4): ~15x over PyTorch reference

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2218, branch: AmesingFlank/stack/40
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 01:46
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/39 to main May 4, 2026 01:46
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/40 branch from 0da969e to 7a9f432 Compare May 4, 2026 01:46
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/39 May 4, 2026 01:46
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 01:46
AmesingFlank added a commit that referenced this pull request May 4, 2026
Add a new jagged HSTU attention example that works on both Triton and
Pallas backends, unlike the original jagged_hstu_attn.py which is
Triton-only.

Key differences from jagged_hstu_attn.py:
- Uses hl.grid([num_batches, H]) + hl.tile(start, end) with
  data-dependent bounds instead of hl.tile([B, H, max_seq_len]) with
  if-guards. This avoids wasted computation on padding and eliminates
  the max_seq_len parameter.
- Compatible with both Triton and Pallas backends (the original uses
  tensor-derived if-predicates that Pallas doesn't support).
- Uses >= for causal mask (original uses >, excluding self-attention).
- Provides autotune_baseline_fn so the autotuner works even when the
  default config doesn't fit in TPU VMEM.

Benchmark results (B=8, max_seq=128, H=8, D=64):
- Triton (H100): ~210x over PyTorch reference
- Pallas (TPU v4): ~15x over PyTorch reference

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2218, branch: AmesingFlank/stack/40
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 01:52
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 03:21
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/39 to main May 4, 2026 03:21
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/40 branch from 23d7cdb to 129e8a2 Compare May 4, 2026 03:22
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/39 May 4, 2026 03:22
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 03:22
AmesingFlank added a commit that referenced this pull request May 4, 2026
Add a new jagged HSTU attention example that works on both Triton and
Pallas backends, unlike the original jagged_hstu_attn.py which is
Triton-only. This kernel works on Pallas TPU, because it doesn't depend
on any integer indexing, but instead, it uses tiled access on a jagged
dimension, which translates to DMA that dynamically slices the -3th
dimension, while loading the -2th and -1th dimension in full
(`pltpu.make_async_copy(q[pl.ds(start, size), :, :])`, which is the same
way RaggedPagedAttention handles jagged data)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2218, branch: AmesingFlank/stack/40
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 03:33
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/39 to main May 4, 2026 03:33
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/40 branch from 129e8a2 to 24e75f8 Compare May 4, 2026 03:33
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/39 May 4, 2026 03:33
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 03:33
@AmesingFlank AmesingFlank requested review from jansel, norx1991 and oulgen May 4, 2026 15:09
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 16:45
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/39 to main May 4, 2026 16:45
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/39 May 4, 2026 16:45
@AmesingFlank AmesingFlank marked this pull request as ready for review May 4, 2026 16:45
AmesingFlank added a commit that referenced this pull request May 4, 2026
Add a new jagged HSTU attention example that works on both Triton and
Pallas backends, unlike the original jagged_hstu_attn.py which is
Triton-only. This kernel works on Pallas TPU, because it doesn't depend
on any integer indexing, but instead, it uses tiled access on a jagged
dimension, which translates to DMA that dynamically slices the -3th
dimension, while loading the -2th and -1th dimension in full
(`pltpu.make_async_copy(q[pl.ds(start, size), :, :])`, which is the same
way RaggedPagedAttention handles jagged data)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2218, branch: AmesingFlank/stack/40
AmesingFlank added a commit that referenced this pull request May 4, 2026
Add a new jagged HSTU attention example that works on both Triton and
Pallas backends, unlike the original jagged_hstu_attn.py which is
Triton-only. This kernel works on Pallas TPU, because it doesn't depend
on any integer indexing, but instead, it uses tiled access on a jagged
dimension, which translates to DMA that dynamically slices the -3th
dimension, while loading the -2th and -1th dimension in full
(`pltpu.make_async_copy(q[pl.ds(start, size), :, :])`, which is the same
way RaggedPagedAttention handles jagged data)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2218, branch: AmesingFlank/stack/40
@AmesingFlank AmesingFlank marked this pull request as draft May 4, 2026 17:55
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/39 to main May 4, 2026 17:55
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/40 branch from 24e75f8 to ea56f6e Compare May 4, 2026 17:55
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/39 May 4, 2026 17:55
Add a new jagged HSTU attention example that works on
Pallas backends, unlike the original jagged_hstu_attn.py which is
Triton-only. This kernel works on Pallas TPU, because it doesn't depend
on any integer indexing, but instead, it uses tiled access on a jagged
dimension, which translates to DMA that dynamically slices the -3th
dimension, while loading the -2th and -1th dimension in full
(`pltpu.make_async_copy(q[pl.ds(start, size), :, :])`, which is the same
way RaggedPagedAttention handles jagged data)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

stack-info: PR: #2218, branch: AmesingFlank/stack/40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants