Skip to content

Conversation

@buildwithsuhana
Copy link
Contributor

@buildwithsuhana buildwithsuhana commented Oct 28, 2025

This change introduces core building blocks for tensor parallelism by adding two key components.

First, it adds crucial collective operations, all_reduce and all_gather, to the JAX backend. These allow multiple devices to synchronize data by summing tensors (like gradients) or gathering individual slices back into a full tensor. Second, it adds the high-level tensor sharding logic (split_tensor_for_parallelism), which uses ops.array_split to intelligently slice large tensors, even unevenly, for distribution across devices. New tests confirm this new parallel logic, including the uneven splitting, works as expected.

The tests on this PR will pass after the PR #21697 gets merged

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request establishes foundational components for tensor parallelism within the JAX backend, crucial for Autosharding. It provides core collective communication primitives like all_reduce and all_gather and introduces a flexible tensor splitting utility, split_tensor_for_parallelism, designed to efficiently distribute tensors across multiple devices, even when uneven splitting is required.

Highlights

  • JAX Collective Operations: Introduced all_reduce (sum, mean) and all_gather functions to the JAX backend for inter-device communication, essential for distributed computing.
  • Tensor Sharding Logic: Added split_tensor_for_parallelism to intelligently slice tensors, including support for uneven distributions, for efficient device parallelism.
  • Comprehensive Testing: Included new test cases to validate the correct functionality of both the collective operations and the tensor splitting logic, ensuring robustness.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces foundational components for tensor parallelism in Keras, specifically for the JAX backend. It adds all_reduce and all_gather collective operations, which are essential for distributed computations. Additionally, it provides a split_tensor_for_parallelism utility for sharding tensors across devices. The changes are well-tested, covering both even and uneven tensor splitting. My review includes a few suggestions to improve documentation accuracy and code simplicity, and to align with the repository's style guide regarding docstring examples.

buildwithsuhana and others added 3 commits October 28, 2025 10:41
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@hertschuh
Copy link
Collaborator

Can you rebase to make the tests pass?

@codecov-commenter
Copy link

codecov-commenter commented Nov 6, 2025

Codecov Report

❌ Patch coverage is 59.70149% with 135 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.25%. Comparing base (d8e0b4a) to head (8bb39f6).
⚠️ Report is 10 commits behind head on master.

Files with missing lines Patch % Lines
...tribution/tensor_parallel/coordinated_optimizer.py 50.00% 89 Missing and 15 partials ⚠️
...ras/src/distribution/tensor_parallel/autoconfig.py 78.89% 10 Missing and 13 partials ⚠️
keras/src/backend/jax/core.py 33.33% 6 Missing ⚠️
.../src/distribution/tensor_parallel/tensor_layout.py 77.77% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21792      +/-   ##
==========================================
- Coverage   82.66%   82.25%   -0.42%     
==========================================
  Files         577      580       +3     
  Lines       59453    59843     +390     
  Branches     9320     9423     +103     
