[Pallas] Use jnp.where() for branches with tensor-derived conditions but no in-place tensor updates#1936
[Pallas] Use jnp.where() for branches with tensor-derived conditions but no in-place tensor updates#1936AmesingFlank wants to merge 1 commit into
Conversation
…but no in-place tensor updates stack-info: PR: #1936, branch: AmesingFlank/stack/5
d7903c3 to
9d03081
Compare
…but no in-place tensor updates stack-info: PR: #1936, branch: AmesingFlank/stack/5
9d03081 to
fa86c9d
Compare
…but no in-place tensor updates stack-info: PR: #1936, branch: AmesingFlank/stack/5
fa86c9d to
cf7beac
Compare
…but no in-place tensor updates stack-info: PR: #1936, branch: AmesingFlank/stack/5
cf7beac to
0a69ce6
Compare
…but no in-place tensor updates stack-info: PR: #1936, branch: AmesingFlank/stack/5
0a69ce6 to
af01424
Compare
…but no in-place tensor updates stack-info: PR: #1936, branch: AmesingFlank/stack/5
af01424 to
5a3c816
Compare
…but no in-place tensor updates stack-info: PR: #1936, branch: AmesingFlank/stack/5
5a3c816 to
d5a3bab
Compare
…but no in-place tensor updates stack-info: PR: #1936, branch: AmesingFlank/stack/5
d5a3bab to
6faf98b
Compare
| if graph_info.has_inplace_writes: | ||
| raise BackendUnsupported( | ||
| "pallas", | ||
| "if-statements with tensor-derived predicates and branches has in-place tensor writes. " |
There was a problem hiding this comment.
huh why does this matter? tensor writes should be vmem mutations, which should be supported?
There was a problem hiding this comment.
The issue is that the condition in 'lax.cond' must be a scalar value. If it is not guaranteed to be a scalar, then we can't use lax.cond. However, if the branches have no side effects (vmem mutations), the we can use jnp.where, which does allow non-scalar conditions.
There was a problem hiding this comment.
when you say scalar, what do you mean? i think even in triton if conditions must be size-1 tensors
There was a problem hiding this comment.
By scalar i specifically mean actual scalars, and not size-1 arrays or 0D arrays. See api doc here https://docs.jax.dev/en/latest/_autosummary/jax.lax.cond.html. This PR is for handling when cond is NOT an actually scalar, in which case lax.cond can't be used, but we can use 'jnp.where' which does allow arrays, provided there's no side effects in the branches
There was a problem hiding this comment.
so by scalar you mean a compile-time bool? surely it works for traced bools, which are 0-dim bool arrays?
There was a problem hiding this comment.
hmm, i think you are right. I was looking at an example (similar to the the one in the test in this PR) where the conditional is an actual size-4 array, but I realized just now that this is because we're now handling tl.grid correctly in Pallas, so the shape of the conditional is wrong.
Stacked PRs:
[Pallas] Use jnp.where() for branches with tensor-derived conditions but no in-place tensor updates