@@ -86,6 +86,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
8686
8787 part_size_n = layer .output_size_per_partition
8888 part_size_k = layer .input_size_per_partition
89+ weight_block_size = getattr (layer , "weight_block_size" , None )
8990
9091 if size_k_first :
9192 assert layer .weight .shape == (part_size_k , part_size_n )
@@ -119,14 +120,11 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
119120 scales = layer .weight_scale_inv .to (layer .orig_dtype )
120121 del layer .weight_scale_inv
121122
122- if layer .weight_block_size is None :
123- group_size = - 1
124- else :
125- group_size = layer .weight_block_size [1 ]
123+ group_size = - 1 if weight_block_size is None else weight_block_size [1 ]
126124
127125 # marlin kernel only support channel-wise and group-wise quantization
128126 # we need to convert the scales
129- if layer . weight_block_size is None :
127+ if weight_block_size is None :
130128 if scales .nelement () == 1 :
131129 # tensor-wise quantization -> channel-wise quantization
132130 # (1, 1) =>(repeat)=> (1, size_n)
@@ -149,7 +147,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
149147 # =>(repeat)=> (size_k // block_size[1], size_n)
150148 if not size_k_first :
151149 scales = scales .T .contiguous ()
152- block_n = layer . weight_block_size [0 ]
150+ block_n = weight_block_size [0 ]
153151 scales = scales .repeat_interleave (block_n , 1 )
154152 # size_n may not divisible by block_size[0]
155153 scales = scales [:, :part_size_n ]
@@ -173,6 +171,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
173171 e = layer .num_experts
174172 k = layer .hidden_size
175173 n = layer .intermediate_size_per_partition
174+ weight_block_size = getattr (layer , "weight_block_size" , None )
176175
177176 # WORKSPACE
178177 device = layer .w13_weight .device
@@ -213,10 +212,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
213212
214213 # WEIGHT SCALES
215214 # Permute scales
216- if layer .weight_block_size is None :
217- group_size = - 1
218- else :
219- group_size = layer .weight_block_size [1 ]
215+ group_size = - 1 if weight_block_size is None else weight_block_size [1 ]
220216
221217 for name in ["w13" , "w2" ]:
222218 if name + "_weight_scale" in dir (layer ):
@@ -236,7 +232,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
236232
237233 # marlin kernel only support channel-wise and group-wise quantization
238234 # we need to convert the scales
239- if layer . weight_block_size is None :
235+ if weight_block_size is None :
240236 if scales .nelement () == e :
241237 # tensor-wise quantization -> channel-wise quantization
242238 # (e, 1, 1) =>(repeat)=> (e, 1, size_n)
@@ -259,7 +255,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
259255 # =>(repeat)=> (e, size_k // block_size[1], size_n)
260256 if not size_k_first :
261257 scales = scales .permute (0 , 2 , 1 )
262- block_n = layer . weight_block_size [0 ]
258+ block_n = weight_block_size [0 ]
263259 scales = scales .repeat_interleave (block_n , 2 )
264260 # size_n may not divisible by block_size[0]
265261 scales = scales [..., :size_n ].contiguous ()
0 commit comments