Skip to content

Conversation

@Cui-yshoho
Copy link
Contributor

What does this PR do?

Context parallelism

Context parallelism splits input sequences across multiple NPUs to reduce memory usage. Each NPU processes its own slice of the sequence.

Use set_attention_backend() to switch to a more optimized attention backend. Currently only supports flash.

Ring Attention

Key (K) and value (V) representations communicate between devices using Ring Attention. This ensures each split sees every other token’s K/V. Each NPU computes attention for its local K/V and passes it to the next NPU in the ring. No single NPU holds the full sequence, which reduces communication latency.

Pass a ContextParallelConfig to the parallel_config argument of the transformer model. The config supports the ring_degree argument that determines how many devices to use for Ring Attention.

import mindspore as ms
from mindspore import mint
from mindone.diffusers import AutoModel, FluxPipeline, ContextParallelConfig

mint.distributed.init_process_group()
rank = mint.distributed.get_rank()

transformer = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="transformer", mindspore_dtype=ms.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2))
pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", transformer=transformer, mindspore_dtype=ms.bfloat16)
pipeline.transformer.set_attention_backend("flash")

prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""

# Must specify generator so all ranks start with same latents (or pass your own)
image = pipeline(prompt, num_inference_steps=50)[0][0]

if rank == 0:
    image.save("output.png")

Ulysses Attention

Ulysses Attention splits a sequence across NPUs and performs an all-to-all communication (every device sends/receives data to every other device). Each NPU ends up with all tokens for only a subset of attention heads. Each NPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.

ContextParallelConfig supports Ulysses Attention through the ulysses_degree argument. This determines how many devices to use for Ulysses Attention.

Pass the ContextParallelConfig to enable_parallelism().

pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))

Call msrun launch to run the script and use the --worker_num & --local_worker_num argument to set the number of NPUs to use. Like this:

msrun --worker_num=2 --local_worker_num=2 --master_port=8118 --log_dir=msrun_log --join=True --cluster_time_out=300 net.py

Currently, only native attention and flash attention support context parallelism.
However, the native attention backend only supports Ulysses Attention, and does not support Ring Attention.
To enable Ring Attention when using context parallelism, you must set:

pipeline.transformer.set_attention_backend("flash")

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Did you make sure to update the documentation with your changes? E.g. record bug fixes or new features in What's New. Here are the
    documentation guidelines
  • Did you build and run the code without any errors?
  • Did you report the running environment (NPU type/MS version) and performance in the doc? (better record it for data loading, model inference, or training tasks)
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@xxx

@Cui-yshoho Cui-yshoho requested a review from vigo999 as a code owner November 25, 2025 08:13
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Cui-yshoho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the MindSpore Diffusers library by introducing advanced context parallelism strategies, namely Ring Attention and Ulysses Attention. These techniques enable more memory-efficient and scalable inference on multi-NPU setups by distributing input sequences and attention computations. The changes also include a robust refactoring of the attention mechanism through a centralized dispatcher, allowing for flexible selection of optimized attention backends and laying the groundwork for future parallelism and performance improvements across various transformer models.

Highlights

  • Context Parallelism Implementation: Introduced ContextParallelConfig, ParallelConfig, and associated hooks (ContextParallelSplitHook, ContextParallelGatherHook) to enable splitting and gathering tensors across multiple NPUs for reduced memory usage.
  • Ring Attention Support: Implemented Ring Attention (TemplatedRingAttention) where key and value representations communicate between devices in a ring topology, ensuring each NPU sees all tokens' K/V.
  • Ulysses Attention Support: Implemented Ulysses Attention (TemplatedUlyssesAttention) which splits sequences across NPUs and uses all-to-all communication for attention heads.
  • Centralized Attention Dispatcher: Refactored attention logic into a new attention_dispatch.py module with dispatch_attention_fn to dynamically select between different attention backends (e.g., native, Flash Attention) and integrate context parallelism.
  • Model Integration: Updated ModelMixin with enable_parallelism and set_attention_backend methods, and integrated context parallelism plans (_cp_plan) into FluxTransformer2DModel, LTXVideoTransformer3DModel, QwenImageTransformer2DModel, SkyReelsV2Transformer3DModel, and WanTransformer3DModel.
  • Documentation: Added new documentation for 'Parallel inference' and updated the _toctree.yml to reflect this new feature.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant new functionality for context parallelism, including Ring and Ulysses Attention, and refactors several models to use a new unified attention dispatcher. This is a great step towards improving performance and scalability. My review has identified several critical issues, primarily in the backward pass implementations of the new custom attention cells and in some helper functions. These will need to be addressed to ensure correct gradient calculations during training. I've also found a few other bugs and a documentation typo. Overall, the architectural changes are solid, but the implementation details require careful correction.

grad_chunks = mint.chunk(dout, self.world_size, dim=self.dim)
return (grad_chunks[self.rank],)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The bprop implementation seems incomplete. The gradient chunks are calculated, but only the chunk corresponding to the current rank is returned. For the backward pass of an all-gather operation, the gradients should be summed across all devices (an all-reduce operation) and then scattered. Returning only the local chunk will result in incorrect gradients during training.

out = out.permute(0, 2, 1, 3)
return out

grad_query_t, grad_key_t, grad_value_t = ms.grad(forward_fn, grad_position=(0, 1, 2))(query_t, key_t, value_t)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The backward pass implementation for NativeAttentionCell is incorrect. The dout tensor, which represents the incoming gradient from the subsequent layer, is not being used in the gradient calculation with ms.grad. This will result in incorrect gradients for any dout that is not a tensor of ones. You should use sens_param=True and pass dout to the gradient function to compute the correct Vector-Jacobian Product (VJP).

