Skip to content

Commit 4b9b447

Browse files
authored
support RTN on Gaudi2 and make UTs auto detect device (#1811)
Signed-off-by: xin3he <xin3.he@intel.com>
1 parent 8dac9f2 commit 4b9b447

File tree

15 files changed

+165
-132
lines changed

15 files changed

+165
-132
lines changed

neural_compressor/torch/algorithms/mix_precision/half_precision_convert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from neural_compressor.common import logger
2424
from neural_compressor.torch.algorithms.mix_precision.module_wrappers import HalfPrecisionModuleWrapper
25-
from neural_compressor.torch.utils import get_device
25+
from neural_compressor.torch.utils import get_accelerator
2626

2727

2828
class HalfPrecisionConverter:
@@ -40,7 +40,7 @@ def __init__(self, configs_mapping: Dict[Tuple[str], object], *args, **kwargs):
4040
configs_mapping (Dict): config class for mix-precision.
4141
"""
4242
self.configs_mapping = configs_mapping
43-
self.device = get_device()
43+
self.device = get_accelerator().current_device_name()
4444

4545
def convert(self, model: torch.nn.Module):
4646
"""Convert to FP16 or BF16 model.

neural_compressor/torch/algorithms/weight_only/awq.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121

2222
from neural_compressor.torch.algorithms import Quantizer
23-
from neural_compressor.torch.utils import get_device, logger
23+
from neural_compressor.torch.utils import get_accelerator, logger
2424

2525
from .modules import MulLinear
2626
from .utility import (
@@ -124,13 +124,15 @@ def __init__(
124124
weight_config={},
125125
total_block_args=[],
126126
total_block_kwargs=[],
127+
device="auto",
127128
):
128129

129130
self.example_inputs = example_inputs
130131
self.model = model
131132
if example_inputs is None:
132133
assert dataloader is not None, "datalaoder or example_inputs is required."
133134
self.example_inputs = get_example_input(dataloader)
135+
self.device = device
134136
self._move_model_and_data_to_device()
135137
self.total_block_args = total_block_args
136138
self.total_block_kwargs = total_block_kwargs
@@ -146,7 +148,7 @@ def __init__(
146148

147149
def _move_model_and_data_to_device(self):
148150
# Put the model and example_inputs into target device
149-
device = get_device()
151+
device = get_accelerator(self.device).current_device_name()
150152
self.model.to(device)
151153
self.example_inputs = self.example_inputs.to(device)
152154

neural_compressor/torch/algorithms/weight_only/gptq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import torch.nn as nn
2828
from tqdm import tqdm
2929

30-
from neural_compressor.torch.utils import fetch_module, get_device, is_transformers_imported, logger, set_module
30+
from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger, set_module
3131
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
3232

3333
from .modules import WeightOnlyLinear
@@ -258,7 +258,7 @@ def __init__(
258258
self.check_layer_config()
259259

260260
# device
261-
self.device = get_device(kwargs.pop("device", "auto"))
261+
self.device = get_accelerator(kwargs.pop("device", "auto")).current_device_name()
262262
self.model.to(self.device)
263263
self.is_ready = False
264264

neural_compressor/torch/algorithms/weight_only/modules.py

Lines changed: 61 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from torch.autograd import Function
2424
from torch.nn import functional as F
2525

26-
from neural_compressor.torch.utils import logger
26+
from neural_compressor.torch.utils import accelerator, logger
2727

2828
from .utility import quant_tensor
2929

@@ -174,9 +174,9 @@ def __init__(
174174

175175
def pack(self, int_weight, scale, zp, bias, g_idx=None):
176176
if self.use_optimum_format:
177-
self.scales = self.scales.t_().contiguous()
178-
self.qweight = self.qweight.t_().contiguous()
179-
self.qzeros = self.qzeros.t_().contiguous()
177+
self.scales = self.scales.T.contiguous()
178+
self.qweight = self.qweight.T.contiguous()
179+
self.qzeros = self.qzeros.T.contiguous()
180180
int_weight = int_weight.to(self.device)
181181
if self.use_optimum_format and zp is None:
182182
# to avoid overflow
@@ -197,124 +197,111 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
197197
assert scale.shape == self.scales.shape, f"{scale.shape} != {self.scales.shape} Scale shape is mismatched."
198198
self.scales = scale.type(self.float_type).to(self.device)
199199
if not self.use_optimum_format and self.compression_dim == 0:
200-
int_weight = int_weight.t_().contiguous()
201-
self.qweight = self.qweight.t_().contiguous()
200+
int_weight = int_weight.T.contiguous()
201+
self.qweight = self.qweight.T.contiguous()
202202
origin_shape = int_weight.shape
203203
target_shape = self.qweight.shape
204204
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
205-
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)
206205

207206
# pack weight
208-
for j in range(target_shape[1]):
209-
start = self.n_pack * j
210-
end = self.n_pack * (j + 1)
211-
tmp = int_weight[:, start:end].type(self.compression_dtype)
212-
for e in range(tmp.shape[1]):
213-
tmp[:, e] &= mask
214-
tmp[:, e] = tmp[:, e] << (self.bits * e)
215-
self.qweight[:, j] |= tmp[:, e]
207+
self.qweight.copy_(self.pack_tensor(int_weight))
216208
if not self.use_optimum_format and self.compression_dim == 0:
217-
self.qweight = self.qweight.t_().contiguous()
209+
self.qweight = self.qweight.T.contiguous()
218210

219211
if zp is not None:
220212
zp = zp.to(self.device)
221213
if self.use_optimum_format:
222214
zp -= 1
223215
if self.use_optimum_format or self.compression_dim == 0:
224-
zp = zp.t_().contiguous()
225-
self.qzeros = self.qzeros.t_().contiguous()
216+
zp = zp.T.contiguous()
217+
self.qzeros = self.qzeros.T.contiguous()
226218
assert hasattr(self, "qzeros"), "zp is not set when initializing."
227-
target_shape = self.qzeros.shape
228-
for j in range(target_shape[1]):
229-
start = self.n_pack * j
230-
end = self.n_pack * (j + 1)
231-
tmp = zp[:, start:end].type(self.compression_dtype)
232-
for e in range(tmp.shape[1]):
233-
tmp[:, e] &= mask
234-
tmp[:, e] = tmp[:, e] << (self.bits * e)
235-
self.qzeros[:, j] |= tmp[:, e]
219+
self.qzeros.copy_(self.pack_tensor(zp))
236220
if self.use_optimum_format or self.compression_dim == 0:
237-
self.qzeros = self.qzeros.t_().contiguous()
221+
self.qzeros = self.qzeros.T.contiguous()
238222
if self.use_optimum_format:
239-
self.scales = self.scales.t_().contiguous()
240-
self.qweight = self.qweight.t_().contiguous()
241-
self.qzeros = self.qzeros.t_().contiguous()
223+
self.scales = self.scales.T.contiguous()
224+
self.qweight = self.qweight.T.contiguous()
225+
self.qzeros = self.qzeros.T.contiguous()
242226

243227
def recover(self):
244228
logger.debug(f"Recovering {self} weight")
245-
scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales
246-
qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight
229+
scales = self.scales.T.contiguous() if self.use_optimum_format else self.scales
230+
qweight = self.qweight.T.contiguous() if self.use_optimum_format else self.qweight
247231

248232
device = scales.device
249233
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
250234
if self.g_idx is None:
251235
# used for recovering fp32_weight
252236
self.g_idx = torch.tensor([i // self.group_size for i in range(self.in_features)], dtype=torch.int32)
253-
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(device)
254-
if hasattr(self, "qzeros"):
255-
weight_dtype = torch.uint8
256-
else:
257-
weight_dtype = torch.int8
258237
# unpack weight
259-
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device)
260238
if not self.use_optimum_format and self.compression_dim == 0:
261-
weight = weight.t_().contiguous()
262-
qweight = qweight.t_().contiguous()
263-
origin_shape = weight.shape
264-
target_shape = qweight.shape
265-
for j in range(target_shape[1]):
266-
for e in range(self.n_pack):
267-
index = j * self.n_pack + e
268-
if index >= origin_shape[1]:
269-
continue
270-
tmp = qweight[:, j]
271-
tmp = tmp << (self.compress_bits - self.bits * (e + 1))
272-
tmp = tmp >> self.compress_bits - self.bits
273-
if weight_dtype == torch.uint8:
274-
tmp &= mask # remove sign bit
275-
weight[:, index] = tmp.type(weight_dtype)
239+
qweight = qweight.T.contiguous()
240+
weight = self.unpack_tensor(qweight)
276241
if not self.use_optimum_format and self.compression_dim == 0:
277-
weight = weight.t_().contiguous()
242+
weight = weight.T.contiguous()
243+
weight = weight[: self.out_features, : self.in_features] # avoid oversize
278244
if "int" not in self.dtype:
279245
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
280246
for k, v in self.int2float_mapping.items():
281247
new_weight += torch.where(weight == k, v, 0)
282248
weight = new_weight
283249
# unpack zero_point
284250
if hasattr(self, "qzeros"):
285-
zp_dtype = self.compression_dtype # to avoid overflow when weight-zp
286-
zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device)
287-
qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros
251+
qzeros = self.qzeros.T.contiguous() if self.use_optimum_format else self.qzeros
288252
if self.use_optimum_format or self.compression_dim == 0:
289-
zp = zp.t_().contiguous()
290-
qzeros = qzeros.t_().contiguous()
291-
origin_shape = zp.shape
292-
target_shape = qzeros.shape
293-
for j in range(target_shape[1]):
294-
for e in range(self.n_pack):
295-
index = j * self.n_pack + e
296-
if index >= origin_shape[1]:
297-
continue
298-
tmp = qzeros[:, j]
299-
tmp = tmp << (self.compress_bits - self.bits * (e + 1))
300-
tmp = tmp >> self.compress_bits - self.bits
301-
tmp &= mask
302-
zp[:, index] = tmp.type(zp_dtype)
253+
qzeros = qzeros.T.contiguous()
254+
zp = self.unpack_tensor(qzeros)
303255
if self.use_optimum_format or self.compression_dim == 0:
304-
zp = zp.t_().contiguous()
256+
zp = zp.T.contiguous()
257+
zp = zp[: scales.shape[0], : scales.shape[1]] # avoid oversize
305258
if self.use_optimum_format:
306259
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
307260
zp += 1
308261
zp = torch.where(zp > (2**self.bits - 1), 0, zp)
309262
# recover fp32 weight with int_weight, scale, and zero_point
310263
for idx in range(self.in_features):
311-
fp32_weight[:, idx] = (weight[:, idx] - zp[:, self.g_idx[idx]]) * scales[:, self.g_idx[idx]]
264+
fp32_weight[:, idx] = (torch.subtract(weight[:, idx], zp[:, self.g_idx[idx]]).to(torch.int8)) * scales[
265+
:, self.g_idx[idx]
266+
]
312267
else:
313268
# recover fp32 weight with int_weight, scale
314269
for idx in range(self.in_features):
315270
fp32_weight[:, idx] = weight[:, idx] * scales[:, self.g_idx[idx]]
316271
return fp32_weight
317272

273+
def pack_tensor(self, raw_tensor):
274+
target_len = math.ceil(raw_tensor.shape[1] / self.n_pack)
275+
packed_tensor = torch.zeros(raw_tensor.shape[0], target_len, dtype=self.compression_dtype).to(self.device)
276+
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)
277+
for j in range(packed_tensor.shape[1]):
278+
start = self.n_pack * j
279+
end = self.n_pack * (j + 1)
280+
tmp = raw_tensor[:, start:end].type(self.compression_dtype)
281+
tmp &= mask
282+
for e in range(tmp.shape[1]):
283+
tmp[:, e] = tmp[:, e] << (self.bits * e)
284+
packed_tensor[:, j] |= tmp[:, e]
285+
accelerator.synchronize()
286+
return packed_tensor
287+
288+
def unpack_tensor(self, packed_tensor):
289+
target_dtype = torch.int8 if not hasattr(self, "qzeros") or "int" not in self.dtype else torch.uint8
290+
target_len = packed_tensor.shape[1] * self.n_pack
291+
unpacked_tensor = torch.zeros(packed_tensor.shape[0], target_len, dtype=self.compression_dtype).to(self.device)
292+
mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device)
293+
for j in range(packed_tensor.shape[1]):
294+
for e in range(self.n_pack):
295+
index = j * self.n_pack + e
296+
tmp = packed_tensor[:, j]
297+
tmp = tmp << (self.compress_bits - self.bits * (e + 1))
298+
tmp = tmp >> self.compress_bits - self.bits
299+
if target_dtype == torch.uint8:
300+
tmp &= mask # remove sign bit
301+
unpacked_tensor[:, index].copy_(tmp.type(target_dtype))
302+
accelerator.synchronize()
303+
return unpacked_tensor
304+
318305
def forward(self, input):
319306
if not hasattr(self, "weight"):
320307
weight = self.recover()

neural_compressor/torch/algorithms/weight_only/rtn.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import torch
2525

2626
from neural_compressor.torch.algorithms import Quantizer
27-
from neural_compressor.torch.utils import get_device, is_transformers_imported, logger, set_module
27+
from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger, set_module
2828

2929
from .utility import cast_fp8, quant_tensor, search_clip
3030

@@ -90,7 +90,7 @@ def convert(
9090
model: fake quantized torch module
9191
"""
9292
weight_config = self.quant_config
93-
device = get_device(kwargs.pop("device", "auto"))
93+
device = get_accelerator(kwargs.pop("device", "auto")).current_device_name()
9494

9595
# Put model on device explicitly
9696
# TODO: refine it later, Put module on device one by one instead of the whole model
@@ -165,9 +165,9 @@ def convert(
165165
else:
166166
transpose = group_dim == 0
167167
if transpose:
168-
weight = m.weight.t_().contiguous()
168+
weight = m.weight.detach().T.contiguous()
169169
else:
170-
weight = m.weight
170+
weight = m.weight.detach()
171171
if use_mse_search:
172172
quantile = search_clip(m, bits, group_size, scheme, dtype, use_full_range)
173173
if export_compressed_model:
@@ -189,8 +189,8 @@ def convert(
189189
in_features = m.in_features
190190
out_features = m.out_features
191191
elif is_transformers_imported() and isinstance(m, transformers.Conv1D):
192-
in_features = m.weight.shape[1]
193-
out_features = m.weight.shape[0]
192+
in_features = m.weight.shape[0]
193+
out_features = m.weight.shape[1]
194194
int_weight = int_weight.t_().contiguous()
195195
scale = scale.t_().contiguous()
196196
zp = zp.t_().contiguous() if zp is not None else zp
@@ -227,6 +227,5 @@ def convert(
227227
# for only group_dim is 0 or only `transformers.Conv1D`,
228228
# we need to transpose the quantized tensor and module's weight back
229229
weight = weight.t_().contiguous()
230-
m.weight.t_().contiguous()
231230
m.weight.data.copy_(weight)
232231
return model

neural_compressor/torch/algorithms/weight_only/teq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch
2323

2424
from neural_compressor.torch.algorithms.base_algorithm import Quantizer
25-
from neural_compressor.torch.utils import get_device, is_transformers_imported, logger
25+
from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger
2626

2727
from .modules import MulLinear, TEQLinearFakeQuant
2828
from .utility import get_module, quant_tensor, set_module
@@ -63,7 +63,7 @@ def _post_init(self):
6363
def _get_device(self):
6464
"""Get the model device
6565
:return:Model device."""
66-
device = get_device()
66+
device = get_accelerator().current_device_name()
6767
return device
6868

6969
def _get_dtype(self):

neural_compressor/torch/algorithms/weight_only/utility.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import torch
1818

19-
from neural_compressor.torch.utils import logger
19+
from neural_compressor.torch.utils import accelerator, device_synchronize, logger
2020

2121
__all__ = [
2222
"FLOAT_MAPPING",
@@ -205,12 +205,12 @@ def qdq_weight_sym(weight, bits=4, quantile=1.0, return_int=False, full_range=Fa
205205
wmax = torch.max(torch.abs(max_val), torch.abs(min_val))
206206
wmax = wmax * quantile
207207
tmp = wmax == 0
208-
wmax[tmp] = +1
208+
wmax[tmp] = torch.tensor(1, dtype=wmax.dtype, device=wmax.device)
209209
if full_range:
210210
# use -8, 8 to make sure amax is not changed after fake quant
211211
scale = wmax / (-minq)
212-
tmp = scale * flip_flag.int()
213-
scale -= 2 * tmp # set negative scale with flip_flag
212+
# set negative scale with flip_flag
213+
scale = torch.where(flip_flag, -scale, scale)
214214
else:
215215
scale = wmax / maxq
216216
scale.unsqueeze_(dim=-1)
@@ -248,6 +248,7 @@ def qdq_weight_actor(weight, bits, scheme, quantile=1.0, dtype="int", return_int
248248
return qdq_weight_asym(weight, bits, quantile, return_int, **kwargs)
249249

250250

251+
@device_synchronize
251252
def quant_tensor(
252253
weight,
253254
bits=4,
@@ -343,10 +344,12 @@ def quant_tensor(
343344
)
344345
if return_int or quant_scale:
345346
weight2, scale2, zp2 = weight2
346-
orig_weight.copy_(torch.cat([weight1, weight2], dim=1))
347+
weight = torch.cat([weight1, weight2], dim=1)
347348
scale = torch.cat([scale1, scale2], dim=1)
348349
zp = None if zp2 is None else torch.cat([zp1, zp2], dim=1)
349-
q_state = (weight, scale, zp)
350+
accelerator.synchronize()
351+
orig_weight.copy_(weight)
352+
return orig_weight, scale, zp
350353
else:
351354
orig_weight.copy_(torch.cat([weight1, weight2], dim=1))
352355
return orig_weight
@@ -444,7 +447,7 @@ def search_clip(m, bits=4, group_size=32, scheme="asym", dtype="int", enable_ful
444447
full_range=enable_full_range,
445448
quantile=ratio,
446449
)
447-
loss = (org_weight - m.weight.data).float().pow(2).mean().item()
450+
loss = (org_weight - m.weight.data).float().pow(2).mean()
448451
m.weight.data.copy_(org_weight)
449452
history.append(loss)
450453
is_best = loss < best_error

0 commit comments

Comments
 (0)