Skip to content

Conversation

@samanklesaria
Copy link
Collaborator

What does this PR do?

  • Adds an optional output_sharding to standard layers just like in jax for use with explicit sharding.

@samanklesaria samanklesaria force-pushed the output_sharding branch 2 times, most recently from 1195087 to 96f0a64 Compare November 19, 2025 19:21
@samanklesaria samanklesaria changed the title Add out_sharding argument to call methods for standard layers Add out_sharding argument to call methods for layers with jax calls that support it Nov 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant