@@ -38,7 +38,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
38
38
39
39
@abstractmethod
40
40
def create_weights (self , layer : torch .nn .Module , num_experts : int ,
41
- hidden_size : int , intermediate_size : int ,
41
+ hidden_size : int , intermediate_size_per_partition : int ,
42
42
params_dtype : torch .dtype , ** extra_weight_attrs ):
43
43
raise NotImplementedError
44
44
@@ -65,22 +65,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
65
65
"""MoE method without quantization."""
66
66
67
67
def create_weights (self , layer : torch .nn .Module , num_experts : int ,
68
- hidden_size : int , intermediate_size : int ,
68
+ hidden_size : int , intermediate_size_per_partition : int ,
69
69
params_dtype : torch .dtype , ** extra_weight_attrs ):
70
70
# Fused gate_up_proj (column parallel)
71
- w13_weight = torch .nn .Parameter (torch .empty (num_experts ,
72
- 2 * intermediate_size ,
73
- hidden_size ,
74
- dtype = params_dtype ),
71
+ w13_weight = torch .nn .Parameter (torch .empty (
72
+ num_experts ,
73
+ 2 * intermediate_size_per_partition ,
74
+ hidden_size ,
75
+ dtype = params_dtype ),
75
76
requires_grad = False )
76
77
layer .register_parameter ("w13_weight" , w13_weight )
77
78
set_weight_attrs (w13_weight , extra_weight_attrs )
78
79
79
80
# down_proj (row parallel)
80
- w2_weight = torch .nn .Parameter (torch .empty (num_experts ,
81
- hidden_size ,
82
- intermediate_size ,
83
- dtype = params_dtype ),
81
+ w2_weight = torch .nn .Parameter (torch .empty (
82
+ num_experts ,
83
+ hidden_size ,
84
+ intermediate_size_per_partition ,
85
+ dtype = params_dtype ),
84
86
requires_grad = False )
85
87
layer .register_parameter ("w2_weight" , w2_weight )
86
88
set_weight_attrs (w2_weight , extra_weight_attrs )
@@ -289,13 +291,20 @@ def __init__(
289
291
self .quant_method = quant_config .get_quant_method (self , prefix )
290
292
assert self .quant_method is not None
291
293
292
- self .quant_method .create_weights (
293
- layer = self ,
294
- num_experts = num_experts ,
295
- hidden_size = hidden_size ,
296
- intermediate_size = self .intermediate_size_per_partition ,
297
- params_dtype = params_dtype ,
298
- weight_loader = self .weight_loader )
294
+ moe_quant_params = {
295
+ "num_experts" : num_experts ,
296
+ "hidden_size" : hidden_size ,
297
+ "intermediate_size_per_partition" :
298
+ self .intermediate_size_per_partition ,
299
+ "params_dtype" : params_dtype ,
300
+ "weight_loader" : self .weight_loader ,
301
+ }
302
+ # need full intermediate size pre-sharding for WNA16 act order
303
+ if (self .quant_method .__class__ .__name__ ==
304
+ "CompressedTensorsWNA16MoEMethod" ):
305
+ moe_quant_params ["intermediate_size_full" ] = intermediate_size
306
+
307
+ self .quant_method .create_weights (layer = self , ** moe_quant_params )
299
308
300
309
def _load_per_tensor_weight_scale (self , shard_id : str ,
301
310
param : torch .nn .Parameter ,
@@ -312,19 +321,30 @@ def _load_per_tensor_weight_scale(self, shard_id: str,
312
321
elif shard_id == "w2" :
313
322
param_data [expert_id ] = loaded_weight
314
323
315
- def _load_model_weight_or_group_weight_scale (self , shard_dim : int ,
324
+ def _load_model_weight_or_group_weight_scale (self ,
325
+ shard_dim : int ,
316
326
expert_data : torch .Tensor ,
317
327
shard_id : str ,
318
328
loaded_weight : torch .Tensor ,
319
- tp_rank : int ):
320
- # Load grouped weight scales for group quantization
321
- # or model weights
329
+ tp_rank : int ,
330
+ load_full_w2 : bool = False ):
331
+ """
332
+ Load grouped weight scales for group quantization or model weights
333
+ :param shard_dim: dimension to shard
334
+ :param expert_data: parameter for a particular expert
335
+ :param shard_id: either w1, w2, or w3
336
+ :param loaded_weight: checkpoint weight to load into the param
337
+ :param tp_rank: tensor parallel rank
338
+ :param load_full_w2: whether or not the w2 loaded should be sharded.
339
+ """
322
340
if shard_id == "w2" :
323
- self ._load_w2 (shard_id = shard_id ,
324
- shard_dim = shard_dim ,
341
+ # In the case where we have actorder/g_idx, we do not partition the
342
+ # w2 scales, as indicated by `load_full` argument, for all tp cases
343
+ self ._load_w2 (shard_dim = shard_dim ,
325
344
loaded_weight = loaded_weight ,
326
345
expert_data = expert_data ,
327
- tp_rank = tp_rank )
346
+ tp_rank = tp_rank ,
347
+ load_full = load_full_w2 )
328
348
elif shard_id in ("w1" , "w3" ):
329
349
self ._load_w13 (shard_id = shard_id ,
330
350
shard_dim = shard_dim ,
@@ -364,15 +384,21 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
364
384
expert_data = expert_data .narrow (shard_dim , shard_size , shard_size )
365
385
expert_data .copy_ (loaded_weight )
366
386
367
- def _load_w2 (self , expert_data : torch .Tensor , shard_dim : int ,
368
- shard_id : str , loaded_weight : torch .Tensor , tp_rank : int ):
387
+ def _load_w2 (self ,
388
+ expert_data : torch .Tensor ,
389
+ shard_dim : int ,
390
+ loaded_weight : torch .Tensor ,
391
+ tp_rank : int ,
392
+ load_full : bool = False ):
369
393
370
394
# Index the loaded weight for tp sharding.
371
395
# down_proj: "RowParallel" so tp sharding on input_dim
372
396
# Narrow parameter and load.
373
397
shard_size = expert_data .shape [shard_dim ]
374
- loaded_weight = loaded_weight .narrow (shard_dim , shard_size * tp_rank ,
375
- shard_size )
398
+ if not load_full :
399
+ loaded_weight = loaded_weight .narrow (shard_dim ,
400
+ shard_size * tp_rank ,
401
+ shard_size )
376
402
# w2, down_proj: Load into only logical weight of w2.
377
403
expert_data .copy_ (loaded_weight )
378
404
@@ -387,8 +413,7 @@ def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
387
413
shard_dim : int , loaded_weight : torch .Tensor , tp_rank : int ):
388
414
389
415
if shard_id == "w2" :
390
- self ._load_w2 (shard_id = shard_id ,
391
- shard_dim = shard_dim ,
416
+ self ._load_w2 (shard_dim = shard_dim ,
392
417
loaded_weight = loaded_weight ,
393
418
expert_data = expert_data ,
394
419
tp_rank = tp_rank )
@@ -416,19 +441,19 @@ def weight_loader(self, param: torch.nn.Parameter,
416
441
]
417
442
# Fetch the dim to shard the parameter/loaded weight
418
443
# based on the shard id. This will be whatever
419
- # dimension intermediate_size is used.
444
+ # dimension intermediate_size_per_partition is used.
420
445
SHARD_ID_TO_SHARDED_DIM = {"w1" : 0 , "w2" : 1 , "w3" : 0 }
421
446
422
447
expert_data = param .data [expert_id ]
423
448
tp_rank = get_tensor_model_parallel_rank ()
424
449
425
450
# is_transposed: if the dim to shard the weight
426
451
# should be flipped. Required by GPTQ, compressed-tensors
427
- # should be whatever dimension intermediate_size is
452
+ # should be whatever dimension intermediate_size_per_partition is
428
453
is_transposed = getattr (param , "is_transposed" , False )
429
454
shard_dim = SHARD_ID_TO_SHARDED_DIM [shard_id ]
430
455
if is_transposed :
431
- shard_dim = ~ shard_dim
456
+ shard_dim = int ( not shard_dim )
432
457
433
458
# Case input scale: input_scale loading is only supported for fp8
434
459
if "input_scale" in weight_name :
@@ -480,7 +505,8 @@ def weight_loader(self, param: torch.nn.Parameter,
480
505
shard_dim = shard_dim ,
481
506
loaded_weight = loaded_weight ,
482
507
expert_data = expert_data ,
483
- tp_rank = tp_rank )
508
+ tp_rank = tp_rank ,
509
+ load_full_w2 = getattr (param , "load_full_w2" , False ))
484
510
elif quant_method == FusedMoeWeightScaleSupported .TENSOR .value :
485
511
self ._load_per_tensor_weight_scale (shard_id = shard_id ,
486
512
param = param ,
0 commit comments