Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas-Triton] Fix squeeze lowering required sharding argument #25412

Merged
merged 1 commit into from
Dec 18, 2024

Conversation

shangz-ai
Copy link
Contributor

Error comes from this PR: #25103 if there is no sharding argument passed to squeeze.

I observed my pallas kernel (using squeeze) failure due to the above PR, and the kernel was fine before that PR was merged.

jax/_src/pallas/triton/lowering.py Outdated Show resolved Hide resolved
@shangz-ai
Copy link
Contributor Author

@superbobry Here is a very naive reproducer

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from functools import partial

def squeeze_kernel(k_ref, o_ref):
    # Read inputs
    k = k_ref[...]

    # Squeeze and store
    o_ref[...] = jnp.squeeze(k)

@partial(jax.jit, static_argnames='out_type')
def squeeze_k(k: jax.Array, out_type=jnp.float32) -> jax.Array:
    return pl.pallas_call(
        squeeze_kernel,
        out_shape=jax.ShapeDtypeStruct(jnp.squeeze(k).shape, out_type)
    )(k)

if __name__ == "__main__":
    k = jnp.ones((4, 2, 1, 1, 2, 1, 4), dtype=jnp.float32)

    # Run the kernel and print result
    print("Shape before squeezing: \n", k.shape)
    res = squeeze_k(k, out_type=jnp.float32)
    print("Shape after squeezing: \n", res.shape)

The main branch jax will report error as:

msg=_reshape_lowering_rule() missing 1 required keyword-only argument: 'sharding'

but with the PR fix, it can run correctly as:

Shape before squeezing: 
 (4, 2, 1, 1, 2, 1, 4)
Shape after squeezing: 
 (4, 2, 2, 4)

Please let me know if that makes sense to you or need any other information.
Thanks!

@superbobry
Copy link
Collaborator

Can you squash the commits please?

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Dec 17, 2024
@shangz-ai
Copy link
Contributor Author

Can you squash the commits please?

Done

@copybara-service copybara-service bot merged commit 66c8a2e into jax-ml:main Dec 18, 2024
14 of 19 checks passed
@shangz-ai shangz-ai deleted the patch-3 branch December 18, 2024 17:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants