Skip to content

Commit 7e667f3

Browse files
committed
add cp
1 parent 2c574e1 commit 7e667f3

17 files changed

+2472
-301
lines changed

docs/diffusers/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@
6868
title: Accelerate inference
6969
- local: optimization/memory
7070
title: Reduce memory usage
71+
- local: api/parallel
72+
title: Parallel inference
7173
- title: Community optimizations
7274
sections:
7375
- local: optimization/xformers

docs/diffusers/api/parallel.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# Parallelism
13+
14+
Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times.
15+
16+
::: mindone.diffusers.ParallelConfig
17+
18+
::: mindone.diffusers.ContextParallelConfig
19+
20+
::: mindone.diffusers.hooks.apply_context_parallel

mindone/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
"CogView4Transformer2DModel",
7979
"ConsisIDTransformer3DModel",
8080
"ConsistencyDecoderVAE",
81+
"ContextParallelConfig",
8182
"ControlNetModel",
8283
"ControlNetUnionModel",
8384
"ControlNetXSAdapter",
@@ -106,6 +107,7 @@
106107
"MultiAdapter",
107108
"MultiControlNetModel",
108109
"OmniGenTransformer2DModel",
110+
"ParallelConfig",
109111
"PixArtTransformer2DModel",
110112
"PriorTransformer",
111113
"QwenImageTransformer2DModel",
@@ -464,6 +466,7 @@
464466
CogView4Transformer2DModel,
465467
ConsisIDTransformer3DModel,
466468
ConsistencyDecoderVAE,
469+
ContextParallelConfig,
467470
ControlNetModel,
468471
ControlNetUnionModel,
469472
ControlNetXSAdapter,
@@ -492,6 +495,7 @@
492495
MultiAdapter,
493496
MultiControlNetModel,
494497
OmniGenTransformer2DModel,
498+
ParallelConfig,
495499
PixArtTransformer2DModel,
496500
PriorTransformer,
497501
QwenImageTransformer2DModel,

mindone/diffusers/hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from .context_parallel import apply_context_parallel
1516
from .faster_cache import FasterCacheConfig, apply_faster_cache
1617
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
1718
from .hooks import HookRegistry, ModelHook
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
from dataclasses import dataclass
17+
from typing import Dict, List, Type, Union
18+
19+
import mindspore as ms
20+
from mindspore import mint
21+
22+
from ..models._modeling_parallel import (
23+
ContextParallelConfig,
24+
ContextParallelInput,
25+
ContextParallelModelPlan,
26+
ContextParallelOutput,
27+
)
28+
from ..utils import get_logger
29+
from ..utils.mindspore_utils import unwrap_module
30+
from .hooks import HookRegistry, ModelHook
31+
32+
logger = get_logger(__name__) # pylint: disable=invalid-name
33+
34+
_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
35+
_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
36+
37+
38+
# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
39+
@dataclass
40+
class ModuleForwardMetadata:
41+
cached_parameter_indices: Dict[str, int] = None
42+
_cls: Type = None
43+
44+
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
45+
kwargs = kwargs or {}
46+
47+
if identifier in kwargs:
48+
return kwargs[identifier], True, None
49+
50+
if self.cached_parameter_indices is not None:
51+
index = self.cached_parameter_indices.get(identifier, None)
52+
if index is None:
53+
raise ValueError(f"Parameter '{identifier}' not found in cached indices.")
54+
return args[index], False, index
55+
56+
if self._cls is None:
57+
raise ValueError("Model class is not set for metadata.")
58+
59+
parameters = list(inspect.signature(self._cls.construct).parameters.keys())
60+
parameters = parameters[1:] # skip `self`
61+
self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
62+
63+
if identifier not in self.cached_parameter_indices:
64+
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
65+
66+
index = self.cached_parameter_indices[identifier]
67+
68+
if index >= len(args):
69+
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
70+
71+
return args[index], False, index
72+
73+
74+
def apply_context_parallel(
75+
module: ms.nn.Cell,
76+
parallel_config: ContextParallelConfig,
77+
plan: Dict[str, ContextParallelModelPlan],
78+
) -> None:
79+
"""Apply context parallel on a model."""
80+
logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")
81+
82+
for module_id, cp_model_plan in plan.items():
83+
submodule = _get_submodule_by_name(module, module_id)
84+
if not isinstance(submodule, list):
85+
submodule = [submodule]
86+
87+
logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")
88+
89+
for m in submodule:
90+
if isinstance(cp_model_plan, dict):
91+
hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
92+
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
93+
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
94+
if isinstance(cp_model_plan, ContextParallelOutput):
95+
cp_model_plan = [cp_model_plan]
96+
if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
97+
raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
98+
hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
99+
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
100+
else:
101+
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
102+
registry = HookRegistry.check_if_exists_or_initialize(m)
103+
registry.register_hook(hook, hook_name)
104+
105+
106+
def remove_context_parallel(module: ms.nn.Cell, plan: Dict[str, ContextParallelModelPlan]) -> None:
107+
for module_id, cp_model_plan in plan.items():
108+
submodule = _get_submodule_by_name(module, module_id)
109+
if not isinstance(submodule, list):
110+
submodule = [submodule]
111+
112+
for m in submodule:
113+
registry = HookRegistry.check_if_exists_or_initialize(m)
114+
if isinstance(cp_model_plan, dict):
115+
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
116+
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
117+
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
118+
else:
119+
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
120+
registry.remove_hook(hook_name)
121+
122+
123+
class ContextParallelSplitHook(ModelHook):
124+
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
125+
super().__init__()
126+
self.metadata = metadata
127+
self.parallel_config = parallel_config
128+
self.module_forward_metadata = None
129+
130+
def initialize_hook(self, module):
131+
cls = unwrap_module(module).__class__
132+
self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
133+
return module
134+
135+
def pre_construct(self, module, *args, **kwargs):
136+
args_list = list(args)
137+
138+
for name, cpm in self.metadata.items():
139+
if isinstance(cpm, ContextParallelInput) and cpm.split_output:
140+
continue
141+
142+
# Maybe the parameter was passed as a keyword argument
143+
input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
144+
name, args_list, kwargs
145+
)
146+
147+
if input_val is None:
148+
continue
149+
150+
# The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
151+
# the output instead of input for a particular layer by setting split_output=True
152+
if isinstance(input_val, ms.Tensor):
153+
input_val = self._prepare_cp_input(input_val, cpm)
154+
elif isinstance(input_val, (list, tuple)):
155+
if len(input_val) != len(cpm):
156+
raise ValueError(
157+
f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
158+
)
159+
sharded_input_val = []
160+
for i, x in enumerate(input_val):
161+
if ms.is_tensor(x) and not cpm[i].split_output:
162+
x = self._prepare_cp_input(x, cpm[i])
163+
sharded_input_val.append(x)
164+
input_val = sharded_input_val
165+
else:
166+
raise ValueError(f"Unsupported input type: {type(input_val)}")
167+
168+
if is_kwarg:
169+
kwargs[name] = input_val
170+
elif index is not None and index < len(args_list):
171+
args_list[index] = input_val
172+
else:
173+
raise ValueError(
174+
f"An unexpected error occurred while processing the input '{name}'. Please open an "
175+
f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible "
176+
f"example along with the full stack trace."
177+
)
178+
179+
return tuple(args_list), kwargs
180+
181+
def post_construct(self, module, output):
182+
is_tensor = isinstance(output, ms.Tensor)
183+
is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, ms.Tensor) for x in output)
184+
185+
if not is_tensor and not is_tensor_list:
186+
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
187+
188+
output = [output] if is_tensor else list(output)
189+
for index, cpm in self.metadata.items():
190+
if not isinstance(cpm, ContextParallelInput) or not cpm.split_output:
191+
continue
192+
if index >= len(output):
193+
raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
194+
current_output = output[index]
195+
current_output = self._prepare_cp_input(current_output, cpm)
196+
output[index] = current_output
197+
198+
return output[0] if is_tensor else tuple(output)
199+
200+
def _prepare_cp_input(self, x: ms.Tensor, cp_input: ContextParallelInput) -> ms.Tensor:
201+
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
202+
raise ValueError(
203+
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
204+
)
205+
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
206+
207+
208+
class ContextParallelGatherHook(ModelHook):
209+
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
210+
super().__init__()
211+
self.metadata = metadata
212+
self.parallel_config = parallel_config
213+
214+
def post_construct(self, module, output):
215+
is_tensor = isinstance(output, ms.Tensor)
216+
217+
if is_tensor:
218+
output = [output]
219+
elif not (isinstance(output, (list, tuple)) and all(isinstance(x, ms.Tensor) for x in output)):
220+
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
221+
222+
output = list(output)
223+
224+
if len(output) != len(self.metadata):
225+
raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.")
226+
227+
for i, cpm in enumerate(self.metadata):
228+
if cpm is None:
229+
continue
230+
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
231+
232+
return output[0] if is_tensor else tuple(output)
233+
234+
235+
class AllGatherFunction(ms.nn.Cell):
236+
def __init__(self, dim, group):
237+
super().__init__()
238+
self.dim = dim
239+
self.group = group
240+
self.world_size = mint.distributed.get_world_size(group)
241+
self.rank = mint.distributed.get_rank(group)
242+
243+
def construct(self, tensor):
244+
# return funcol.all_gather_tensor(tensor, dim, group=group)
245+
# mint.distributed.all_gather_into_tensor only support dim=0
246+
tensor_t = tensor.transpose(self.dim, 0) if self.dim != 0 else tensor
247+
248+
out_shape = list(tensor_t.shape)
249+
out_shape[0] *= self.world_size
250+
output = mint.zeros(out_shape, dtype=tensor_t.dtype)
251+
252+
mint.distributed.all_gather_into_tensor(output, tensor_t.contiguous(), group=self.group)
253+
254+
if self.dim != 0:
255+
output = output.transpose(0, self.dim)
256+
257+
return output
258+
259+
def bprop(self, tensor, out, dout):
260+
grad_chunks = mint.chunk(dout, self.world_size, dim=self.dim)
261+
return (grad_chunks[self.rank],)
262+
263+
264+
class EquipartitionSharder:
265+
@classmethod
266+
def shard(cls, tensor: ms.Tensor, dim: int, mesh) -> ms.Tensor:
267+
# NOTE: the following assertion does not have to be true in general. We simply enforce it for now
268+
# because the alternate case has not yet been tested/required for any model.
269+
assert (
270+
tensor.shape[dim] % mint.distributed.get_world_size(mesh) == 0
271+
), "Tensor size along dimension to be sharded must be divisible by mesh size"
272+
273+
# The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
274+
# return tensor.chunk(mint.distributed.get_world_size(mesh), dim=dim)[mesh.get_rank()]
275+
276+
return tensor.chunk(mint.distributed.get_world_size(mesh), dim=dim)[mint.distributed.get_rank(mesh)]
277+
278+
@classmethod
279+
def unshard(cls, tensor: ms.Tensor, dim: int, mesh) -> ms.Tensor:
280+
tensor = tensor.contiguous()
281+
tensor = AllGatherFunction(dim, mesh)(tensor)
282+
return tensor
283+
284+
285+
def _get_submodule_by_name(model: ms.nn.Cell, name: str) -> Union[ms.nn.Cell, List[ms.nn.Cell]]:
286+
if name.count("*") > 1:
287+
raise ValueError("Wildcard '*' can only be used once in the name")
288+
return _find_submodule_by_name(model, name)
289+
290+
291+
def _find_submodule_by_name(model: ms.nn.Cell, name: str) -> Union[ms.nn.Cell, List[ms.nn.Cell]]:
292+
if name == "":
293+
return model
294+
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
295+
if first_atom == "*":
296+
if not isinstance(model, ms.nn.CellList):
297+
raise ValueError("Wildcard '*' can only be used with ModuleList")
298+
submodules = []
299+
for submodule in model:
300+
subsubmodules = _find_submodule_by_name(submodule, remaining_name)
301+
if not isinstance(subsubmodules, list):
302+
subsubmodules = [subsubmodules]
303+
submodules.extend(subsubmodules)
304+
return submodules
305+
else:
306+
if hasattr(model, first_atom):
307+
submodule = getattr(model, first_atom)
308+
return _find_submodule_by_name(submodule, remaining_name)
309+
else:
310+
raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")

mindone/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ..utils import _LazyModule
2121

2222
_import_structure = {
23+
"_modeling_parallel": ["ContextParallelConfig", "ParallelConfig"],
2324
"adapter": ["MultiAdapter", "T2IAdapter"],
2425
"auto_model": ["AutoModel"],
2526
"autoencoders.autoencoder_asym_kl": ["AsymmetricAutoencoderKL"],
@@ -104,6 +105,7 @@
104105
}
105106

106107
if TYPE_CHECKING:
108+
from ._modeling_parallel import ContextParallelConfig, ParallelConfig
107109
from .adapter import MultiAdapter, T2IAdapter
108110
from .auto_model import AutoModel
109111
from .autoencoders import (

0 commit comments

Comments
 (0)