Skip to content

Commit

Permalink
Add int4 support (intel#32)
Browse files Browse the repository at this point in the history
* Add int4 support

* Fix dtypes

* Add dtypes test

* Add dtype to library

* Faster i8 to i4 compression

* hotfix

* Update the profile-llm script

* Add library

* fix script

* Update readme

* Add neural compressor and demo

* Use neural compressor as the default method

* hotfix

* Quantize only quantized models

* Add tests

* fix issue intel#27
  • Loading branch information
alessandropalla authored May 29, 2024
1 parent 5294a5c commit b34d859
Show file tree
Hide file tree
Showing 22 changed files with 422 additions and 63 deletions.
22 changes: 18 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,33 @@ function(get_linux_lsb_release_information)
set(LSB_RELEASE_VERSION "${LSB_RELEASE_VERSION}" PARENT_SCOPE)
endfunction()

set(OV_VERSION_SHORT "2024.1")
set(OV_VERSION "2024.1.0.15008.f4afc983258_x86_64")
set(OV_VERSION_SHORT "nightly")
set(OV_VERSION "2024.3.0.dev20240524_x86_64")
set(OV_STORAGE_URL "https://storage.openvinotoolkit.org/repositories/openvino/packages")
set(OV_NIGHTLY_COMMIT "2024.3.0-15502-66093834e38")

if (WIN32)
if(NOT OV_LIBRARY_URL)
set(OV_LIBRARY_URL "https://storage.openvinotoolkit.org/repositories/openvino/packages/${OV_VERSION_SHORT}/windows/w_openvino_toolkit_windows_${OV_VERSION}.zip")
if (${OV_VERSION_SHORT} STREQUAL "nightly")
set(OV_PLATFORM "${OV_NIGHTLY_COMMIT}")
else()
set(OV_PLATFORM "windows")
endif()
set(OV_LIBRARY_URL "${OV_STORAGE_URL}/${OV_VERSION_SHORT}/${OV_PLATFORM}/w_openvino_toolkit_windows_${OV_VERSION}.zip")
endif()
elseif(UNIX)
if(NOT OV_LIBRARY_URL)
get_linux_lsb_release_information()
if (LSB_RELEASE_ID STREQUAL "Ubuntu")
if (${LSB_RELEASE_VERSION} STREQUAL "18.04" OR ${LSB_RELEASE_VERSION} STREQUAL "20.04" OR ${LSB_RELEASE_VERSION} STREQUAL "22.04")
string(REPLACE ".04" "" LSB_RELEASE_VERSION_SHORT ${LSB_RELEASE_VERSION})
set(OV_LIBRARY_URL "https://storage.openvinotoolkit.org/repositories/openvino/packages/${OV_VERSION_SHORT}/linux/l_openvino_toolkit_ubuntu${LSB_RELEASE_VERSION_SHORT}_${OV_VERSION}.tgz")
if (${OV_VERSION_SHORT} STREQUAL "nightly")
set(OV_PLATFORM "${OV_NIGHTLY_COMMIT}")
else()
set(OV_PLATFORM "linux")
endif()

set(OV_LIBRARY_URL "${OV_STORAGE_URL}/${OV_VERSION_SHORT}/${OV_PLATFORM}/l_openvino_toolkit_ubuntu${LSB_RELEASE_VERSION_SHORT}_${OV_VERSION}.tgz")
else()
message(FATAL_ERROR "Ubuntu version ${LSB_RELEASE_VERSION} is unsupported")
endif()
Expand All @@ -63,6 +76,7 @@ else()
message(FATAL_ERROR "Unsupported architecture")
endif ()

message(STATUS "OpenVINO library URL: ${OV_LIBRARY_URL}")

FetchContent_Declare(
openvino
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ Some useful links
In our quest to significantly improve the library's performance, we are directing our efforts toward implementing a range of key features, including:

- [x] **8-bit quantization**
- [ ] **4-bit Quantization and GPTQ**
- [ ] **NPU-Native mixed precision inference**
- [x] **4-bit Quantization and GPTQ**
- [x] **NPU-Native mixed precision inference**
- [x] **Float16 support**
- [ ] **BFloat16 (Brain Floating Point Format)**
- [x] **`torch.compile` support**
Expand Down
5 changes: 2 additions & 3 deletions examples/phi-2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@
from langchain.chains import LLMChain
from langchain.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer
import intel_npu_acceleration_library
import torch
import intel_npu_acceleration_library as npu_lib

model_id = "microsoft/Phi-2"

model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=True).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True)
streamer = TextStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)

npu_model = intel_npu_acceleration_library.compile(model, dtype=torch.float16)
npu_model = npu_lib.compile(model, dtype=npu_lib.int4)

pipe = pipeline(
"text-generation",
Expand Down
50 changes: 50 additions & 0 deletions examples/phi-3-nc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#
# Copyright © 2024 Intel Corporation
# SPDX-License-Identifier: Apache 2.0
#

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer
import intel_npu_acceleration_library as npu_lib
import warnings

torch.random.manual_seed(0)

model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
torch_dtype="auto",
trust_remote_code=True,
)

model = npu_lib.compile(model, dtype=npu_lib.int4)
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
streamer = TextStreamer(tokenizer, skip_prompt=True)

messages = [
{
"role": "system",
"content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.",
},
{
"role": "user",
"content": "Can you provide ways to eat combinations of bananas and dragonfruits?",
},
]

pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)

generation_args = {
"max_new_tokens": 500,
"return_full_text": False,
"temperature": 0.0,
"do_sample": False,
"streamer": streamer,
}

with warnings.catch_warnings():
warnings.simplefilter("ignore")
pipe(messages, **generation_args)
13 changes: 13 additions & 0 deletions include/intel_npu_acceleration_library/conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@

namespace intel_npu_acceleration_library {

/**
* @brief Compress a int8 vector to I4 format.
*
* @param src pointer to the source int8 buffer
* @param dst pointer to the destination uint8 buffer
* @param size size of the src and dst buffers
*/
void compressToI4(const int8_t* src, uint8_t* dst, size_t size) {
for (size_t i = 0; i < size / 2; i++) {
dst[i] = (src[2 * i] & 0x0F) | ((src[2 * i + 1] & 0x0F) << 4);
}
}

/**
* @brief Convert a int8 vector to fp16 given a scalar scale.
*
Expand Down
3 changes: 2 additions & 1 deletion intel_npu_acceleration_library/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#

from .compiler import compile
from .dtypes import int4, int8, float16


__all__ = ["compile"]
__all__ = ["compile", "int4", "int8", "float16"]
2 changes: 2 additions & 0 deletions intel_npu_acceleration_library/backend/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def init_common(lib: ctypes.CDLL):

lib.isNPUAvailable.restype = ctypes.c_bool

lib.compressToI4.argtypes = [c_i8_array, c_u8_array, ctypes.c_int]


def init_network_factory(lib: ctypes.CDLL):
"""Initialize Netowrk factory bindings.
Expand Down
24 changes: 24 additions & 0 deletions intel_npu_acceleration_library/backend/compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#
# Copyright © 2024 Intel Corporation
# SPDX-License-Identifier: Apache 2.0
#

from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
import numpy as np


def compress_to_i4(weights: np.ndarray) -> np.ndarray:
"""Compress a int8 array to int4.
Args:
weights (np.ndarray): input array
Returns:
np.ndarray: compressed array
"""
compressed_weights = np.zeros(
(weights.shape[0], weights.shape[1] // 2), dtype=np.uint8
)

backend_lib.compressToI4(weights, compressed_weights, np.prod(weights.shape))
return compressed_weights
4 changes: 3 additions & 1 deletion intel_npu_acceleration_library/backend/qlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
batch: int,
profile: bool = False,
device: str = "NPU",
dtype: np.dtype = np.int8,
):
"""Initialize the QLinear class.
Expand All @@ -26,6 +27,7 @@ def __init__(
batch (int): batch
profile (bool): Enable/Disable profiling. Defaults to False.
device (str): Target device, default to "NPU".
dtype (np.dtype): weights datatype. Defaults to np.int8.
Raises:
RuntimeError: Quantized matmul requires input_channel to be a multiple of 8
Expand All @@ -35,7 +37,7 @@ def __init__(
raise RuntimeError(
"Quantized matmul requires input_channel to be a multiple of 8"
)
out = self.linear(self.input, outC, inC, bias=False, wt_dtype=np.int8)
out = self.linear(self.input, outC, inC, bias=False, wt_dtype=dtype)
self.compile(out)

def run(
Expand Down
4 changes: 3 additions & 1 deletion intel_npu_acceleration_library/backend/qmatmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
batch: int,
profile: bool = False,
device: str = "NPU",
dtype: np.dtype = np.int8,
):
"""Initialize the QMatmul class.
Expand All @@ -26,9 +27,10 @@ def __init__(
batch (int): batch
profile (bool): Enable/Disable profiling. Defaults to False.
device (str): Target device, default to "NPU".
dtype (np.dtype): weights datatype. Defaults to np.int8.
"""
super().__init__(inC, outC, batch, profile, device)
out = self.linear(self.input, outC, inC, bias=False, wt_dtype=np.int8)
out = self.linear(self.input, outC, inC, bias=False, wt_dtype=dtype)
self.compile(out)

def run(self, X: np.ndarray, W: np.ndarray, scale: np.ndarray) -> np.ndarray:
Expand Down
18 changes: 14 additions & 4 deletions intel_npu_acceleration_library/backend/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from intel_npu_acceleration_library.backend import NNFactory
from torch.profiler import record_function
from typing import Optional, List, Any, Dict, Deque
from functools import partial
from collections import deque
import numpy as np
import torch
Expand Down Expand Up @@ -46,18 +47,27 @@ def run_matmul(

outC, inC = weights.shape[-2:]

if weights.dtype == torch.uint8:
# In case is Int4 we need to double the input channels because weights are compressed
inC *= 2

# Set tensors as contiguous in memory
x = set_contiguous(x)
weights = set_contiguous(weights)
weights = weights.view([-1, weights.shape[-1]])

if weights.dtype.is_floating_point:
op_class = Linear if op_id is not None else MatMul
op_class_name = op_class.__name__
create_op = partial(op_class)
op_args = [weights.to(torch.float16).numpy()]
elif weights.dtype == torch.int8:
elif weights.dtype in (torch.int8, torch.uint8):
if scale is None:
raise RuntimeError("Quantized weights require a not null scale")
op_class = QLinear if op_id is not None else QMatMul
op_class_name = op_class.__name__
np_dtype = np.int8 if weights.dtype == torch.int8 else np.uint8
create_op = partial(op_class, dtype=np_dtype)
if scale is None:
raise RuntimeError(
f"Quantized matmul (weights dtype == {weights.dtype}) requires scale (scale = {scale})"
Expand Down Expand Up @@ -90,13 +100,13 @@ def run_matmul(
else:
batch = real_batch

key = f"{str(op_class.__name__)}_{batch}_{inC}_x_{outC}_{inC}_{x_np.dtype}"
key = f"{str(op_class_name)}_{batch}_{inC}_x_{outC}_{inC}_{x_np.dtype}"
models = _model_cache.get(key, None)

if models is None:
_model_cache[key] = deque([op_class(inC, outC, batch)])
_model_cache[key] = deque([create_op(inC, outC, batch)])
elif len(models) < 1:
_model_cache[key].append(op_class(inC, outC, batch))
_model_cache[key].append(create_op(inC, outC, batch))
else:
_model_cache[key].rotate(1)

Expand Down
37 changes: 24 additions & 13 deletions intel_npu_acceleration_library/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,14 @@
#

from intel_npu_acceleration_library.optimizations import horizontal_fusion_linear
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention
from transformers.models.gemma.modeling_gemma import GemmaMLP, GemmaAttention
from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear
from intel_npu_acceleration_library.quantization import quantize_model
from intel_npu_acceleration_library.dtypes import int8, int4
import intel_npu_acceleration_library.nn as nn
from torch._dynamo import register_backend
from typing import Union, Callable, Any

try:
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention
from transformers.models.gemma.modeling_gemma import GemmaMLP, GemmaAttention

is_transformers_available = True
except ModuleNotFoundError:
# Transformer library is not installed
is_transformers_available = False


import intel_npu_acceleration_library.nn as nn
from typing import List
import torch

Expand All @@ -38,7 +32,7 @@ def compile(
Returns:
torch.nn.Module: compiled NPU nn.Module
"""
if not (dtype.is_floating_point or dtype == torch.int8):
if not (dtype.is_floating_point or dtype in (int8, int4)):
raise RuntimeError(
f"intel-npu-acceleration-library library do not support yet the requeste datatype: {dtype}"
)
Expand All @@ -48,6 +42,9 @@ def compile(
# General optimizations
apply_horizontal_fusion(model)
optimize_llama_attention(model, dtype)
if dtype in (int8, int4):
# Quantize model
model = quantize_model(model, dtype)

# Model lowering to NPU ops
lower_linear(model, dtype)
Expand Down Expand Up @@ -102,13 +99,27 @@ def lower_linear(
layer (torch.nn.Module): Original torch.nn.Linear module
dtype (torch.dtype): Target datatype
Raises:
RuntimeError: unsupported quantization bits
Returns:
Union[torch.nn.Module, None]: Return the new NPU operator or None
"""
if isinstance(layer, torch.nn.Linear):
return nn.Linear.fromTorch(layer, dtype)
if isinstance(layer, torch.nn.Conv2d):
return nn.Conv2d.fromTorch(layer, dtype)
if isinstance(layer, WeightOnlyLinear):
if layer.bits == 4:
return nn.QuantizedLinear(
layer.qweight.to(torch.uint8), layer.scales, layer.bias
)
elif layer.bits == 8:
return nn.QuantizedLinear(
layer.qweight.view(torch.int8), layer.scales, layer.bias
)
else:
raise RuntimeError(f"Unsupported quantization bits: {layer.bits}")
return None


Expand Down
Loading

0 comments on commit b34d859

Please sign in to comment.