Skip to content

Conversation

@jackopenn
Copy link

What does this PR do?

  • Add out_sharding arguments to linear layers where supported, these include:
    • Conv : jax.lax.conv_general_dilated
    • Embed.attend : jnp.dot
  • Update tests for all linear layers and new methods

Fixes #5155

Checklist

  • This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other checks if that's the case).
  • This change is discussed in a Github issue/discussion (please add a link).
  • The documentation and docstrings adhere to the documentation guidelines.
  • This change includes necessary high-coverage tests. (No quality testing = no merge!)

@samanklesaria
Copy link
Collaborator

Thanks for the PR! This is stuff I should have added in #5102 - apologies for missing these modules. This all looks good to me!

@samanklesaria samanklesaria self-requested a review December 29, 2025 15:23
@samanklesaria
Copy link
Collaborator

@jackopenn looks like there's some issues with the tests.

@jackopenn
Copy link
Author

@jackopenn looks like there's some issues with the tests.

Oops. Fixed now :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add support for out_sharding in more layers

2 participants