==========================================
+ Hits        49148    49222      +74     
- Misses       7902     8181     +279     
- Partials     2403     2440      +37     
Flag Coverage Δ
keras 82.07% <59.70%> (-0.42%) ⬇️
keras-jax 62.75% <59.70%> (-0.56%) ⬇️
keras-numpy 57.45% <38.80%> (-0.09%) ⬇️
keras-openvino 34.63% <38.80%> (+0.28%) ⬆️
keras-tensorflow 63.98% <38.80%> (-0.14%) ⬇️
keras-torch 63.47% <38.80%> (-0.13%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

buildwithsuhana added a commit to buildwithsuhana/keras that referenced this pull request Nov 18, 2025
A tensor slice corresponding to the given `index`.
"""
if dim == -1:
split_dim = ops.ndim(tensor) - 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is -1 the only possible case? Usually, any negative number is possible and then you do:

split_dim = ops.ndim(tensor) + dim

Comment on lines +78 to +81
@pytest.mark.skipif(
jax.local_device_count() < 2,
reason="Requires multiple local devices for testing.",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests will never run because we don't have a setup with 2 accelerators right now (we have 1 and 8).

You should move these tests to /Users/fhertschuh/Documents/keras/keras/src/backend/jax/distribution_lib_test.py in which we simulate 8 devices.

return jax.checkpoint(f)


def all_reduce(x, op="sum", axis_name="model"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I see this, it's a bit weird to have these ops in core because all backends must implement all ops in core because they are core ops.

Also, this is distribution related. Can you move them to keras/src/backend/jax/distribution_lib.py at least for now.

Maybe we'll need a specific namespace for distribution ops, but since they're not exported, I think distribution_lib is fine for now.

children_to_add = []

if hasattr(current_layer, "layers") and current_layer.layers:
for sub_layer in current_layer.layers:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again model.layers will return the full flat list of all the layers in the model, even if they're deeply nested.

So:

  • processed_layers is not needed
  • stack is not needed
  • prefix is not needed, it will always be "", you will never get a deeply nested path. If you wanted that, you'll need to code this very differently
  • children_to_add is not needed
  • line 234-260 are not needed
  • lines 224-232 are probably not needed

The code could be a lot shorter. This whole function get_default_config can be:

def get_default_config(module, device_ids):
    device_count = len(device_ids)
    state_rules = {}
    output_rules = {}
    for layer in module.layers:
        _apply_layer_sharding_rules(
            layer, layer.name, device_count, state_rules, output_rules
        )
    return LayoutMap(state_rules=state_rules, output_rules=output_rules)

Comment on lines +224 to +228
for specific_attr in [
"token_embedding",
"embeddings",
"position_embedding",
]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this needs a special case, can you explain?

Comment on lines +177 to +178
def get_default_config(module, device_ids):
"""Generates a default tensor parallelism configuration for a Keras model.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use the term module in Keras. We have layers and models, which extend layers. I think you should just can it model, unless you do need to call it on layers, in which case you can call it layer.

Comment on lines +7 to +25
_split_fn_internal = split_tensor_for_parallelism


def _split_rule(device_count, dim):
"""
Creates a sharding rule for a specific dimension.
Returns a lambda function compatible with LayoutMap that defines
how a tensor should be split across the available devices.
Args:
device_count: The total number of devices available for parallelism.
dim: The dimension of the tensor to split.
Returns:
callable: A lambda function accepting (tensor, index) that returns the
sharded layout.
"""
return lambda x, index: _split_fn_internal(x, index, device_count, dim=dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should just declare this _apply_layer_sharding_rules.

I would do it with a function.partial for 2 reasons:

  • the code would be much shorter (no need to document this)
  • you can bind device_count once for all, this way, you don't have to pass it to split_rule, which would make each call shorter and would allow them to be on 1 line instead of 3 lines, which will make them way more readable:
def _apply_layer_sharding_rules(...):

    def split_rule(dim):
        return functools.partial(split_tensor_for_parallelism, device_count=device_count, dim=dim)

    ...
    state_rules[f"{full_name}.kernel"] = split_rule(dim=1)

input_dim = None
output_dim = None

if hasattr(layer, "kernel") and layer.kernel is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "real" kernel is layer._kernel for Dense and EinsumDense layers in Keras.

kernel is a property that may modify the kernel before returning it (for LoRa and quantization).

It may not matter for most of the your code, but it will matter if you want the variable and not just the value of the kernel (for instance to apply sharding to the variable itself).

Comment on lines +110 to +114
if layer.use_bias:
state_rules[f"{full_name}.bias"] = _split_rule(
device_count, dim=0
)
output_rules[f"{full_name}"] = {0: "gather"}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The state_rules are actual Python function.

The output_rules are strings, which I assume are later parsed to call the corresponding function (gather or reduce).

Why not be consistent and use either functions everywhere or strings everywhere?

break

if not key_found:
clean_name = weight.name.split("/")[-1].split(":")[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When do variable names contain ":"?

To me this was a legacy thing in Keras 2 that we don't have in Keras 3.

What about "/"? It should be in the path, not the name.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants