Skip to content

[Pallas] Use jnp.where() for branches with tensor-derived conditions but no in-place tensor updates#1936

Closed
AmesingFlank wants to merge 1 commit into
AmesingFlank/stack/4from
AmesingFlank/stack/5
Closed

[Pallas] Use jnp.where() for branches with tensor-derived conditions but no in-place tensor updates#1936
AmesingFlank wants to merge 1 commit into
AmesingFlank/stack/4from
AmesingFlank/stack/5

Conversation

@AmesingFlank
Copy link
Copy Markdown
Contributor

@AmesingFlank AmesingFlank commented Apr 2, 2026

AmesingFlank added a commit that referenced this pull request Apr 2, 2026
…but no in-place tensor updates

stack-info: PR: #1936, branch: AmesingFlank/stack/5
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/5 branch from d7903c3 to 9d03081 Compare April 2, 2026 21:41
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 2, 2026
AmesingFlank added a commit that referenced this pull request Apr 3, 2026
…but no in-place tensor updates

stack-info: PR: #1936, branch: AmesingFlank/stack/5
@AmesingFlank AmesingFlank marked this pull request as draft April 3, 2026 00:57
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/4 to main April 3, 2026 00:57
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/5 branch from 9d03081 to fa86c9d Compare April 3, 2026 00:58
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/4 April 3, 2026 00:58
@AmesingFlank AmesingFlank marked this pull request as ready for review April 3, 2026 00:58
AmesingFlank added a commit that referenced this pull request Apr 3, 2026
…but no in-place tensor updates

stack-info: PR: #1936, branch: AmesingFlank/stack/5
@AmesingFlank AmesingFlank marked this pull request as draft April 3, 2026 01:02
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/4 to main April 3, 2026 01:02
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/5 branch from fa86c9d to cf7beac Compare April 3, 2026 01:02
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/4 April 3, 2026 01:02
@AmesingFlank AmesingFlank marked this pull request as ready for review April 3, 2026 01:02
AmesingFlank added a commit that referenced this pull request Apr 3, 2026
…but no in-place tensor updates

stack-info: PR: #1936, branch: AmesingFlank/stack/5
@AmesingFlank AmesingFlank marked this pull request as draft April 3, 2026 01:23
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/4 to main April 3, 2026 01:23
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/5 branch from cf7beac to 0a69ce6 Compare April 3, 2026 01:23
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/4 April 3, 2026 01:23
@AmesingFlank AmesingFlank marked this pull request as ready for review April 3, 2026 01:23
AmesingFlank added a commit that referenced this pull request Apr 3, 2026
…but no in-place tensor updates

stack-info: PR: #1936, branch: AmesingFlank/stack/5
@AmesingFlank AmesingFlank marked this pull request as draft April 3, 2026 03:40
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/4 to main April 3, 2026 03:40
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/5 branch from 0a69ce6 to af01424 Compare April 3, 2026 03:40
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/4 April 3, 2026 03:40
@AmesingFlank AmesingFlank marked this pull request as ready for review April 3, 2026 03:40
AmesingFlank added a commit that referenced this pull request Apr 3, 2026
…but no in-place tensor updates

stack-info: PR: #1936, branch: AmesingFlank/stack/5
@AmesingFlank AmesingFlank marked this pull request as draft April 3, 2026 03:46
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/4 to main April 3, 2026 03:47
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/5 branch from af01424 to 5a3c816 Compare April 3, 2026 03:47
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/4 April 3, 2026 03:47
@AmesingFlank AmesingFlank marked this pull request as ready for review April 3, 2026 03:47
AmesingFlank added a commit that referenced this pull request Apr 3, 2026
…but no in-place tensor updates

stack-info: PR: #1936, branch: AmesingFlank/stack/5
@AmesingFlank AmesingFlank marked this pull request as draft April 3, 2026 05:01
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/4 to main April 3, 2026 05:01
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/5 branch from 5a3c816 to d5a3bab Compare April 3, 2026 05:01
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/4 April 3, 2026 05:02
@AmesingFlank AmesingFlank marked this pull request as ready for review April 3, 2026 05:02
…but no in-place tensor updates

stack-info: PR: #1936, branch: AmesingFlank/stack/5
@AmesingFlank AmesingFlank marked this pull request as draft April 3, 2026 05:06
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/4 to main April 3, 2026 05:06
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/5 branch from d5a3bab to 6faf98b Compare April 3, 2026 05:06
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/4 April 3, 2026 05:06
@AmesingFlank AmesingFlank marked this pull request as ready for review April 3, 2026 05:07
if graph_info.has_inplace_writes:
raise BackendUnsupported(
"pallas",
"if-statements with tensor-derived predicates and branches has in-place tensor writes. "
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh why does this matter? tensor writes should be vmem mutations, which should be supported?

Copy link
Copy Markdown
Contributor Author

@AmesingFlank AmesingFlank Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that the condition in 'lax.cond' must be a scalar value. If it is not guaranteed to be a scalar, then we can't use lax.cond. However, if the branches have no side effects (vmem mutations), the we can use jnp.where, which does allow non-scalar conditions.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when you say scalar, what do you mean? i think even in triton if conditions must be size-1 tensors

Copy link
Copy Markdown
Contributor Author

@AmesingFlank AmesingFlank Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By scalar i specifically mean actual scalars, and not size-1 arrays or 0D arrays. See api doc here https://docs.jax.dev/en/latest/_autosummary/jax.lax.cond.html. This PR is for handling when cond is NOT an actually scalar, in which case lax.cond can't be used, but we can use 'jnp.where' which does allow arrays, provided there's no side effects in the branches

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so by scalar you mean a compile-time bool? surely it works for traced bools, which are 0-dim bool arrays?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, i think you are right. I was looking at an example (similar to the the one in the test in this PR) where the conditional is an actual size-4 array, but I realized just now that this is because we're now handling tl.grid correctly in Pallas, so the shape of the conditional is wrong.

@AmesingFlank AmesingFlank requested a review from v0i0 April 3, 2026 17:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants