Skip to content

Commit c3bfab1

Browse files
te convert model lazy loading
1 parent a16d2bb commit c3bfab1

File tree

2 files changed

+101
-49
lines changed

2 files changed

+101
-49
lines changed

src/accelerate/utils/transformer_engine.py

Lines changed: 78 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,30 @@
1717
import torch.nn as nn
1818

1919
from .imports import is_hpu_available, is_transformer_engine_available
20-
from .operations import GatheredParameters
2120

2221

2322
# Do not import `transformer_engine` at package level to avoid potential issues
2423

2524

2625
def convert_model(model, to_transformer_engine=True, _convert_linear=True, _convert_ln=True):
2726
"""
28-
Recursively converts the linear and layernorm layers of a model to their `transformers_engine` counterpart.
27+
Converts model layers to Transformer Engine counterparts with reduced memory overhead.
28+
29+
This function performs a two-stage conversion process:
30+
1. Creates TE module structure on meta device (no memory allocation)
31+
2. Transfers weights efficiently to avoid peak memory usage during conversion
32+
33+
Args:
34+
model: The model to convert
35+
to_transformer_engine (bool): Whether to convert to TE (True) or from TE (False)
36+
_convert_linear (bool): Whether to convert Linear layers
37+
_convert_ln (bool): Whether to convert LayerNorm layers
38+
39+
Returns:
40+
The converted model
41+
42+
Raises:
43+
ImportError: If transformer_engine is not available
2944
"""
3045
if not is_transformer_engine_available():
3146
raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
@@ -39,55 +54,73 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True, _conv
3954
else:
4055
import transformer_engine.pytorch as te
4156

42-
for name, module in model.named_children():
43-
if isinstance(module, nn.Linear) and to_transformer_engine and _convert_linear:
44-
has_bias = module.bias is not None
45-
params_to_gather = [module.weight]
46-
if has_bias:
47-
params_to_gather.append(module.bias)
57+
import torch
4858

49-
with GatheredParameters(params_to_gather, modifier_rank=0):
59+
from accelerate import init_empty_weights
60+
61+
# Stage 1: Create TE module skeleton on meta device (zero memory allocation)
62+
with init_empty_weights():
63+
for name, module in model.named_children():
64+
if isinstance(module, nn.Linear) and to_transformer_engine and _convert_linear:
65+
has_bias = module.bias is not None
66+
67+
# TE requires weight dimensions to be multiples of 16 for optimal performance
5068
if any(p % 16 != 0 for p in module.weight.shape):
51-
return
69+
continue
70+
71+
# Create TE Linear module structure without allocating memory
5272
te_module = te.Linear(
53-
module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
73+
module.in_features,
74+
module.out_features,
75+
bias=has_bias,
76+
params_dtype=module.weight.dtype,
77+
device="meta",
5478
)
55-
te_module.weight.copy_(module.weight)
56-
if has_bias:
57-
te_module.bias.copy_(module.bias)
79+
setattr(model, name, te_module)
5880