Suggested change
grad_query_t, grad_key_t, grad_value_t = ms.grad(forward_fn, grad_position=(0, 1, 2))(query_t, key_t, value_t)
grad_query_t, grad_key_t, grad_value_t = ms.grad(forward_fn, grad_position=(0, 1, 2), sens_param=True)(query_t, key_t, value_t, dout.permute(0, 2, 1, 3))

Comment on lines 589 to 616
def bprop(
self,
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
_save_ctx,
_parallel_config,
out,
dout,
):
grad_query, grad_key, grad_value = mint.empty_like(query), mint.empty_like(key), mint.empty_like(value)

# Head dimension may have been padded
grad_query = grad_query[..., : dout.shape[-1]]
grad_key = grad_key[..., : dout.shape[-1]]
grad_value = grad_value[..., : dout.shape[-1]]

return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The backward pass for FlashAttentionCell is not implemented correctly. It currently returns empty-like tensors, which will lead to incorrect (zero) gradients. The backward pass needs to be fully implemented, likely using ops.operations.nn_ops.FlashAttentionScoreGrad.

Comment on lines 807 to 810
# grad_kv_buffer = _wait_tensor(next_grad_kv)
grad_key_numel = grad_key.numel()
grad_key = grad_kv_buffer[:grad_key_numel].reshape_as(grad_key)
grad_value = grad_kv_buffer[grad_key_numel:].reshape_as(grad_value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable grad_kv_buffer is used here before it has been assigned a value within this if block. It seems you intended to assign next_grad_kv to it. Without this assignment, this will raise a NameError during execution.

Suggested change
# grad_kv_buffer = _wait_tensor(next_grad_kv)
grad_key_numel = grad_key.numel()
grad_key = grad_kv_buffer[:grad_key_numel].reshape_as(grad_key)
grad_value = grad_kv_buffer[grad_key_numel:].reshape_as(grad_value)
grad_kv_buffer = next_grad_kv
grad_key_numel = grad_key.numel()
grad_key = grad_kv_buffer[:grad_key_numel].reshape_as(grad_key)
grad_value = grad_kv_buffer[grad_key_numel:].reshape_as(grad_value)

Comment on lines +681 to +700
if (
(attn_mask is not None and attn_mask.dtype != ms.bool_ and 1.0 in attn_mask)
or head_dim > 512
or backend == "math"
or enable_gqa
):
out = math_attention_op(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa)

if query.dtype in (ms.float16, ms.bfloat16):
out = flash_attention_op(query, key, value, attn_mask, keep_prob=1 - dropout_p, scale=scale)
else:
out = flash_attention_op(
query.to(ms.float16),
key.to(ms.float16),
value.to(ms.float16),
attn_mask,
keep_prob=1 - dropout_p,
scale=scale,
).to(query.dtype)
return out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a logical issue in this function. If the first if condition is met, out is calculated using math_attention_op. However, the execution then falls through to the next if/else block, which overwrites out with the result from flash_attention_op. The result from math_attention_op is always discarded. You should add a return statement inside the first if block to prevent this.

Suggested change
if (
(attn_mask is not None and attn_mask.dtype != ms.bool_ and 1.0 in attn_mask)
or head_dim > 512
or backend == "math"
or enable_gqa
):
out = math_attention_op(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa)
if query.dtype in (ms.float16, ms.bfloat16):
out = flash_attention_op(query, key, value, attn_mask, keep_prob=1 - dropout_p, scale=scale)
else:
out = flash_attention_op(
query.to(ms.float16),
key.to(ms.float16),
value.to(ms.float16),
attn_mask,
keep_prob=1 - dropout_p,
scale=scale,
).to(query.dtype)
return out
if (
(attn_mask is not None and attn_mask.dtype != ms.bool_ and 1.0 in attn_mask)
or head_dim > 512
or backend == "math"
or enable_gqa
):
return math_attention_op(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa)
if query.dtype in (ms.float16, ms.bfloat16):
out = flash_attention_op(query, key, value, attn_mask, keep_prob=1 - dropout_p, scale=scale)
else:
out = flash_attention_op(
query.to(ms.float16),
key.to(ms.float16),
value.to(ms.float16),
attn_mask,
keep_prob=1 - dropout_p,
scale=scale,
).to(query.dtype)
return out

Comment on lines 116 to 118
query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The apply_rotary_emb function expects freqs_cos and freqs_sin as separate arguments, but rotary_emb is being passed as a single argument. Since rotary_emb is a tuple containing both tensors, you should unpack it using the splat operator (*).

Suggested change
query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)
query = apply_rotary_emb(query, *rotary_emb)
key = apply_rotary_emb(key, *rotary_emb)


Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times.

::: mindone.diffusersParallelConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There appears to be a typo in the class name. It should be mindone.diffusers.ParallelConfig.

Suggested change
::: mindone.diffusersParallelConfig
::: mindone.diffusers.ParallelConfig

_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_cp_plan = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the cp_plan related to # cards?

Copy link
Contributor Author

@Cui-yshoho Cui-yshoho Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in the related diffusers pr, cp_plan is introduced as follows:
image

return output[0] if is_tensor else tuple(output)


class AllGatherFunction(ms.nn.Cell):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以直接用Function:

from mindspore.common._grad_function import _Function as Function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我去修改一下~

@Cui-yshoho Cui-yshoho force-pushed the cp branch 3 times, most recently from 3bde6ad to 78fb7b8 Compare November 27, 2025 08:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants