diff --git a/examples/phi-3.py b/examples/phi-3.py
index f59d0fb..51a368c 100644
--- a/examples/phi-3.py
+++ b/examples/phi-3.py
@@ -4,7 +4,7 @@
#
import torch
-from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer
+from transformers import AutoTokenizer, pipeline, TextStreamer
import intel_npu_acceleration_library as npu_lib
import warnings
@@ -13,7 +13,6 @@
model = npu_lib.NPUModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
torch_dtype="auto",
- trust_remote_code=True,
dtype=npu_lib.int4,
)
diff --git a/intel_npu_acceleration_library/backend/qlinear.py b/intel_npu_acceleration_library/backend/qlinear.py
index 8093ecd..d5fcd25 100644
--- a/intel_npu_acceleration_library/backend/qlinear.py
+++ b/intel_npu_acceleration_library/backend/qlinear.py
@@ -29,16 +29,11 @@ def __init__(
device (str): Target device, default to "NPU".
dtype (np.dtype): weights datatype. Defaults to np.int8.
- Raises:
- RuntimeError: Quantized matmul requires input_channel to be a multiple of 8
"""
super().__init__(profile, device)
self.inC, self.outC = inC, outC
self.batch = batch
- if inC % 8 != 0:
- raise RuntimeError(
- "Quantized matmul requires input_channel to be a multiple of 8"
- )
+
input = self.parameter((self.batch, self.inC))
out = self.linear(input, outC, inC, bias=False, wt_dtype=dtype)
self.compile(out)
diff --git a/intel_npu_acceleration_library/modelling.py b/intel_npu_acceleration_library/modelling.py
index 22c8d4d..e9510ec 100644
--- a/intel_npu_acceleration_library/modelling.py
+++ b/intel_npu_acceleration_library/modelling.py
@@ -82,6 +82,7 @@ def from_pretrained(
Raises:
RuntimeError: Invalid class
+ AttributeError: Cannot export model with trust_remote_code=True
Returns:
torch.nn.Module: compiled mode
@@ -103,6 +104,10 @@ def from_pretrained(
)
model = npu_lib.compile(model, dtype, training)
if export:
+ if kwargs.get("trust_remote_code", False):
+ raise AttributeError(
+ "Cannot export model with trust_remote_code=True. Please set trust_remote_code=False or export=False"
+ )
print(f"Exporting model {model_name_or_path} to {model_dir_path}")
os.makedirs(model_dir_path, exist_ok=True)
torch.save(model, model_path)
diff --git a/intel_npu_acceleration_library/nn/linear.py b/intel_npu_acceleration_library/nn/linear.py
index f29a108..8d7acbb 100644
--- a/intel_npu_acceleration_library/nn/linear.py
+++ b/intel_npu_acceleration_library/nn/linear.py
@@ -80,7 +80,7 @@ def fromTensor(
dtype (torch.dtype): the desired datatype
Raises:
- RuntimeError: Quantized Linear requires input_channel to be a multiple of 8
+ RuntimeError: dtype not supported
Returns:
Union[Linear, QuantizedLinear]: A NPU linear layer
@@ -95,10 +95,6 @@ def fromTensor(
weights_quant = compress_to_i4(weights_quant)
return QuantizedLinear(weights_quant, scale, bias)
elif dtype == torch.int8:
- if weight.shape[-1] % 8 != 0:
- raise RuntimeError(
- "Quantized Linear requires input_channel to be a multiple of 8"
- )
weights_quant, scale = quantize_tensor(weight)
return QuantizedLinear(weights_quant, scale, bias)
else: