Skip to content

[3x] support automatic host2device on RTN and GPTQ #1894

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion neural_compressor/torch/algorithms/weight_only/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
import torch.nn as nn
from tqdm import tqdm

from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger, set_module
from neural_compressor.torch.utils import (
get_accelerator,
get_model_device,
is_transformers_imported,
logger,
set_module,
)
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

from .modules import WeightOnlyLinear
Expand Down Expand Up @@ -995,6 +1001,7 @@ def prepare(
if use_layer_wise: # pragma: no cover
assert model_path is not None, "model_path should not be None when use layer wise mode"

self.model_device = get_model_device(model) # return model on the same device
self.gptq_quantizer = RAWGPTQuantizer(
model,
weight_config=self.quant_config,
Expand All @@ -1013,6 +1020,7 @@ def convert(self, model, *args, **kwargs):
self.gptq_quantizer.model = model
self.gptq_quantizer.remove_prepare_for_calibration()
q_model, gptq_config = self.gptq_quantizer.execute_quantization()
q_model = q_model.to(self.model_device)
q_model.gptq_config = gptq_config
logger.info("GPTQ quantizing done.")
return q_model
12 changes: 6 additions & 6 deletions neural_compressor/torch/algorithms/weight_only/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ def recover(self):

def pack_tensor_with_torch(self, raw_tensor):
target_len = math.ceil(raw_tensor.shape[1] / self.n_pack)
packed_tensor = torch.zeros(raw_tensor.shape[0], target_len, dtype=self.compression_dtype).to(self.device)
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)
packed_tensor = torch.zeros(raw_tensor.shape[0], target_len, dtype=self.compression_dtype).to(raw_tensor.device)
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(raw_tensor.device)
for j in range(packed_tensor.shape[1]):
start = self.n_pack * j
end = self.n_pack * (j + 1)
Expand All @@ -286,8 +286,8 @@ def pack_tensor_with_torch(self, raw_tensor):
def unpack_tensor_with_torch(self, packed_tensor):
target_dtype = torch.int8 if not hasattr(self, "qzeros") or "int" not in self.dtype else torch.uint8
target_len = packed_tensor.shape[1] * self.n_pack
unpacked_tensor = torch.zeros(packed_tensor.shape[0], target_len, dtype=target_dtype).to(self.device)
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)
unpacked_tensor = torch.zeros(packed_tensor.shape[0], target_len, dtype=target_dtype).to(packed_tensor.device)
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(packed_tensor.device)
for j in range(packed_tensor.shape[1]):
for e in range(self.n_pack):
index = j * self.n_pack + e
Expand Down Expand Up @@ -338,13 +338,13 @@ def unpack_tensor_with_numpy(self, packed_tensor):
return unpacked_tensor

def pack_tensor(self, raw_tensor):
if "cuda" in self.device:
if "cuda" in raw_tensor.device.type:
return self.pack_tensor_with_torch(raw_tensor)
else:
return self.pack_tensor_with_numpy(raw_tensor)

def unpack_tensor(self, packed_tensor):
if "cuda" in self.device:
if "cuda" in packed_tensor.device.type:
return self.unpack_tensor_with_torch(packed_tensor)
else:
return self.unpack_tensor_with_numpy(packed_tensor)
Expand Down
12 changes: 8 additions & 4 deletions neural_compressor/torch/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from neural_compressor.torch.utils import (
get_accelerator,
get_attr,
get_model_device,
is_transformers_imported,
logger,
set_attr,
Expand Down Expand Up @@ -99,10 +100,7 @@ def convert(
"""
weight_config = self.quant_config
device = get_accelerator(kwargs.pop("device", "auto")).current_device_name()

# Put model on device explicitly
# TODO: refine it later, Put module on device one by one instead of the whole model
model.to(device)
model_device = get_model_device(model) # return model on the same device

# for transformers model. If lm_head is tied from embedding, we deepcopy it.
if quant_lm_head and getattr(getattr(model, "config", None), "tie_word_embeddings", False):
Expand Down Expand Up @@ -132,6 +130,8 @@ def convert(
dtype = weight_config[name].get("dtype", "int")
if dtype == "fp32":
continue
# Move modules to the accelerator device layer-by-layer
m.to(device)
### FP8 cast part
if dtype in ["fp8_e5m2", "fp8_e5m2fnuz", "fp8_e4m3fn", "fp8_e4m3fnuz"]:
logger.debug("Cast module {} to FP8 using qdq mode, no scaling".format(name))
Expand Down Expand Up @@ -223,4 +223,8 @@ def convert(
return new_module
else:
set_module(model, name, new_module)
# Move modules back to the model device layer-by-layer
m.to(model_device)
new_module.to(model_device)
model.to(model_device)
return model
13 changes: 13 additions & 0 deletions neural_compressor/torch/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,16 @@ def dump_model_op_stats(mode, tune_cfg):
output_data.append(field_results)

Statistics(output_data, header="Mixed Precision Statistics", field_names=field_names).print_stat()


def get_model_device(model: torch.nn.Module):
"""Get the device.

Args:
model (torch.nn.Module): the input model.

Returns:
device (str): a string.
"""
for n, p in model.named_parameters():
return p.data.device.type # p.data.device == device(type='cpu')
18 changes: 16 additions & 2 deletions test/3x/torch/quantization/weight_only/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ def setup_class(self):
def teardown_class(self):
shutil.rmtree("saved_results", ignore_errors=True)

@pytest.mark.skipif(device == "cpu", reason="no available accelerator")
def test_auto_host2device(self):
# if model is on CPU, we move it to device layer-by-layer for acceleration,
# and then move it back to CPU after quantization.
model = copy.deepcopy(self.tiny_gptj).to("cpu")
example_inputs = copy.deepcopy(self.example_inputs).to("cpu")
quant_config = get_default_gptq_config()
model = prepare(model, quant_config)
run_fn(model)
model = convert(model)
gptq_label = model(example_inputs)[0]
gptq_atol = (gptq_label - self.label.to("cpu")).amax()
assert gptq_atol < 0.06, "GPTQ should have low atol."

def test_accuracy_improvement(self):
# test_default_rtn_config
model = copy.deepcopy(self.tiny_gptj)
Expand Down Expand Up @@ -215,9 +229,9 @@ def test_conv1d(self):
from transformers import GPT2Model, GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("sshleifer/tiny-gpt2")
model = GPT2Model.from_pretrained("sshleifer/tiny-gpt2")
model = GPT2Model.from_pretrained("sshleifer/tiny-gpt2").to(device)
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors="pt")
encoded_input = tokenizer(text, return_tensors="pt").to(device)

def run_fn_conv1d(model):
model(**encoded_input)
Expand Down
13 changes: 13 additions & 0 deletions test/3x/torch/quantization/weight_only/test_rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,16 @@ def mock_is_transformers_imported():
model = convert(model)
out = model(self.example_inputs)[0]
assert torch.allclose(out, self.label, atol=1e-1), "Accuracy gap atol > 0.1 is unexpected."

@pytest.mark.skipif(device == "cpu", reason="no available accelerator")
def test_auto_host2device(self):
# if model is on CPU, we move it to device layer-by-layer for acceleration,
# and then move it back to CPU after quantization.
model = copy.deepcopy(self.tiny_gptj).to("cpu")
example_inputs = copy.deepcopy(self.example_inputs).to("cpu")
quant_config = get_default_rtn_config()
model = prepare(model, quant_config)
model = convert(model)
rtn_label = model(example_inputs)[0]
rtn_atol = (rtn_label - self.label.to("cpu")).amax()
assert rtn_atol < 0.08, "RTN should have low atol."
Loading