66from collections .abc import Callable
77from typing import Any , Optional
88
9+ import numpy as np
910import torch
10- import torch .nn .functional as F
1111from torch .nn .parameter import Parameter
1212
1313from vllm .logger import init_logger
14- from vllm .model_executor .layers .fused_moe import (
15- FusedMoE ,
16- FusedMoEConfig ,
17- FusedMoEMethodBase ,
18- )
1914from vllm .model_executor .layers .fused_moe .config import (
15+ FusedMoEConfig ,
2016 FusedMoEQuantConfig ,
21- int4_w4a16_moe_quant_config ,
22- int8_w8a16_moe_quant_config ,
2317)
18+ from vllm .model_executor .layers .fused_moe .layer import FusedMoE , FusedMoEMethodBase
2419from vllm .model_executor .layers .linear import (
2520 LinearBase ,
2621 LinearMethodBase ,
3126 QuantizationConfig ,
3227 QuantizeMethodBase ,
3328)
29+ from vllm .model_executor .layers .quantization .utils import replace_parameter
30+ from vllm .model_executor .layers .quantization .utils .marlin_utils import (
31+ apply_rtn_marlin_linear ,
32+ marlin_make_workspace_new ,
33+ )
34+ from vllm .scalar_type import scalar_types
3435
3536logger = init_logger (__name__ )
3637"""By default, use 8 bit as target precision, but it can be
4142overridden by setting the RTN_GROUP_SIZE envvar
4243"""
4344GROUP_SIZE = os .getenv ("RTN_GROUP_SIZE" , "128" )
45+ """Global Marlin workspace shared by all modules
46+ """
47+ workspace = None
4448
4549
4650class RTNConfig (QuantizationConfig ):
@@ -60,6 +64,10 @@ def __init__(
6064 f"supported for RTN, but got { self .weight_bits } bits."
6165 )
6266
67+ self .quant_type = (
68+ scalar_types .uint8b128 if self .weight_bits == 8 else scalar_types .uint4b8
69+ )
70+
6371 def __repr__ (self ) -> str :
6472 return (
6573 f"RTNConfig(weight_bits={ self .weight_bits } , group_size={ self .group_size } )"
@@ -221,24 +229,32 @@ def create_weights(
221229 layer .output_size_per_partition = output_size_per_partition
222230
223231 def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
224- fix_weights (layer , "weight" )
232+ """Repack weights and scales for Marlin kernels."""
233+ weight_bits = self .quant_config .weight_bits
234+
235+ weight , scale = repack_weights (layer .weight , layer .scale , weight_bits )
236+
237+ replace_parameter (layer , "weight" , weight )
238+ replace_parameter (layer , "scale" , scale )
239+
240+ init_workspace (layer .weight .device )
225241
226242 def apply (
227243 self ,
228244 layer : torch .nn .Module ,
229245 x : torch .Tensor ,
230246 bias : torch .Tensor | None = None ,
231247 ) -> torch .Tensor :
232- qweight = layer . weight
233- scale = layer . scale
234-
235- weight = rtn_dequantize ( qweight , scale )
236- out = F . linear ( x , weight )
237- del weight
238- if bias is not None :
239- out . add_ ( bias )
240-
241- return out
248+ return apply_rtn_marlin_linear (
249+ input = x ,
250+ weight = layer . weight ,
251+ weight_scale = layer . scale ,
252+ workspace = workspace ,
253+ quant_type = self . quant_config . quant_type ,
254+ output_size_per_partition = layer . output_size_per_partition ,
255+ input_size_per_partition = layer . input_size_per_partition ,
256+ bias = bias ,
257+ )
242258
243259
244260class RTNMoEMethod (FusedMoEMethodBase ):
@@ -315,28 +331,27 @@ def create_weights(
315331 set_weight_attrs (w2_weight , extra_weight_attrs )
316332
317333 def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
334+ """Repack weights and scales for Marlin kernels."""
318335 weight_bits = self .quant_config .weight_bits
319- fix_weights (layer , "w13_weight" , weight_bits == 4 )
320- fix_weights (layer , "w2_weight" , weight_bits == 4 )
336+
337+ w13_weight , w13_scale = repack_weights (
338+ layer .w13_weight , layer .w13_scale , weight_bits
339+ )
340+ replace_parameter (layer , "w13_weight" , w13_weight )
341+ replace_parameter (layer , "w13_scale" , w13_scale )
342+
343+ w2_weight , w2_scale = repack_weights (
344+ layer .w2_weight , layer .w2_scale , weight_bits
345+ )
346+ replace_parameter (layer , "w2_weight" , w2_weight )
347+ replace_parameter (layer , "w2_scale" , w2_scale )
348+
349+ init_workspace (layer .w13_weight .device )
321350
322351 def get_fused_moe_quant_config (
323352 self , layer : torch .nn .Module
324353 ) -> FusedMoEQuantConfig | None :
325- weight_bits = self .quant_config .weight_bits
326- group_size = self .quant_config .group_size
327- assert weight_bits == 4 or weight_bits == 8
328- config_builder = (
329- int4_w4a16_moe_quant_config
330- if weight_bits == 4
331- else int8_w8a16_moe_quant_config
332- )
333- return config_builder (
334- w1_scale = layer .w13_scale ,
335- w2_scale = layer .w2_scale ,
336- w1_zp = None ,
337- w2_zp = None ,
338- block_shape = [0 , group_size ],
339- )
354+ return None
340355
341356 def apply (
342357 self ,
@@ -366,8 +381,6 @@ def apply(
366381 if enable_eplb :
367382 raise NotImplementedError ("EPLB not supported for `RTNMoEMethod` yet." )
368383
369- from vllm .model_executor .layers .fused_moe import fused_experts
370-
371384 topk_weights , topk_ids , _ = FusedMoE .select_experts (
372385 hidden_states = x ,
373386 router_logits = router_logits ,
@@ -383,18 +396,22 @@ def apply(
383396 indices_type = self .topk_indices_dtype ,
384397 )
385398
386- return fused_experts (
399+ return torch . ops . vllm . fused_marlin_moe (
387400 x ,
388401 layer .w13_weight ,
389402 layer .w2_weight ,
390- topk_weights = topk_weights ,
391- topk_ids = topk_ids ,
392- inplace = True ,
393- activation = activation ,
403+ getattr (layer , "w13_bias" , None ),
404+ getattr (layer , "w2_bias" , None ),
405+ layer .w13_scale ,
406+ layer .w2_scale ,
407+ router_logits ,
408+ topk_weights ,
409+ topk_ids ,
410+ quant_type_id = self .quant_config .quant_type .id ,
394411 apply_router_weight_on_input = apply_router_weight_on_input ,
395412 global_num_experts = global_num_experts ,
396413 expert_map = expert_map ,
397- quant_config = self . moe_quant_config ,
414+ workspace = workspace ,
398415 )
399416
400417
@@ -504,18 +521,133 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
504521 return input_deq
505522
506523
507- def fix_weights (layer : torch .nn .Module , param_name : str , reshape : bool = False ):
508- """torch.compile does not know how to deal with a Parameter subclass
509- (aka RTNParameter). As we don't really need RTNParameters for the
510- forward pass, we replace them with equivalent instances of Parameters.
524+ def _get_perms ():
525+ perm = []
526+ for i in range (32 ):
527+ perm1 = []
528+ col = i // 4
529+ for block in [0 , 1 ]:
530+ for row in [
531+ 2 * (i % 4 ),
532+ 2 * (i % 4 ) + 1 ,
533+ 2 * (i % 4 + 4 ),
534+ 2 * (i % 4 + 4 ) + 1 ,
535+ ]:
536+ perm1 .append (16 * row + col + 8 * block )
537+ for j in range (4 ):
538+ perm .extend ([p + 256 * j for p in perm1 ])
539+
540+ perm_arr = np .array (perm )
541+ interleave = np .array ([0 , 2 , 4 , 6 , 1 , 3 , 5 , 7 ])
542+ perm_arr = perm_arr .reshape ((- 1 , 8 ))[:, interleave ].ravel ()
543+ perm_tensor = torch .from_numpy (perm_arr )
544+ scale_perm = []
545+ for i in range (8 ):
546+ scale_perm .extend ([i + 8 * j for j in range (8 )])
547+ scale_perm_single = []
548+ for i in range (4 ):
549+ scale_perm_single .extend ([2 * i + j for j in [0 , 1 , 8 , 9 , 16 , 17 , 24 , 25 ]])
550+ return perm_tensor , scale_perm , scale_perm_single
551+
552+
553+ _perm , _scale_perm , _scale_perm_single = _get_perms ()
554+
555+
556+ def pack_for_marlin (weight , scale , qbits ):
557+ batch = weight .shape [0 ]
558+
559+ n = weight .size (1 )
560+ k = weight .size (2 )
561+ groupsize = k // scale .size (2 )
562+
563+ tile = 16
564+ s = scale .permute (0 , 2 , 1 ) # transpose
565+ w = weight .permute (0 , 2 , 1 ) # transpose
566+ if groupsize != k :
567+ w = w .reshape ((batch , - 1 , groupsize , n ))
568+ w = w .permute (0 , 2 , 1 , 3 )
569+ w = w .reshape ((batch , groupsize , - 1 ))
570+ s = s .reshape ((batch , 1 , - 1 ))
571+
572+ if groupsize != k :
573+ w = w .reshape ((batch , groupsize , - 1 , n ))
574+ w = w .permute (0 , 2 , 1 , 3 )
575+ w = w .reshape ((batch , k , n )).contiguous ()
576+ s = s .reshape ((batch , - 1 , len (_scale_perm )))[:, :, _scale_perm ]
577+ else :
578+ s = s .reshape ((batch , - 1 , len (_scale_perm_single )))[:, :, _scale_perm_single ]
579+ s = s .reshape ((batch , - 1 , n )).contiguous ()
580+ w = w .reshape ((batch , k // tile , tile , n // tile , tile ))
581+ w = w .permute ((0 , 1 , 3 , 2 , 4 ))
582+ w = w .reshape ((batch , k // tile , n * tile ))
583+ res = w
584+ res = res .reshape ((batch , - 1 , _perm .numel ()))[:, :, _perm ].reshape (res .shape )
585+ if qbits == 4 :
586+ q = torch .zeros (
587+ (batch , res .shape [1 ], res .shape [2 ] // 2 ), dtype = torch .int8 , device = w .device
588+ )
589+ for i in range (2 ):
590+ q |= res [:, :, i ::2 ] << 4 * i
591+ q = q .reshape (batch , - 1 , n ).contiguous ()
592+ else :
593+ q = res .clone ()
594+ q [:, :, 2 ::8 ] = res [:, :, 4 ::8 ]
595+ q [:, :, 3 ::8 ] = res [:, :, 5 ::8 ]
596+ q [:, :, 4 ::8 ] = res [:, :, 2 ::8 ]
597+ q [:, :, 5 ::8 ] = res [:, :, 3 ::8 ]
598+ q = q .reshape (batch , - 1 , n ).to (torch .int8 ).contiguous ()
599+
600+ return q , s
601+
602+
603+ def repack_8bit_into_32bit (input ):
604+ output = torch .zeros (
605+ (input .shape [0 ], input .shape [1 ], input .shape [2 ] // 4 ),
606+ dtype = torch .int32 ,
607+ device = input .device ,
608+ )
609+ for i in range (4 ):
610+ output |= (input [:, :, i ::4 ] & 0xFF ).to (torch .int32 ) << 8 * i
611+
612+ return output
613+
614+
615+ def repack_weights (qweight , scale , weight_bits ):
616+ batch_present = len (qweight .shape ) == 3
617+ if not batch_present :
618+ qweight = qweight .unsqueeze (0 )
619+ scale = scale .unsqueeze (0 )
620+
621+ if weight_bits == 4 :
622+ """Unpack two 4-bit values from each byte.
623+ """
624+ qweight_unpacked = torch .empty (
625+ (qweight .shape [0 ], qweight .shape [1 ] * 2 , qweight .shape [2 ]),
626+ dtype = torch .uint8 ,
627+ device = qweight .device ,
628+ )
629+ for i in range (2 ):
630+ qweight_unpacked [:, :, i ::2 ] = ((qweight << 4 * (1 - i )) >> 4 ).reshape (
631+ qweight .shape [0 ], qweight .shape [1 ] * 2 , qweight .shape [2 ] // 2
632+ )
633+ else :
634+ qweight_unpacked = qweight
635+
636+ qweight_packed , scale_packed = pack_for_marlin (qweight_unpacked , scale , weight_bits )
637+ """Marlin kernels expect tensors in int32 format in a certain shape
511638 """
512- old_weight = getattr (layer , param_name )
513- assert isinstance (old_weight , RTNParameter )
514- data = old_weight .data .data
639+ qweight_repacked = repack_8bit_into_32bit (qweight_packed .to (torch .uint8 ))
640+ qweight_reshaped = qweight_repacked .reshape (
641+ qweight .shape [0 ], qweight .shape [2 ] // 16 , - 1
642+ )
643+ if not batch_present :
644+ qweight_reshaped = qweight_reshaped .squeeze (0 )
645+ scale_packed = scale_packed .squeeze (0 )
646+
647+ return qweight_reshaped , scale_packed
515648
516- delattr (layer , param_name )
517649
518- if reshape :
519- data = data . reshape ( old_weight . shape [ 0 ], old_weight . shape [ 1 ] * 2 , - 1 )
520- new_weight = Parameter ( data = data , requires_grad = False )
521- layer . register_parameter ( param_name , new_weight )
650+ def init_workspace ( device ) :
651+ global workspace
652+ if workspace is None :
653+ workspace = marlin_make_workspace_new ( device , 4 )
0 commit comments