@@ -105,6 +105,13 @@ def __init__(
105105 device = "cuda" ,
106106 dtype = self .params_dtype ))
107107
108+ set_weight_attrs (self .ws , {
109+ "weight_loader" : self .weight_loader ,
110+ })
111+ set_weight_attrs (self .w2s , {
112+ "weight_loader" : self .weight_loader ,
113+ })
114+
108115 # Scaling factors for FP8 weights
109116 self .ws_scale = nn .Parameter (
110117 torch .ones (
@@ -115,12 +122,23 @@ def __init__(
115122 self .num_total_experts , device = "cuda" , dtype = torch .float32 ),
116123 requires_grad = False ) if self .use_fp8 else None
117124
118- set_weight_attrs (self .ws , {
119- "weight_loader" : self .weight_loader ,
120- })
121- set_weight_attrs (self .w2s , {
122- "weight_loader" : self .weight_loader ,
123- })
125+ # Scaling factors for FP8 activations
126+ need_act_scales = (self .use_fp8
127+ and quant_config .activation_scheme == "static" )
128+ self .as_scale = nn .Parameter (
129+ torch .zeros (1 , device = "cuda" , dtype = torch .float32 ),
130+ requires_grad = False ) if need_act_scales else None
131+ self .a2s_scale = nn .Parameter (
132+ torch .zeros (1 , device = "cuda" , dtype = torch .float32 ),
133+ requires_grad = False ) if need_act_scales else None
134+
135+ if need_act_scales :
136+ set_weight_attrs (self .as_scale , {
137+ "weight_loader" : self .weight_loader ,
138+ })
139+ set_weight_attrs (self .a2s_scale , {
140+ "weight_loader" : self .weight_loader ,
141+ })
124142
125143 def weight_loader (self , param : nn .Parameter , loaded_weight : torch .Tensor ,
126144 weight_name : str , expert_id : int ):
@@ -135,6 +153,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
135153 shard_size :2 * shard_size , :] = loaded_weight [shard , :]
136154 if weight_name .endswith ("w2.weight" ):
137155 param_data [expert_id , :, :] = loaded_weight [:, shard ]
156+ if "act_scale" in weight_name :
157+ param_data [:] = param_data [:].max (loaded_weight )
138158
139159 def process_weights_after_loading (self ):
140160 if self .use_fp8 :
@@ -162,7 +182,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
162182 inplace = True ,
163183 use_fp8 = self .use_fp8 ,
164184 w1_scale = self .ws_scale ,
165- w2_scale = self .w2s_scale )
185+ w2_scale = self .w2s_scale ,
186+ a1_scale = self .as_scale ,
187+ a2_scale = self .a2s_scale )
166188
167189 if self .tp_size > 1 :
168190 final_hidden_states = tensor_model_parallel_all_reduce (
@@ -443,11 +465,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
443465 ]
444466
445467 expert_params_mapping = [
468+ # These are the weights for the experts
446469 # (param_name, weight_name, expert_id)
447470 ("ws" if weight_name in ["w1" , "w3" ] else "w2s" ,
448471 f"experts.{ expert_id } .{ weight_name } .weight" , expert_id )
449472 for expert_id in range (self .config .num_local_experts )
450473 for weight_name in ["w1" , "w2" , "w3" ]
474+ ] + [
475+ # These are the activation scales for the experts
476+ # (param_name, weight_name, expert_id)
477+ ("as_scale" if weight_name in ["w1" , "w3" ] else "a2s_scale" ,
478+ f"experts.{ expert_id } .{ weight_name } .act_scale" , expert_id )
479+ for expert_id in range (self .config .num_local_experts )
480+ for weight_name in ["w1" , "w2" , "w3" ]
451481 ]
452482
453483 params_dict = dict (self .named_parameters ())
0 commit comments