81+
elif isinstance(module, nn.LayerNorm) and to_transformer_engine and _convert_ln:
82+
# Create TE LayerNorm module structure without allocating memory
83+
te_module = te.LayerNorm(
84+
module.normalized_shape[0],
85+
eps=module.eps,
86+
params_dtype=module.weight.dtype,
87+
)
5988
setattr(model, name, te_module)
60-
# Note: @xrsrke (Phuc) found that te.LayerNorm doesn't have any real memory savings or speedups over nn.LayerNorm
61-
elif isinstance(module, nn.LayerNorm) and to_transformer_engine and _convert_ln:
62-
with GatheredParameters([module.weight, module.bias], modifier_rank=0):
63-
te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
64-
te_module.weight.copy_(module.weight)
65-
te_module.bias.copy_(module.bias)
66-
67-
setattr(model, name, te_module)
68-
elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear:
69-
has_bias = module.bias is not None
70-
new_module = nn.Linear(
71-
module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
72-
)
73-
new_module.weight.copy_(module.weight)
74-
if has_bias:
75-
new_module.bias.copy_(module.bias)
76-
77-
setattr(model, name, new_module)
78-
elif isinstance(module, te.LayerNorm) and not to_transformer_engine and _convert_ln:
79-
new_module = nn.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
80-
new_module.weight.copy_(module.weight)
81-
new_module.bias.copy_(module.bias)
82-
83-
setattr(model, name, new_module)
84-
else:
85-
convert_model(
86-
module,
87-
to_transformer_engine=to_transformer_engine,
88-
_convert_linear=_convert_linear,
89-
_convert_ln=_convert_ln,
90-
)
89+
90+
else:
91+
# Recursively convert child modules
92+
convert_model(
93+
module,
94+
to_transformer_engine=to_transformer_engine,
95+
_convert_linear=_convert_linear,
96+
_convert_ln=_convert_ln,
97+
)
98+
99+
# Efficiently transfer weights from original to TE modules
100+
for name, module in model.named_modules():
101+
if (
102+
isinstance(module, te.Linear)
103+
and hasattr(module, "weight")
104+
and module.weight.device == torch.device("meta")
105+
):
106+
# Locate corresponding weight parameters in the model's state dict
107+
weight_key = f"{name}.weight"
108+
bias_key = f"{name}.bias"
109+
110+
state_dict = model.state_dict()
111+
112+
# Transfer weight parameter with memory-efficient copying
113+
if weight_key in state_dict:
114+
with torch.no_grad():
115+
# For very large weights, consider chunked transfer to reduce peak memory
116+
module.weight.copy_(state_dict[weight_key])
117+
118+
# Transfer bias parameter if it exists
119+
if hasattr(module, "bias") and module.bias is not None and bias_key in state_dict:
120+
with torch.no_grad():
121+
module.bias.copy_(state_dict[bias_key])
122+
123+
return model
91124

92125

93126
def has_transformer_engine_layers(model):

tests/test_fp8.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,23 @@
4040
)
4141

4242

43+
def check_gpu_memory_usage_is_low():
44+
import pynvml
45+
46+
pynvml.nvmlInit()
47+
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
48+
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
49+
# Check if GPU memory usage is low (e.g., less than 100MB)
50+
assert info.used < 100 * 1024 * 1024, f"GPU memory usage is too high: {info.used / (1024 * 1024)} MB"
51+
pynvml.nvmlShutdown()
52+
53+
4354
def can_convert_te_model():
44-
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [FP8RecipeKwargs(backend="TE")]}
55+
accelerator_kwargs = {
56+
"mixed_precision": "fp8",
57+
"kwargs_handlers": [FP8RecipeKwargs(backend="TE")],
58+
"device_placement": False,
59+
}
4560
accelerator = Accelerator(**accelerator_kwargs)
4661
dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
4762
model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.Linear(32, 16))
@@ -53,9 +68,9 @@ def can_convert_te_model():
5368

5469

5570
def maintain_proper_deepspeed_config(expected_version):
56-
assert AcceleratorState().deepspeed_plugin.zero_stage == expected_version, (
57-
f"Expected zero stage {expected_version} but got {AcceleratorState().deepspeed_plugin.zero_stage}"
58-
)
71+
assert (
72+
AcceleratorState().deepspeed_plugin.zero_stage == expected_version
73+
), f"Expected zero stage {expected_version} but got {AcceleratorState().deepspeed_plugin.zero_stage}"
5974

6075

6176
def can_convert_ao_model():
@@ -168,6 +183,10 @@ def test_can_prepare_model_multi_accelerator_deepspeed(self):
168183

169184

170185
if __name__ == "__main__":
186+
# import debugpy
187+
# debugpy.listen(("localhost", 5678))
188+
# print("Waiting for debugger attach...")
189+
# debugpy.wait_for_client()
171190
# TE suite
172191
if is_transformer_engine_available():
173192
can_convert_te_model()

0 commit comments

Comments
 (0)