2222 marlin_import_exception ,
2323)
2424from gptqmodel .nn_modules .qlinear .awq_torch import AwqTorchQuantLinear
25+ from gptqmodel .nn_modules .qlinear .torch_fused_awq import TorchFusedAwqQuantLinear
2526from gptqmodel .utils .marlin import marlin_make_workspace_new
2627
2728
3031log = LogBar .shared ()
3132
3233DEVICE = torch .device ("cuda:0" )
34+ CPU_DEVICE = torch .device ("cpu" )
3335
3436GREEN = "\033 [32m"
3537RED = "\033 [31m"
@@ -50,6 +52,7 @@ class TestAwqKernelOutput(unittest.TestCase):
5052 (BACKEND .GEMM , torch .float16 , 0.004 ),
5153 # (BACKEND.GEMM, torch.bfloat16, 0.05),
5254 (BACKEND .MARLIN , torch .float16 , 0.006 ),
55+ (BACKEND .TORCH_FUSED_AWQ , torch .float16 , 0.004 ),
5356 # (BACKEND.MARLIN, torch.bfloat16, 0.05),
5457 ]
5558
@@ -92,6 +95,16 @@ def setUpClass(cls) -> None:
9295 qweight_cpu , qzeros_cpu , scales_cpu , bias_cpu
9396 )
9497
98+ try :
99+ cls .modules [BACKEND .TORCH_FUSED_AWQ ] = cls ._build_torch_fused_awq_module (
100+ qweight_cpu , qzeros_cpu , scales_cpu , bias_cpu
101+ )
102+ except Exception as exc :
103+ cls .backend_skip_reason [BACKEND .TORCH_FUSED_AWQ ] = (
104+ f"Torch fused AWQ kernel unavailable: { exc } "
105+ )
106+ cls .modules [BACKEND .TORCH_FUSED_AWQ ] = None
107+
95108 base_inputs = cls ._generate_inputs ()
96109 cls .inputs : Dict [torch .dtype , List [torch .Tensor ]] = {}
97110 cls .reference_outputs : Dict [torch .dtype , List [torch .Tensor ]] = {}
@@ -247,6 +260,35 @@ def _build_torch_awq_module(
247260 module .post_init ()
248261 return module
249262
263+ @classmethod
264+ def _build_torch_fused_awq_module (
265+ cls ,
266+ qweight_cpu : torch .Tensor ,
267+ qzeros_cpu : torch .Tensor ,
268+ scales_cpu : torch .Tensor ,
269+ bias_cpu : torch .Tensor ,
270+ ) -> TorchFusedAwqQuantLinear :
271+ module = TorchFusedAwqQuantLinear (
272+ bits = cls .BITS ,
273+ group_size = cls .GROUP_SIZE ,
274+ sym = True ,
275+ desc_act = False ,
276+ in_features = cls .in_features ,
277+ out_features = cls .out_features ,
278+ bias = True ,
279+ adapter = None ,
280+ register_buffers = True ,
281+ ).to (CPU_DEVICE )
282+
283+ module .qweight .copy_ (qweight_cpu .to (CPU_DEVICE ))
284+ module .qzeros .copy_ (qzeros_cpu .to (CPU_DEVICE ))
285+ module .scales .copy_ (scales_cpu .to (torch .float16 ).to (CPU_DEVICE ))
286+ module .bias .copy_ (bias_cpu .to (torch .float16 ).to (CPU_DEVICE ))
287+
288+ module .eval ()
289+ module .post_init ()
290+ return module
291+
250292 @classmethod
251293 def _generate_inputs (cls ) -> List [torch .Tensor ]:
252294 large_shapes = [(4 , 32 ), (2 , 64 ), (1 , 96 )]
@@ -288,19 +330,37 @@ def _forward(
288330 * ,
289331 compute_dtype : Optional [torch .dtype ] = None ,
290332 output_dtype : Optional [torch .dtype ] = None ,
333+ target_device : Optional [torch .device ] = None ,
291334 ) -> List [torch .Tensor ]:
335+ if target_device is None :
336+ target_device = cls ._infer_module_device (module )
292337 outputs : List [torch .Tensor ] = []
293338 with torch .inference_mode ():
294339 for tensor in inputs :
295340 local_tensor = tensor
296- if compute_dtype is not None and tensor .dtype != compute_dtype :
297- local_tensor = tensor .to (dtype = compute_dtype )
341+ if local_tensor .device != target_device :
342+ local_tensor = local_tensor .to (device = target_device )
343+ if compute_dtype is not None and local_tensor .dtype != compute_dtype :
344+ local_tensor = local_tensor .to (dtype = compute_dtype )
298345 result = module (local_tensor )
299346 if output_dtype is not None and result .dtype != output_dtype :
300347 result = result .to (dtype = output_dtype )
301348 outputs .append (result .detach ().cpu ())
302349 return outputs
303350
351+ @staticmethod
352+ def _infer_module_device (module : torch .nn .Module ) -> torch .device :
353+ try :
354+ tensor = next (module .parameters ())
355+ return tensor .device
356+ except StopIteration :
357+ pass
358+ try :
359+ tensor = next (module .buffers ())
360+ return tensor .device
361+ except StopIteration :
362+ return torch .device ("cpu" )
363+
304364 def _maybe_skip_backend (self , backend : BACKEND ) -> None :
305365 reason = self .backend_skip_reason .get (backend )
306366 if reason :
0 commit comments