-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Adding Tensor_layout for Tensor parallelism for Autosharding #21792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Adding Tensor_layout for Tensor parallelism for Autosharding #21792
Conversation
Summary of ChangesHello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request establishes foundational components for tensor parallelism within the JAX backend, crucial for Autosharding. It provides core collective communication primitives like Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces foundational components for tensor parallelism in Keras, specifically for the JAX backend. It adds all_reduce and all_gather collective operations, which are essential for distributed computations. Additionally, it provides a split_tensor_for_parallelism utility for sharding tensors across devices. The changes are well-tested, covering both even and uneven tensor splitting. My review includes a few suggestions to improve documentation accuracy and code simplicity, and to align with the repository's style guide regarding docstring examples.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…keras into tensor_parallel
|
Can you rebase to make the tests pass? |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21792 +/- ##
==========================================
- Coverage 82.66% 82.25% -0.42%
==========================================
Files 577 580 +3
Lines 59453 59843 +390
Branches 9320 9423 +103
==========================================
+ Hits 49148 49222 +74
- Misses 7902 8181 +279
- Partials 2403 2440 +37
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| A tensor slice corresponding to the given `index`. | ||
| """ | ||
| if dim == -1: | ||
| split_dim = ops.ndim(tensor) - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is -1 the only possible case? Usually, any negative number is possible and then you do:
split_dim = ops.ndim(tensor) + dim| @pytest.mark.skipif( | ||
| jax.local_device_count() < 2, | ||
| reason="Requires multiple local devices for testing.", | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests will never run because we don't have a setup with 2 accelerators right now (we have 1 and 8).
You should move these tests to /Users/fhertschuh/Documents/keras/keras/src/backend/jax/distribution_lib_test.py in which we simulate 8 devices.
| return jax.checkpoint(f) | ||
|
|
||
|
|
||
| def all_reduce(x, op="sum", axis_name="model"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that I see this, it's a bit weird to have these ops in core because all backends must implement all ops in core because they are core ops.
Also, this is distribution related. Can you move them to keras/src/backend/jax/distribution_lib.py at least for now.
Maybe we'll need a specific namespace for distribution ops, but since they're not exported, I think distribution_lib is fine for now.
| children_to_add = [] | ||
|
|
||
| if hasattr(current_layer, "layers") and current_layer.layers: | ||
| for sub_layer in current_layer.layers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again model.layers will return the full flat list of all the layers in the model, even if they're deeply nested.
So:
processed_layersis not neededstackis not neededprefixis not needed, it will always be"", you will never get a deeply nested path. If you wanted that, you'll need to code this very differentlychildren_to_addis not needed- line 234-260 are not needed
- lines 224-232 are probably not needed
The code could be a lot shorter. This whole function get_default_config can be:
def get_default_config(module, device_ids):
device_count = len(device_ids)
state_rules = {}
output_rules = {}
for layer in module.layers:
_apply_layer_sharding_rules(
layer, layer.name, device_count, state_rules, output_rules
)
return LayoutMap(state_rules=state_rules, output_rules=output_rules)| for specific_attr in [ | ||
| "token_embedding", | ||
| "embeddings", | ||
| "position_embedding", | ||
| ]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why this needs a special case, can you explain?
| def get_default_config(module, device_ids): | ||
| """Generates a default tensor parallelism configuration for a Keras model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't use the term module in Keras. We have layers and models, which extend layers. I think you should just can it model, unless you do need to call it on layers, in which case you can call it layer.
| _split_fn_internal = split_tensor_for_parallelism | ||
|
|
||
|
|
||
| def _split_rule(device_count, dim): | ||
| """ | ||
| Creates a sharding rule for a specific dimension. | ||
| Returns a lambda function compatible with LayoutMap that defines | ||
| how a tensor should be split across the available devices. | ||
| Args: | ||
| device_count: The total number of devices available for parallelism. | ||
| dim: The dimension of the tensor to split. | ||
| Returns: | ||
| callable: A lambda function accepting (tensor, index) that returns the | ||
| sharded layout. | ||
| """ | ||
| return lambda x, index: _split_fn_internal(x, index, device_count, dim=dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should just declare this _apply_layer_sharding_rules.
I would do it with a function.partial for 2 reasons:
- the code would be much shorter (no need to document this)
- you can bind
device_countonce for all, this way, you don't have to pass it tosplit_rule, which would make each call shorter and would allow them to be on 1 line instead of 3 lines, which will make them way more readable:
def _apply_layer_sharding_rules(...):
def split_rule(dim):
return functools.partial(split_tensor_for_parallelism, device_count=device_count, dim=dim)
...
state_rules[f"{full_name}.kernel"] = split_rule(dim=1)| input_dim = None | ||
| output_dim = None | ||
|
|
||
| if hasattr(layer, "kernel") and layer.kernel is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The "real" kernel is layer._kernel for Dense and EinsumDense layers in Keras.
kernel is a property that may modify the kernel before returning it (for LoRa and quantization).
It may not matter for most of the your code, but it will matter if you want the variable and not just the value of the kernel (for instance to apply sharding to the variable itself).
| if layer.use_bias: | ||
| state_rules[f"{full_name}.bias"] = _split_rule( | ||
| device_count, dim=0 | ||
| ) | ||
| output_rules[f"{full_name}"] = {0: "gather"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The state_rules are actual Python function.
The output_rules are strings, which I assume are later parsed to call the corresponding function (gather or reduce).
Why not be consistent and use either functions everywhere or strings everywhere?
| break | ||
|
|
||
| if not key_found: | ||
| clean_name = weight.name.split("/")[-1].split(":")[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When do variable names contain ":"?
To me this was a legacy thing in Keras 2 that we don't have in Keras 3.
What about "/"? It should be in the path, not the name.
This change introduces core building blocks for tensor parallelism by adding two key components.
First, it adds crucial collective operations, all_reduce and all_gather, to the JAX backend. These allow multiple devices to synchronize data by summing tensors (like gradients) or gathering individual slices back into a full tensor. Second, it adds the high-level tensor sharding logic (split_tensor_for_parallelism), which uses ops.array_split to intelligently slice large tensors, even unevenly, for distribution across devices. New tests confirm this new parallel logic, including the uneven splitting, works as expected.
The tests on this PR will pass after the PR #21697 gets merged