17
17
import torch .nn as nn
18
18
19
19
from .imports import is_hpu_available , is_transformer_engine_available
20
- from .operations import GatheredParameters
21
20
22
21
23
22
# Do not import `transformer_engine` at package level to avoid potential issues
24
23
25
24
26
25
def convert_model (model , to_transformer_engine = True , _convert_linear = True , _convert_ln = True ):
27
26
"""
28
- Recursively converts the linear and layernorm layers of a model to their `transformers_engine` counterpart.
27
+ Converts model layers to Transformer Engine counterparts with reduced memory overhead.
28
+
29
+ This function performs a two-stage conversion process:
30
+ 1. Creates TE module structure on meta device (no memory allocation)
31
+ 2. Transfers weights efficiently to avoid peak memory usage during conversion
32
+
33
+ Args:
34
+ model: The model to convert
35
+ to_transformer_engine (bool): Whether to convert to TE (True) or from TE (False)
36
+ _convert_linear (bool): Whether to convert Linear layers
37
+ _convert_ln (bool): Whether to convert LayerNorm layers
38
+
39
+ Returns:
40
+ The converted model
41
+
42
+ Raises:
43
+ ImportError: If transformer_engine is not available
29
44
"""
30
45
if not is_transformer_engine_available ():
31
46
raise ImportError ("Using `convert_model` requires transformer_engine to be installed." )
@@ -39,55 +54,73 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True, _conv
39
54
else :
40
55
import transformer_engine .pytorch as te
41
56
42
- for name , module in model .named_children ():
43
- if isinstance (module , nn .Linear ) and to_transformer_engine and _convert_linear :
44
- has_bias = module .bias is not None
45
- params_to_gather = [module .weight ]
46
- if has_bias :
47
- params_to_gather .append (module .bias )
57
+ import torch
48
58
49
- with GatheredParameters (params_to_gather , modifier_rank = 0 ):
59
+ from accelerate import init_empty_weights
60
+
61
+ # Stage 1: Create TE module skeleton on meta device (zero memory allocation)
62
+ with init_empty_weights ():
63
+ for name , module in model .named_children ():
64
+ if isinstance (module , nn .Linear ) and to_transformer_engine and _convert_linear :
65
+ has_bias = module .bias is not None
66
+
67
+ # TE requires weight dimensions to be multiples of 16 for optimal performance
50
68
if any (p % 16 != 0 for p in module .weight .shape ):
51
- return
69
+ continue
70
+
71
+ # Create TE Linear module structure without allocating memory
52
72
te_module = te .Linear (
53
- module .in_features , module .out_features , bias = has_bias , params_dtype = module .weight .dtype
73
+ module .in_features ,
74
+ module .out_features ,
75
+ bias = has_bias ,
76
+ params_dtype = module .weight .dtype ,
77
+ device = "meta" ,
54
78
)
55
- te_module .weight .copy_ (module .weight )
56
- if has_bias :
57
- te_module .bias .copy_ (module .bias )
79
+ setattr (model , name , te_module )
58
80
81
+ elif isinstance (module , nn .LayerNorm ) and to_transformer_engine and _convert_ln :
82
+ # Create TE LayerNorm module structure without allocating memory
83
+ te_module = te .LayerNorm (
84
+ module .normalized_shape [0 ],
85
+ eps = module .eps ,
86
+ params_dtype = module .weight .dtype ,
87
+ )
59
88
setattr (model , name , te_module )
60
- # Note: @xrsrke (Phuc) found that te.LayerNorm doesn't have any real memory savings or speedups over nn.LayerNorm
61
- elif isinstance (module , nn .LayerNorm ) and to_transformer_engine and _convert_ln :
62
- with GatheredParameters ([module .weight , module .bias ], modifier_rank = 0 ):
63
- te_module = te .LayerNorm (module .normalized_shape [0 ], eps = module .eps , params_dtype = module .weight .dtype )
64
- te_module .weight .copy_ (module .weight )
65
- te_module .bias .copy_ (module .bias )
66
-
67
- setattr (model , name , te_module )
68
- elif isinstance (module , te .Linear ) and not to_transformer_engine and _convert_linear :
69
- has_bias = module .bias is not None
70
- new_module = nn .Linear (
71
- module .in_features , module .out_features , bias = has_bias , params_dtype = module .weight .dtype
72
- )
73
- new_module .weight .copy_ (module .weight )
74
- if has_bias :
75
- new_module .bias .copy_ (module .bias )
76
-
77
- setattr (model , name , new_module )
78
- elif isinstance (module , te .LayerNorm ) and not to_transformer_engine and _convert_ln :
79
- new_module = nn .LayerNorm (module .normalized_shape [0 ], eps = module .eps , params_dtype = module .weight .dtype )
80
- new_module .weight .copy_ (module .weight )
81
- new_module .bias .copy_ (module .bias )
82
-
83
- setattr (model , name , new_module )
84
- else :
85
- convert_model (
86
- module ,
87
- to_transformer_engine = to_transformer_engine ,
88
- _convert_linear = _convert_linear ,
89
- _convert_ln = _convert_ln ,
90
- )
89
+
90
+ else :
91
+ # Recursively convert child modules
92
+ convert_model (
93
+ module ,
94
+ to_transformer_engine = to_transformer_engine ,
95
+ _convert_linear = _convert_linear ,
96
+ _convert_ln = _convert_ln ,
97
+ )
98
+
99
+ # Efficiently transfer weights from original to TE modules
100
+ for name , module in model .named_modules ():
101
+ if (
102
+ isinstance (module , te .Linear )
103
+ and hasattr (module , "weight" )
104
+ and module .weight .device == torch .device ("meta" )
105
+ ):
106
+ # Locate corresponding weight parameters in the model's state dict
107
+ weight_key = f"{ name } .weight"
108
+ bias_key = f"{ name } .bias"
109
+
110
+ state_dict = model .state_dict ()
111
+
112
+ # Transfer weight parameter with memory-efficient copying
113
+ if weight_key in state_dict :
114
+ with torch .no_grad ():
115
+ # For very large weights, consider chunked transfer to reduce peak memory
116
+ module .weight .copy_ (state_dict [weight_key ])
117
+
118
+ # Transfer bias parameter if it exists
119
+ if hasattr (module , "bias" ) and module .bias is not None and bias_key in state_dict :
120
+ with torch .no_grad ():
121
+ module .bias .copy_ (state_dict [bias_key ])
122
+
123
+ return model
91
124
92
125
93
126
def has_transformer_engine_layers (model ):
0 commit comments