@@ -105,6 +105,13 @@ def __init__(
105
105
device = "cuda" ,
106
106
dtype = self .params_dtype ))
107
107
108
+ set_weight_attrs (self .ws , {
109
+ "weight_loader" : self .weight_loader ,
110
+ })
111
+ set_weight_attrs (self .w2s , {
112
+ "weight_loader" : self .weight_loader ,
113
+ })
114
+
108
115
# Scaling factors for FP8 weights
109
116
self .ws_scale = nn .Parameter (
110
117
torch .ones (
@@ -115,12 +122,23 @@ def __init__(
115
122
self .num_total_experts , device = "cuda" , dtype = torch .float32 ),
116
123
requires_grad = False ) if self .use_fp8 else None
117
124
118
- set_weight_attrs (self .ws , {
119
- "weight_loader" : self .weight_loader ,
120
- })
121
- set_weight_attrs (self .w2s , {
122
- "weight_loader" : self .weight_loader ,
123
- })
125
+ # Scaling factors for FP8 activations
126
+ need_act_scales = (self .use_fp8
127
+ and quant_config .activation_scheme == "static" )
128
+ self .as_scale = nn .Parameter (
129
+ torch .zeros (1 , device = "cuda" , dtype = torch .float32 ),
130
+ requires_grad = False ) if need_act_scales else None
131
+ self .a2s_scale = nn .Parameter (
132
+ torch .zeros (1 , device = "cuda" , dtype = torch .float32 ),
133
+ requires_grad = False ) if need_act_scales else None
134
+
135
+ if need_act_scales :
136
+ set_weight_attrs (self .as_scale , {
137
+ "weight_loader" : self .weight_loader ,
138
+ })
139
+ set_weight_attrs (self .a2s_scale , {
140
+ "weight_loader" : self .weight_loader ,
141
+ })
124
142
125
143
def weight_loader (self , param : nn .Parameter , loaded_weight : torch .Tensor ,
126
144
weight_name : str , expert_id : int ):
@@ -135,6 +153,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
135
153
shard_size :2 * shard_size , :] = loaded_weight [shard , :]
136
154
if weight_name .endswith ("w2.weight" ):
137
155
param_data [expert_id , :, :] = loaded_weight [:, shard ]
156
+ if "act_scale" in weight_name :
157
+ param_data [:] = param_data [:].max (loaded_weight )
138
158
139
159
def process_weights_after_loading (self ):
140
160
if self .use_fp8 :
@@ -162,7 +182,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
162
182
inplace = True ,
163
183
use_fp8 = self .use_fp8 ,
164
184
w1_scale = self .ws_scale ,
165
- w2_scale = self .w2s_scale )
185
+ w2_scale = self .w2s_scale ,
186
+ a1_scale = self .as_scale ,
187
+ a2_scale = self .a2s_scale )
166
188
167
189
if self .tp_size > 1 :
168
190
final_hidden_states = tensor_model_parallel_all_reduce (
@@ -443,11 +465,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
443
465
]
444
466
445
467
expert_params_mapping = [
468
+ # These are the weights for the experts
446
469
# (param_name, weight_name, expert_id)
447
470
("ws" if weight_name in ["w1" , "w3" ] else "w2s" ,
448
471
f"experts.{ expert_id } .{ weight_name } .weight" , expert_id )
449
472
for expert_id in range (self .config .num_local_experts )
450
473
for weight_name in ["w1" , "w2" , "w3" ]
474
+ ] + [
475
+ # These are the activation scales for the experts
476
+ # (param_name, weight_name, expert_id)
477
+ ("as_scale" if weight_name in ["w1" , "w3" ] else "a2s_scale" ,
478
+ f"experts.{ expert_id } .{ weight_name } .act_scale" , expert_id )
479
+ for expert_id in range (self .config .num_local_experts )
480
+ for weight_name in ["w1" , "w2" , "w3" ]
451
481
]
452
482
453
483
params_dict = dict (self .named_parameters ())
0 commit comments