From b34d8597bff813f5ae88c2319c4f9460872c3ce1 Mon Sep 17 00:00:00 2001 From: Alessandro Palla Date: Wed, 29 May 2024 18:11:35 +0200 Subject: [PATCH] Add int4 support (#32) * Add int4 support * Fix dtypes * Add dtypes test * Add dtype to library * Faster i8 to i4 compression * hotfix * Update the profile-llm script * Add library * fix script * Update readme * Add neural compressor and demo * Use neural compressor as the default method * hotfix * Quantize only quantized models * Add tests * fix issue #27 --- CMakeLists.txt | 22 +++- README.md | 4 +- examples/phi-2.py | 5 +- examples/phi-3-nc.py | 50 +++++++ .../conversion.h | 13 ++ intel_npu_acceleration_library/__init__.py | 3 +- .../backend/bindings.py | 2 + .../backend/compression.py | 24 ++++ .../backend/qlinear.py | 4 +- .../backend/qmatmul.py | 4 +- .../backend/runtime.py | 18 ++- intel_npu_acceleration_library/compiler.py | 37 ++++-- intel_npu_acceleration_library/dtypes.py | 68 ++++++++++ intel_npu_acceleration_library/nn/linear.py | 12 +- .../quantization.py | 123 +++++++++++++++++- requirements.txt | 3 +- script/profile_llm.py | 14 +- script/profile_matmul.py | 34 +++-- src/bindings.cpp | 6 + test/python/test_compile.py | 15 ++- test/python/test_dtypes.py | 20 +++ test/python/test_quantization.py | 4 - 22 files changed, 422 insertions(+), 63 deletions(-) create mode 100644 examples/phi-3-nc.py create mode 100644 intel_npu_acceleration_library/backend/compression.py create mode 100644 intel_npu_acceleration_library/dtypes.py create mode 100644 test/python/test_dtypes.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 5874955..bc7732c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,12 +37,19 @@ function(get_linux_lsb_release_information) set(LSB_RELEASE_VERSION "${LSB_RELEASE_VERSION}" PARENT_SCOPE) endfunction() -set(OV_VERSION_SHORT "2024.1") -set(OV_VERSION "2024.1.0.15008.f4afc983258_x86_64") +set(OV_VERSION_SHORT "nightly") +set(OV_VERSION "2024.3.0.dev20240524_x86_64") +set(OV_STORAGE_URL "https://storage.openvinotoolkit.org/repositories/openvino/packages") +set(OV_NIGHTLY_COMMIT "2024.3.0-15502-66093834e38") if (WIN32) if(NOT OV_LIBRARY_URL) - set(OV_LIBRARY_URL "https://storage.openvinotoolkit.org/repositories/openvino/packages/${OV_VERSION_SHORT}/windows/w_openvino_toolkit_windows_${OV_VERSION}.zip") + if (${OV_VERSION_SHORT} STREQUAL "nightly") + set(OV_PLATFORM "${OV_NIGHTLY_COMMIT}") + else() + set(OV_PLATFORM "windows") + endif() + set(OV_LIBRARY_URL "${OV_STORAGE_URL}/${OV_VERSION_SHORT}/${OV_PLATFORM}/w_openvino_toolkit_windows_${OV_VERSION}.zip") endif() elseif(UNIX) if(NOT OV_LIBRARY_URL) @@ -50,7 +57,13 @@ elseif(UNIX) if (LSB_RELEASE_ID STREQUAL "Ubuntu") if (${LSB_RELEASE_VERSION} STREQUAL "18.04" OR ${LSB_RELEASE_VERSION} STREQUAL "20.04" OR ${LSB_RELEASE_VERSION} STREQUAL "22.04") string(REPLACE ".04" "" LSB_RELEASE_VERSION_SHORT ${LSB_RELEASE_VERSION}) - set(OV_LIBRARY_URL "https://storage.openvinotoolkit.org/repositories/openvino/packages/${OV_VERSION_SHORT}/linux/l_openvino_toolkit_ubuntu${LSB_RELEASE_VERSION_SHORT}_${OV_VERSION}.tgz") + if (${OV_VERSION_SHORT} STREQUAL "nightly") + set(OV_PLATFORM "${OV_NIGHTLY_COMMIT}") + else() + set(OV_PLATFORM "linux") + endif() + + set(OV_LIBRARY_URL "${OV_STORAGE_URL}/${OV_VERSION_SHORT}/${OV_PLATFORM}/l_openvino_toolkit_ubuntu${LSB_RELEASE_VERSION_SHORT}_${OV_VERSION}.tgz") else() message(FATAL_ERROR "Ubuntu version ${LSB_RELEASE_VERSION} is unsupported") endif() @@ -63,6 +76,7 @@ else() message(FATAL_ERROR "Unsupported architecture") endif () +message(STATUS "OpenVINO library URL: ${OV_LIBRARY_URL}") FetchContent_Declare( openvino diff --git a/README.md b/README.md index de7d917..da91a85 100644 --- a/README.md +++ b/README.md @@ -25,8 +25,8 @@ Some useful links In our quest to significantly improve the library's performance, we are directing our efforts toward implementing a range of key features, including: - [x] **8-bit quantization** -- [ ] **4-bit Quantization and GPTQ** -- [ ] **NPU-Native mixed precision inference** +- [x] **4-bit Quantization and GPTQ** +- [x] **NPU-Native mixed precision inference** - [x] **Float16 support** - [ ] **BFloat16 (Brain Floating Point Format)** - [x] **`torch.compile` support** diff --git a/examples/phi-2.py b/examples/phi-2.py index a358082..16c2ac6 100644 --- a/examples/phi-2.py +++ b/examples/phi-2.py @@ -7,8 +7,7 @@ from langchain.chains import LLMChain from langchain.llms import HuggingFacePipeline from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer -import intel_npu_acceleration_library -import torch +import intel_npu_acceleration_library as npu_lib model_id = "microsoft/Phi-2" @@ -16,7 +15,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True) streamer = TextStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) -npu_model = intel_npu_acceleration_library.compile(model, dtype=torch.float16) +npu_model = npu_lib.compile(model, dtype=npu_lib.int4) pipe = pipeline( "text-generation", diff --git a/examples/phi-3-nc.py b/examples/phi-3-nc.py new file mode 100644 index 0000000..52b8b8b --- /dev/null +++ b/examples/phi-3-nc.py @@ -0,0 +1,50 @@ +# +# Copyright © 2024 Intel Corporation +# SPDX-License-Identifier: Apache 2.0 +# + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer +import intel_npu_acceleration_library as npu_lib +import warnings + +torch.random.manual_seed(0) + +model = AutoModelForCausalLM.from_pretrained( + "microsoft/Phi-3-mini-4k-instruct", + torch_dtype="auto", + trust_remote_code=True, +) + +model = npu_lib.compile(model, dtype=npu_lib.int4) +tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct") +streamer = TextStreamer(tokenizer, skip_prompt=True) + +messages = [ + { + "role": "system", + "content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.", + }, + { + "role": "user", + "content": "Can you provide ways to eat combinations of bananas and dragonfruits?", + }, +] + +pipe = pipeline( + "text-generation", + model=model, + tokenizer=tokenizer, +) + +generation_args = { + "max_new_tokens": 500, + "return_full_text": False, + "temperature": 0.0, + "do_sample": False, + "streamer": streamer, +} + +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + pipe(messages, **generation_args) diff --git a/include/intel_npu_acceleration_library/conversion.h b/include/intel_npu_acceleration_library/conversion.h index 46619db..66ab25d 100644 --- a/include/intel_npu_acceleration_library/conversion.h +++ b/include/intel_npu_acceleration_library/conversion.h @@ -13,6 +13,19 @@ namespace intel_npu_acceleration_library { +/** + * @brief Compress a int8 vector to I4 format. + * + * @param src pointer to the source int8 buffer + * @param dst pointer to the destination uint8 buffer + * @param size size of the src and dst buffers + */ +void compressToI4(const int8_t* src, uint8_t* dst, size_t size) { + for (size_t i = 0; i < size / 2; i++) { + dst[i] = (src[2 * i] & 0x0F) | ((src[2 * i + 1] & 0x0F) << 4); + } +} + /** * @brief Convert a int8 vector to fp16 given a scalar scale. * diff --git a/intel_npu_acceleration_library/__init__.py b/intel_npu_acceleration_library/__init__.py index 5e8a37b..ffe5e92 100644 --- a/intel_npu_acceleration_library/__init__.py +++ b/intel_npu_acceleration_library/__init__.py @@ -4,6 +4,7 @@ # from .compiler import compile +from .dtypes import int4, int8, float16 -__all__ = ["compile"] +__all__ = ["compile", "int4", "int8", "float16"] diff --git a/intel_npu_acceleration_library/backend/bindings.py b/intel_npu_acceleration_library/backend/bindings.py index 67fc8c9..9e9dc91 100644 --- a/intel_npu_acceleration_library/backend/bindings.py +++ b/intel_npu_acceleration_library/backend/bindings.py @@ -79,6 +79,8 @@ def init_common(lib: ctypes.CDLL): lib.isNPUAvailable.restype = ctypes.c_bool + lib.compressToI4.argtypes = [c_i8_array, c_u8_array, ctypes.c_int] + def init_network_factory(lib: ctypes.CDLL): """Initialize Netowrk factory bindings. diff --git a/intel_npu_acceleration_library/backend/compression.py b/intel_npu_acceleration_library/backend/compression.py new file mode 100644 index 0000000..6550c04 --- /dev/null +++ b/intel_npu_acceleration_library/backend/compression.py @@ -0,0 +1,24 @@ +# +# Copyright © 2024 Intel Corporation +# SPDX-License-Identifier: Apache 2.0 +# + +from intel_npu_acceleration_library.backend.bindings import lib as backend_lib +import numpy as np + + +def compress_to_i4(weights: np.ndarray) -> np.ndarray: + """Compress a int8 array to int4. + + Args: + weights (np.ndarray): input array + + Returns: + np.ndarray: compressed array + """ + compressed_weights = np.zeros( + (weights.shape[0], weights.shape[1] // 2), dtype=np.uint8 + ) + + backend_lib.compressToI4(weights, compressed_weights, np.prod(weights.shape)) + return compressed_weights diff --git a/intel_npu_acceleration_library/backend/qlinear.py b/intel_npu_acceleration_library/backend/qlinear.py index 634d2c7..cf5bf66 100644 --- a/intel_npu_acceleration_library/backend/qlinear.py +++ b/intel_npu_acceleration_library/backend/qlinear.py @@ -17,6 +17,7 @@ def __init__( batch: int, profile: bool = False, device: str = "NPU", + dtype: np.dtype = np.int8, ): """Initialize the QLinear class. @@ -26,6 +27,7 @@ def __init__( batch (int): batch profile (bool): Enable/Disable profiling. Defaults to False. device (str): Target device, default to "NPU". + dtype (np.dtype): weights datatype. Defaults to np.int8. Raises: RuntimeError: Quantized matmul requires input_channel to be a multiple of 8 @@ -35,7 +37,7 @@ def __init__( raise RuntimeError( "Quantized matmul requires input_channel to be a multiple of 8" ) - out = self.linear(self.input, outC, inC, bias=False, wt_dtype=np.int8) + out = self.linear(self.input, outC, inC, bias=False, wt_dtype=dtype) self.compile(out) def run( diff --git a/intel_npu_acceleration_library/backend/qmatmul.py b/intel_npu_acceleration_library/backend/qmatmul.py index 3159128..c4e5502 100644 --- a/intel_npu_acceleration_library/backend/qmatmul.py +++ b/intel_npu_acceleration_library/backend/qmatmul.py @@ -17,6 +17,7 @@ def __init__( batch: int, profile: bool = False, device: str = "NPU", + dtype: np.dtype = np.int8, ): """Initialize the QMatmul class. @@ -26,9 +27,10 @@ def __init__( batch (int): batch profile (bool): Enable/Disable profiling. Defaults to False. device (str): Target device, default to "NPU". + dtype (np.dtype): weights datatype. Defaults to np.int8. """ super().__init__(inC, outC, batch, profile, device) - out = self.linear(self.input, outC, inC, bias=False, wt_dtype=np.int8) + out = self.linear(self.input, outC, inC, bias=False, wt_dtype=dtype) self.compile(out) def run(self, X: np.ndarray, W: np.ndarray, scale: np.ndarray) -> np.ndarray: diff --git a/intel_npu_acceleration_library/backend/runtime.py b/intel_npu_acceleration_library/backend/runtime.py index 8dd8bc1..d33313a 100644 --- a/intel_npu_acceleration_library/backend/runtime.py +++ b/intel_npu_acceleration_library/backend/runtime.py @@ -8,6 +8,7 @@ from intel_npu_acceleration_library.backend import NNFactory from torch.profiler import record_function from typing import Optional, List, Any, Dict, Deque +from functools import partial from collections import deque import numpy as np import torch @@ -46,6 +47,10 @@ def run_matmul( outC, inC = weights.shape[-2:] + if weights.dtype == torch.uint8: + # In case is Int4 we need to double the input channels because weights are compressed + inC *= 2 + # Set tensors as contiguous in memory x = set_contiguous(x) weights = set_contiguous(weights) @@ -53,11 +58,16 @@ def run_matmul( if weights.dtype.is_floating_point: op_class = Linear if op_id is not None else MatMul + op_class_name = op_class.__name__ + create_op = partial(op_class) op_args = [weights.to(torch.float16).numpy()] - elif weights.dtype == torch.int8: + elif weights.dtype in (torch.int8, torch.uint8): if scale is None: raise RuntimeError("Quantized weights require a not null scale") op_class = QLinear if op_id is not None else QMatMul + op_class_name = op_class.__name__ + np_dtype = np.int8 if weights.dtype == torch.int8 else np.uint8 + create_op = partial(op_class, dtype=np_dtype) if scale is None: raise RuntimeError( f"Quantized matmul (weights dtype == {weights.dtype}) requires scale (scale = {scale})" @@ -90,13 +100,13 @@ def run_matmul( else: batch = real_batch - key = f"{str(op_class.__name__)}_{batch}_{inC}_x_{outC}_{inC}_{x_np.dtype}" + key = f"{str(op_class_name)}_{batch}_{inC}_x_{outC}_{inC}_{x_np.dtype}" models = _model_cache.get(key, None) if models is None: - _model_cache[key] = deque([op_class(inC, outC, batch)]) + _model_cache[key] = deque([create_op(inC, outC, batch)]) elif len(models) < 1: - _model_cache[key].append(op_class(inC, outC, batch)) + _model_cache[key].append(create_op(inC, outC, batch)) else: _model_cache[key].rotate(1) diff --git a/intel_npu_acceleration_library/compiler.py b/intel_npu_acceleration_library/compiler.py index 43b307d..128b6c7 100644 --- a/intel_npu_acceleration_library/compiler.py +++ b/intel_npu_acceleration_library/compiler.py @@ -4,20 +4,14 @@ # from intel_npu_acceleration_library.optimizations import horizontal_fusion_linear +from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention +from transformers.models.gemma.modeling_gemma import GemmaMLP, GemmaAttention +from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear +from intel_npu_acceleration_library.quantization import quantize_model +from intel_npu_acceleration_library.dtypes import int8, int4 +import intel_npu_acceleration_library.nn as nn from torch._dynamo import register_backend from typing import Union, Callable, Any - -try: - from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention - from transformers.models.gemma.modeling_gemma import GemmaMLP, GemmaAttention - - is_transformers_available = True -except ModuleNotFoundError: - # Transformer library is not installed - is_transformers_available = False - - -import intel_npu_acceleration_library.nn as nn from typing import List import torch @@ -38,7 +32,7 @@ def compile( Returns: torch.nn.Module: compiled NPU nn.Module """ - if not (dtype.is_floating_point or dtype == torch.int8): + if not (dtype.is_floating_point or dtype in (int8, int4)): raise RuntimeError( f"intel-npu-acceleration-library library do not support yet the requeste datatype: {dtype}" ) @@ -48,6 +42,9 @@ def compile( # General optimizations apply_horizontal_fusion(model) optimize_llama_attention(model, dtype) + if dtype in (int8, int4): + # Quantize model + model = quantize_model(model, dtype) # Model lowering to NPU ops lower_linear(model, dtype) @@ -102,6 +99,9 @@ def lower_linear( layer (torch.nn.Module): Original torch.nn.Linear module dtype (torch.dtype): Target datatype + Raises: + RuntimeError: unsupported quantization bits + Returns: Union[torch.nn.Module, None]: Return the new NPU operator or None """ @@ -109,6 +109,17 @@ def lower_linear( return nn.Linear.fromTorch(layer, dtype) if isinstance(layer, torch.nn.Conv2d): return nn.Conv2d.fromTorch(layer, dtype) + if isinstance(layer, WeightOnlyLinear): + if layer.bits == 4: + return nn.QuantizedLinear( + layer.qweight.to(torch.uint8), layer.scales, layer.bias + ) + elif layer.bits == 8: + return nn.QuantizedLinear( + layer.qweight.view(torch.int8), layer.scales, layer.bias + ) + else: + raise RuntimeError(f"Unsupported quantization bits: {layer.bits}") return None diff --git a/intel_npu_acceleration_library/dtypes.py b/intel_npu_acceleration_library/dtypes.py new file mode 100644 index 0000000..1082fda --- /dev/null +++ b/intel_npu_acceleration_library/dtypes.py @@ -0,0 +1,68 @@ +# +# Copyright © 2024 Intel Corporation +# SPDX-License-Identifier: Apache 2.0 +# + +from dataclasses import dataclass +from typing import Union +import torch + + +@dataclass(frozen=True) +class NPUDtype: + """Represents a custom data type for NPUs (Neural Processing Units). + + Attrs: + name: str: The name of the data type. + bits: int: The number of bits used to represent the data type. + min: int: The minimum value that can be represented by the data type. + max: int: The maximum value that can be represented by the data type. + torch_dtype: torch.dtype: The corresponding torch data type. + is_floating_point: bool: True if the data type is floating-point, False otherwise. + """ + + name: str + bits: int + min: int + max: int + torch_dtype: torch.dtype + + @property + def is_floating_point(self) -> bool: + """ + Check if the data type is a floating-point type. + + Returns: + bool: True if the data type is floating-point, False otherwise. + """ + return self.torch_dtype.is_floating_point + + def __eq__(self, value: Union["NPUDtype", torch.dtype]) -> bool: + """ + Compare the NPUDtype object with another NPUDtype or torch.dtype object. + + Args: + value (Union["NPUDtype", torch.dtype]): The object to compare with. + + Returns: + bool: True if the objects are equal, False otherwise. + """ + if isinstance(value, torch.dtype): + if value.is_floating_point: + info = torch.finfo(value) + else: + info = torch.iinfo(value) + return ( + self.bits == info.bits + and self.max == info.max + and self.min == info.min + and self.torch_dtype == value + ) + else: + return super().__eq__(value) + + +float16 = NPUDtype("fp16", 16, -65504, 65504, torch.float16) +bfloat16 = NPUDtype("bfloat16", 16, -65504, 65504, torch.float16) +int4 = NPUDtype("int4", 4, -8, 7, torch.int8) +int8 = NPUDtype("int8", 8, -128, 127, torch.int8) diff --git a/intel_npu_acceleration_library/nn/linear.py b/intel_npu_acceleration_library/nn/linear.py index 52f1749..f29a108 100644 --- a/intel_npu_acceleration_library/nn/linear.py +++ b/intel_npu_acceleration_library/nn/linear.py @@ -3,9 +3,10 @@ # SPDX-License-Identifier: Apache 2.0 # -from intel_npu_acceleration_library.quantization import quantize_tensor +from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4 from intel_npu_acceleration_library.nn.autograd import AutogradMatMul from intel_npu_acceleration_library.backend import run_matmul +from intel_npu_acceleration_library.dtypes import NPUDtype from typing import Optional, Union import torch import uuid @@ -88,6 +89,11 @@ def fromTensor( if bias is None: return Linear(weight.to(dtype), None) return Linear(weight.to(dtype), bias.to(dtype)) + elif isinstance(dtype, NPUDtype): + weights_quant, scale = quantize_tensor(weight, (dtype.min, dtype.max)) + if dtype.bits == 4: + weights_quant = compress_to_i4(weights_quant) + return QuantizedLinear(weights_quant, scale, bias) elif dtype == torch.int8: if weight.shape[-1] % 8 != 0: raise RuntimeError( @@ -123,9 +129,9 @@ def __init__( super().__init__() self.weight = weight - if self.weight.dtype != torch.int8: + if self.weight.dtype not in (torch.int8, torch.uint8): raise RuntimeError( - f"Quantized weight must be in torch.int8 dtype instead of {self.weight.dtype}" + f"Quantized weight must be in torch.(u)int8 dtype instead of {self.weight.dtype}" ) self.scale = scale self.outC, self.inC = self.weight.shape diff --git a/intel_npu_acceleration_library/quantization.py b/intel_npu_acceleration_library/quantization.py index dca3d00..656c1c3 100644 --- a/intel_npu_acceleration_library/quantization.py +++ b/intel_npu_acceleration_library/quantization.py @@ -2,7 +2,13 @@ # Copyright © 2024 Intel Corporation # SPDX-License-Identifier: Apache 2.0 # +import intel_npu_acceleration_library.backend.compression as compression +from neural_compressor.config import PostTrainingQuantConfig, TuningCriterion +from intel_npu_acceleration_library.dtypes import int8, int4 +from intel_npu_acceleration_library.dtypes import NPUDtype +from neural_compressor.quantization import fit from typing import Tuple +import logging import torch @@ -55,11 +61,114 @@ def compress_to_i4(weights: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: The compressed tensor with 4-bit representation. """ - compressed_weights = torch.zeros( - (weights.shape[0], weights.shape[1] // 2), dtype=torch.uint8 + return torch.tensor(compression.compress_to_i4(weights.numpy())) + + +def quantize_fit( + model: torch.nn.Module, weights_dtype: str, algorithm: str = "RTN" +) -> torch.nn.Module: + """Quantize a model with a given configuration. + + Args: + model (torch.nn.Module): The model to quantize + weights_dtype (str): The datatype for the weights + algorithm (str, optional): The quantization algorithm. Defaults to "RTN". + + Raises: + RuntimeError: Quantization error: unsupported datatype + + Returns: + torch.nn.Module: The quantized model + """ + if weights_dtype == "int4": + bits = 4 + elif weights_dtype == "int8": + bits = 8 + else: + raise RuntimeError(f"Quantization error: unsupported datatype {weights_dtype}") + + conf = PostTrainingQuantConfig( + approach="weight_only", + tuning_criterion=TuningCriterion(timeout=100000), + op_type_dict={ + ".*": { # match all ops + "weight": { + "dtype": weights_dtype, + "bits": bits, + "group_size": -1, + "scheme": "sym", + "algorithm": algorithm, + }, + "activation": { + "dtype": "fp16", + }, + } + }, ) - for i in range(weights.shape[1] // 2): - compressed_weights[:, i] = (weights[:, 2 * i] & 0x0F) | ( - ((weights[:, 2 * i + 1] & 0x0F) << 4) & 0xF0 - ) - return compressed_weights + + return fit(model=model, conf=conf) + + +def quantize_i8_model( + model: torch.nn.Module, algorithm: str = "RTN" +) -> torch.nn.Module: + """Quantize a model to 8-bit representation. + + Args: + model (torch.nn.Module): The model to quantize + algorithm (str, optional): The quantization algorithm. Defaults to "RTN". + + Returns: + torch.nn.Module: The quantized model + """ + quantized_model = quantize_fit(model, "int8", algorithm) + + return quantized_model.export_compressed_model( + scale_dtype=torch.float16, use_optimum_format=False + ) + + +def quantize_i4_model( + model: torch.nn.Module, algorithm: str = "RTN" +) -> torch.nn.Module: + """Quantize a model to 4-bit representation. + + Args: + model (torch.nn.Module): The model to quantize + algorithm (str, optional): The quantization algorithm. Defaults to "RTN". + + Returns: + torch.nn.Module: The quantized model + """ + quantized_model = quantize_fit(model, "int4", algorithm) + + return quantized_model.export_compressed_model( + compression_dtype=torch.int8, + scale_dtype=torch.float16, + use_optimum_format=False, + ) + + +def quantize_model(model: torch.nn.Module, dtype: NPUDtype) -> torch.nn.Module: + """Quantize a model. + + Args: + model (torch.nn.Module): The model to quantize + dtype (NPUDtype): The desired datatype + + Raises: + RuntimeError: Quantization error: unsupported datatype + + Returns: + torch.nn.Module: The quantized model + """ + # Silence neural compressor logger + logger = logging.getLogger("neural_compressor") + logger.setLevel(logging.ERROR) + + if dtype == int4: + return quantize_i4_model(model) + elif dtype == int8: + return quantize_i8_model(model) + else: + raise RuntimeError(f"Quantization error: unsupported datatype {dtype}") diff --git a/requirements.txt b/requirements.txt index c7ec603..d22de26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ numpy torch -transformers>=4.39.3 \ No newline at end of file +transformers>=4.39.3 +neural-compressor \ No newline at end of file diff --git a/script/profile_llm.py b/script/profile_llm.py index c6278ee..02a48a7 100644 --- a/script/profile_llm.py +++ b/script/profile_llm.py @@ -5,6 +5,8 @@ from transformers import AutoTokenizer, AutoModelForCausalLM from intel_npu_acceleration_library.nn.llm import generate_with_static_shape +from intel_npu_acceleration_library.dtypes import float16, int8, int4 + from torch.profiler import profile, ProfilerActivity import intel_npu_acceleration_library import argparse @@ -40,9 +42,11 @@ def main( ) if dtype == "float16": - dtype = torch.float16 + dtype = float16 elif dtype == "int8": - dtype = torch.int8 + dtype = int8 + elif dtype == "int4": + dtype = int4 else: raise RuntimeError(f"Invalid dtype: {dtype}") @@ -128,7 +132,7 @@ def define_and_parse_args(): parser.add_argument( "--dtype", default="float16", - choices=["float16", "int8"], + choices=["float16", "int8", "int4"], help="Select the target dtype (default: %(default)s)", ) @@ -144,7 +148,9 @@ def define_and_parse_args(): if __name__ == "__main__": args = define_and_parse_args() - print(f"Profiling {args.model} with context size {args.context_size}") + print( + f"Profiling {args.model} with context size {args.context_size} and dtype {args.dtype}" + ) if args.n_threads: print(f"Setting number of pytorch thread to {args.n_threads}") torch.set_num_threads(args.n_threads) diff --git a/script/profile_matmul.py b/script/profile_matmul.py index 2d0328a..2628e3e 100644 --- a/script/profile_matmul.py +++ b/script/profile_matmul.py @@ -3,8 +3,10 @@ # SPDX-License-Identifier: Apache 2.0 # -from intel_npu_acceleration_library.quantization import quantize_tensor +from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4 +from intel_npu_acceleration_library.dtypes import int4 from intel_npu_acceleration_library.backend import Linear, QLinear +from functools import partial import numpy as np import argparse import torch @@ -22,7 +24,7 @@ def print_profile_data(hwp_data, data): ) -def profile(inC, outC, batch, quantized=False, n_iters=500, skip_first=10): +def profile(inC, outC, batch, dtype, n_iters=500, skip_first=10): data = [] mac = inC * outC * batch memcpy = (inC + outC) * batch @@ -30,13 +32,20 @@ def profile(inC, outC, batch, quantized=False, n_iters=500, skip_first=10): X = np.random.uniform(-1, 1, (batch, inC)).astype(np.float16) W = np.random.uniform(-1, 1, (outC, inC)).astype(np.float16) - if quantized: - matmul_csl = QLinear + if dtype == "float16": + matmul_csl = Linear + args = [W] + elif dtype == "int8": weights, scale = quantize_tensor(torch.tensor(W)) + matmul_csl = partial(QLinear, dtype=np.int8) + args = [weights.numpy(), scale.numpy()] + elif dtype == "int4": + weights, scale = quantize_tensor(torch.tensor(W), (int4.min, int4.max)) + weights = compress_to_i4(weights) + matmul_csl = partial(QLinear, dtype=np.uint8) args = [weights.numpy(), scale.numpy()] else: - matmul_csl = Linear - args = [W] + raise RuntimeError(f"Invalid dtype: {dtype}") args.append("0000") @@ -56,7 +65,7 @@ def profile(inC, outC, batch, quantized=False, n_iters=500, skip_first=10): memcpy=memcpy, mac=mac, runtime=hwp_runtime, - dtype=W.dtype, + dtype=dtype, ) for idx in range(n_iters): @@ -94,13 +103,16 @@ def define_and_parse_args(): required=True, help="MatMul output channels", ) - parser.add_argument("--quantize", "-q", action="store_true", help="Quantize") + parser.add_argument( + "--dtype", + default="float16", + choices=["float16", "int8", "int4"], + help="Select the target dtype (default: %(default)s)", + ) return parser.parse_args() if __name__ == "__main__": args = define_and_parse_args() - profile( - args.input_channels, args.output_channels, args.batch, quantized=args.quantize - ) + profile(args.input_channels, args.output_channels, args.batch, dtype=args.dtype) diff --git a/src/bindings.cpp b/src/bindings.cpp index ddb4c46..db61b5b 100644 --- a/src/bindings.cpp +++ b/src/bindings.cpp @@ -12,6 +12,12 @@ intel_npu_acceleration_library_DLL_API bool isNPUAvailable() { return intel_npu_acceleration_library::_isNPUAvailable(core); } +// ######################## Compression ######################## + +intel_npu_acceleration_library_DLL_API void compressToI4(const int8_t* src, uint8_t* dst, size_t size) { + intel_npu_acceleration_library::compressToI4(src, dst, size); +} + // ######################### Parameters ######################### intel_npu_acceleration_library_DLL_API intel_npu_acceleration_library::Parameters* createParameters() { diff --git a/test/python/test_compile.py b/test/python/test_compile.py index a654336..e00493d 100644 --- a/test/python/test_compile.py +++ b/test/python/test_compile.py @@ -4,6 +4,7 @@ # from intel_npu_acceleration_library.compiler import compile +from intel_npu_acceleration_library.dtypes import int4 from sklearn.metrics import r2_score import intel_npu_acceleration_library from packaging.version import Version @@ -30,7 +31,7 @@ def forward(self, x): x = 128 * (torch.rand((16, 32), dtype=torch.float16) - 0.5) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int8]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int8, int4]) def test_compilation(dtype): model = NN() @@ -48,7 +49,10 @@ def test_compilation(dtype): else intel_npu_acceleration_library.nn.QuantizedLinear ) assert isinstance(layer, expected_cls) - assert layer.weight.dtype == dtype + if dtype == int4: + assert layer.weight.dtype == torch.uint8 + else: + assert layer.weight.dtype == dtype if layer.bias is not None: if dtype.is_floating_point: assert layer.bias.dtype == dtype @@ -62,7 +66,10 @@ def test_compilation(dtype): y2 = compiled_model(x).detach() t2 = time.perf_counter() - assert 1 - r2_score(y_ref.numpy(), y1.numpy()) < 0.01 + if dtype == int4: + assert 1 - r2_score(y_ref.numpy(), y1.numpy()) < 0.05 + else: + assert 1 - r2_score(y_ref.numpy(), y1.numpy()) < 0.01 assert torch.allclose(y1, y2) @@ -105,7 +112,7 @@ def test_compile_training(dtype): assert layer.training == True -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int8]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int8, int4]) def test_compile_inference(dtype): model = NN() diff --git a/test/python/test_dtypes.py b/test/python/test_dtypes.py new file mode 100644 index 0000000..9a6c5b6 --- /dev/null +++ b/test/python/test_dtypes.py @@ -0,0 +1,20 @@ +# +# Copyright © 2024 Intel Corporation +# SPDX-License-Identifier: Apache 2.0 +# + +import pytest +from intel_npu_acceleration_library.dtypes import float16, bfloat16, int4, int8 + + +@pytest.fixture +def npu_dtypes(): + return [float16, bfloat16, int4, int8] + + +def test_NPUDtype_is_floating_point(npu_dtypes): + for dtype in npu_dtypes: + if dtype in (int4, int8): + assert dtype.is_floating_point == False + else: + assert dtype.is_floating_point == True diff --git a/test/python/test_quantization.py b/test/python/test_quantization.py index 73b5324..af77c70 100644 --- a/test/python/test_quantization.py +++ b/test/python/test_quantization.py @@ -101,10 +101,6 @@ def test_compiled_quantized(batch, inC, outC): @pytest.mark.parametrize("outC", [256, 512]) def test_i4_quantization(batch, inC, outC): - pytest.skip( - "Test is not working until next openvino release 2024.2 since it lacks support for i4 quantization in the inference engine API" - ) - module = intel_npu_acceleration_library.backend.NNFactory(inC, outC, batch) assert module