Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

lazy cutlass_gemm_dq import #221

Merged
merged 2 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
lazy cutlass_gemm_dq import
  • Loading branch information
Varun Sundar Rabindranath committed May 1, 2024
commit 83da8719b06a6bd0d8cb82e1991bb73e5e913557
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsUnquantized)
CompressedTensorsScheme, CompressedTensorsUnquantized,
CompressedTensorsW8A8StaticTensor)


class CompressedTensorsConfig(QuantizationConfig):
Expand Down Expand Up @@ -80,10 +81,9 @@ def _get_schema(self, weight_quant: Dict, input_quant: Dict):
is_symmetric = weight_symmetric and input_symmetric

if is_8_bits and is_tensor and is_symmetric and \
torch.cuda.is_available():
# CompressedTensorsW8A8StaticTensor only supports CUDA path for now.
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( # noqa: E501
CompressedTensorsW8A8StaticTensor)
torch.cuda.is_available():
# CompressedTensorsW8A8StaticTensor only supports CUDA path for
# now.
return CompressedTensorsW8A8StaticTensor(
fake_quant=self.fake_quant)
raise NotImplementedError(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Dict, Optional, Tuple, Union

import cutlass
import cutlass.epilogue
import torch
from cutlass import Tensor as FakeTensor

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from torch.nn import Parameter

from vllm._C import ops
from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import ( # noqa: E501
cutlass_gemm_dq)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.utils import set_weight_attrs
Expand Down Expand Up @@ -115,6 +113,12 @@ def create_weights(self, layer: torch.nn.Module,
set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader})

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):

# Lazy import so we don't fail on cutlass imports on non-CUDA
# machines.
from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import ( # noqa: E501
cutlass_gemm_dq)

weight = layer.weight
weight_scale = layer.weight_scale
act_scale = layer.input_scale
Expand Down