Skip to content

Commit de56d7c

Browse files
q10facebook-github-bot
authored andcommitted
Fix Quantize tests (#2301)
Summary: - Fix Quantize tests to pass again Reviewed By: jspark1105 Differential Revision: D53283205
1 parent 50ecd52 commit de56d7c

File tree

11 files changed

+225
-257
lines changed

11 files changed

+225
-257
lines changed

fbgemm_gpu/test/quantize/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .common import ( # noqa F401
8+
fused_rowwise_8bit_dequantize_reference,
9+
fused_rowwise_8bit_dequantize_reference_half,
10+
fused_rowwise_8bit_quantize_reference,
11+
fused_rowwise_nbit_quantize_dequantize_reference,
12+
fused_rowwise_nbit_quantize_reference,
13+
)

fbgemm_gpu/test/quantize/bfloat16_test.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,7 @@
1313
import torch
1414
from hypothesis import given, HealthCheck, settings
1515

16-
17-
try:
18-
# pyre-ignore[21]
19-
from fbgemm_gpu import open_source # noqa: F401
20-
21-
except Exception:
22-
if torch.version.hip:
23-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
24-
else:
25-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
26-
27-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
16+
from . import common # noqa E402
2817

2918

3019
class SparseNNOperatorsGPUTest(unittest.TestCase):

fbgemm_gpu/test/quantize/common.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import math
8+
import struct
9+
from typing import Callable
10+
11+
import fbgemm_gpu
12+
import numpy as np
13+
import torch
14+
15+
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
16+
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
17+
18+
try:
19+
# pyre-ignore[21]
20+
from fbgemm_gpu import open_source # noqa: F401
21+
22+
except Exception:
23+
if torch.version.hip:
24+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
25+
else:
26+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
27+
28+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
29+
30+
# Eigen/Python round 0.5 away from 0, Numpy rounds to even
31+
round_to_nearest: Callable[[np.ndarray], np.ndarray] = np.vectorize(round)
32+
33+
34+
def bytes_to_floats(byte_matrix: np.ndarray) -> np.ndarray:
35+
floats = np.empty([np.shape(byte_matrix)[0], 1], dtype=np.float32)
36+
for i, byte_values in enumerate(byte_matrix):
37+
(floats[i],) = struct.unpack("f", bytearray(byte_values))
38+
return floats
39+
40+
41+
def floats_to_bytes(floats: np.ndarray) -> np.ndarray:
42+
byte_matrix = np.empty([np.shape(floats)[0], 4], dtype=np.uint8)
43+
for i, value in enumerate(floats):
44+
assert isinstance(value, np.float32), (value, floats)
45+
as_bytes = struct.pack("f", value)
46+
# In Python3 bytes will be a list of int, in Python2 a list of string
47+
if isinstance(as_bytes[0], int):
48+
byte_matrix[i] = list(as_bytes)
49+
else:
50+
byte_matrix[i] = list(map(ord, as_bytes))
51+
return byte_matrix
52+
53+
54+
def bytes_to_half_floats(byte_matrix: np.ndarray) -> np.ndarray:
55+
floats = np.empty([np.shape(byte_matrix)[0], 1], dtype=np.float16)
56+
for i, byte_values in enumerate(byte_matrix):
57+
(floats[i],) = np.frombuffer(
58+
memoryview(byte_values).tobytes(), dtype=np.float16
59+
)
60+
return floats
61+
62+
63+
def half_floats_to_bytes(floats: np.ndarray) -> np.ndarray:
64+
byte_matrix = np.empty([np.shape(floats)[0], 2], dtype=np.uint8)
65+
for i, value in enumerate(floats):
66+
assert isinstance(value, np.float16), (value, floats)
67+
byte_matrix[i] = np.frombuffer(
68+
memoryview(value.tobytes()).tobytes(), dtype=np.uint8
69+
)
70+
return byte_matrix
71+
72+
73+
def fused_rowwise_8bit_quantize_reference(data: np.ndarray) -> np.ndarray:
74+
minimum = np.min(data, axis=-1, keepdims=True)
75+
maximum = np.max(data, axis=-1, keepdims=True)
76+
span = maximum - minimum
77+
bias = minimum
78+
scale = span / 255.0
79+
inverse_scale = 255.0 / (span + 1e-8)
80+
quantized_data = round_to_nearest((data - bias) * inverse_scale)
81+
scale_bytes = floats_to_bytes(scale.reshape(-1))
82+
scale_bytes = scale_bytes.reshape(data.shape[:-1] + (scale_bytes.shape[-1],))
83+
bias_bytes = floats_to_bytes(bias.reshape(-1))
84+
bias_bytes = bias_bytes.reshape(data.shape[:-1] + (bias_bytes.shape[-1],))
85+
return np.concatenate([quantized_data, scale_bytes, bias_bytes], axis=-1)
86+
87+
88+
def fused_rowwise_8bit_dequantize_reference(fused_quantized: np.ndarray) -> np.ndarray:
89+
scale = bytes_to_floats(fused_quantized[..., -8:-4].astype(np.uint8).reshape(-1, 4))
90+
scale = scale.reshape(fused_quantized.shape[:-1] + (scale.shape[-1],))
91+
bias = bytes_to_floats(fused_quantized[..., -4:].astype(np.uint8).reshape(-1, 4))
92+
bias = bias.reshape(fused_quantized.shape[:-1] + (bias.shape[-1],))
93+
quantized_data = fused_quantized[..., :-8]
94+
return quantized_data * scale + bias
95+
96+
97+
def fused_rowwise_8bit_dequantize_reference_half(
98+
fused_quantized: np.ndarray,
99+
) -> np.ndarray:
100+
scale = bytes_to_half_floats(
101+
fused_quantized[..., -8:-4].astype(np.uint8).reshape(-1, 4)
102+
)
103+
scale = scale.reshape(fused_quantized.shape[:-1] + (scale.shape[-1],))
104+
bias = bytes_to_half_floats(
105+
fused_quantized[..., -4:].astype(np.uint8).reshape(-1, 4)
106+
)
107+
bias = bias.reshape(fused_quantized.shape[:-1] + (bias.shape[-1],))
108+
quantized_data = fused_quantized[..., :-8]
109+
return quantized_data * scale + bias
110+
111+
112+
def fused_rowwise_nbit_quantize_reference(data: np.ndarray, bit: int) -> np.ndarray:
113+
minimum = np.min(data, axis=1).astype(np.float16).astype(np.float32)
114+
maximum = np.max(data, axis=1)
115+
span = maximum - minimum
116+
qmax = (1 << bit) - 1
117+
scale = (span / qmax).astype(np.float16).astype(np.float32)
118+
bias = np.zeros(data.shape[0])
119+
quantized_data = np.zeros(data.shape).astype(np.uint8)
120+
121+
for i in range(data.shape[0]):
122+
bias[i] = minimum[i]
123+
inverse_scale = 1.0 if scale[i] == 0.0 else 1 / scale[i]
124+
if scale[i] == 0.0 or math.isinf(inverse_scale):
125+
scale[i] = 1.0
126+
inverse_scale = 1.0
127+
quantized_data[i] = np.clip(
128+
np.round((data[i, :] - minimum[i]) * inverse_scale), 0, qmax
129+
)
130+
131+
# pack
132+
assert 8 % bit == 0
133+
num_elem_per_byte = 8 // bit
134+
packed_dim = (data.shape[1] + num_elem_per_byte - 1) // num_elem_per_byte
135+
packed_data = np.zeros([data.shape[0], packed_dim]).astype(np.uint8)
136+
for i in range(data.shape[0]):
137+
for j in range(data.shape[1]):
138+
if j % num_elem_per_byte == 0:
139+
packed_data[i, j // num_elem_per_byte] = quantized_data[i, j]
140+
else:
141+
packed_data[i, j // num_elem_per_byte] += quantized_data[i, j] << (
142+
(j % num_elem_per_byte) * bit
143+
)
144+
145+
scale_bytes = half_floats_to_bytes(scale.astype(np.float16))
146+
bias_bytes = half_floats_to_bytes(bias.astype(np.float16))
147+
return np.concatenate([packed_data, scale_bytes, bias_bytes], axis=1)
148+
149+
150+
def fused_rowwise_nbit_quantize_dequantize_reference(
151+
data: np.ndarray, bit: int
152+
) -> np.ndarray:
153+
fused_quantized = fused_rowwise_nbit_quantize_reference(data, bit)
154+
scale = bytes_to_half_floats(fused_quantized[:, -4:-2].astype(np.uint8)).astype(
155+
np.float32
156+
)
157+
bias = bytes_to_half_floats(fused_quantized[:, -2:].astype(np.uint8)).astype(
158+
np.float32
159+
)
160+
quantized_data = fused_quantized[:, :-4]
161+
162+
# unpack
163+
packed_dim = fused_quantized.shape[1] - 4
164+
assert 8 % bit == 0
165+
num_elem_per_byte = 8 // bit
166+
assert packed_dim == ((data.shape[1] + num_elem_per_byte - 1) // num_elem_per_byte)
167+
unpacked_data = np.zeros(data.shape).astype(np.uint8)
168+
for i in range(data.shape[0]):
169+
for j in range(data.shape[1]):
170+
unpacked_data[i, j] = (
171+
quantized_data[i, j // num_elem_per_byte]
172+
>> ((j % num_elem_per_byte) * bit)
173+
) & ((1 << bit) - 1)
174+
175+
return scale * unpacked_data + bias

fbgemm_gpu/test/quantize/fp8_rowwise_test.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,13 @@
1515
from fbgemm_gpu.split_embedding_configs import SparseType
1616
from hypothesis import given, settings, Verbosity
1717

18+
from . import common # noqa E402
19+
from .common import open_source
1820

19-
try:
21+
if open_source:
2022
# pyre-ignore[21]
21-
from fbgemm_gpu import open_source # noqa: F401
22-
23-
# pyre-ignore[21]
24-
from test_utils import ( # noqa: F401
25-
gpu_unavailable,
26-
optests,
27-
symint_vector_unsupported,
28-
)
29-
except Exception:
30-
if torch.version.hip:
31-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
32-
else:
33-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
34-
35-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
23+
from test_utils import gpu_unavailable, optests, symint_vector_unsupported
24+
else:
3625
from fbgemm_gpu.test.test_utils import (
3726
gpu_unavailable,
3827
optests,

fbgemm_gpu/test/quantize/fused_8bit_rowwise_test.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,18 @@
1212
from fbgemm_gpu.split_embedding_configs import SparseType
1313
from hypothesis import assume, given, HealthCheck, settings
1414

15-
try:
16-
# pyre-ignore[21]
17-
from fbgemm_gpu import open_source # noqa: F401
15+
from . import common # noqa E402
16+
from .common import (
17+
fused_rowwise_8bit_dequantize_reference,
18+
fused_rowwise_8bit_quantize_reference,
19+
open_source,
20+
)
1821

22+
if open_source:
1923
# pyre-ignore[21]
20-
from test_utils import (
21-
fused_rowwise_8bit_dequantize_reference,
22-
fused_rowwise_8bit_quantize_reference,
23-
gpu_available,
24-
)
25-
except Exception:
26-
if torch.version.hip:
27-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
28-
else:
29-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
30-
31-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
32-
33-
from fbgemm_gpu.test.test_utils import (
34-
fused_rowwise_8bit_dequantize_reference,
35-
fused_rowwise_8bit_quantize_reference,
36-
gpu_available,
37-
)
24+
from test_utils import gpu_available
25+
else:
26+
from fbgemm_gpu.test.test_utils import gpu_available
3827

3928
no_long_tests: bool = False
4029

fbgemm_gpu/test/quantize/fused_nbit_rowwise_test.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,20 @@
1212
from fbgemm_gpu.split_embedding_configs import SparseType
1313
from hypothesis import assume, given, HealthCheck, settings
1414

15-
try:
16-
# pyre-ignore[21]
17-
from fbgemm_gpu import open_source # noqa: F401
15+
from . import common # noqa E402
16+
from .common import (
17+
bytes_to_half_floats,
18+
fused_rowwise_nbit_quantize_dequantize_reference,
19+
fused_rowwise_nbit_quantize_reference,
20+
open_source,
21+
)
1822

23+
if open_source:
1924
# pyre-ignore[21]
20-
from test_utils import (
21-
bytes_to_half_floats,
22-
fused_rowwise_nbit_quantize_dequantize_reference,
23-
fused_rowwise_nbit_quantize_reference,
24-
gpu_available,
25-
)
26-
except Exception:
27-
if torch.version.hip:
28-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
29-
else:
30-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
25+
from test_utils import gpu_available
26+
else:
27+
from fbgemm_gpu.test.test_utils import gpu_available
3128

32-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
33-
from fbgemm_gpu.test.test_utils import (
34-
bytes_to_half_floats,
35-
fused_rowwise_nbit_quantize_dequantize_reference,
36-
fused_rowwise_nbit_quantize_reference,
37-
gpu_available,
38-
)
3929

4030
no_long_tests: bool = False
4131

fbgemm_gpu/test/quantize/hfp8_test.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,7 @@
1212
from hypothesis import given, HealthCheck, settings
1313
from torch import Tensor
1414

15-
16-
try:
17-
# pyre-ignore[21]
18-
from fbgemm_gpu import open_source # noqa: F401
19-
20-
except Exception:
21-
if torch.version.hip:
22-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
23-
else:
24-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
25-
26-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
15+
from . import common # noqa E402
2716

2817

2918
class TestHFP8QuantizationConversion(unittest.TestCase):

fbgemm_gpu/test/quantize/mixed_dim_int8_test.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,13 @@
1212
from fbgemm_gpu.split_embedding_configs import SparseType
1313
from hypothesis import given, HealthCheck, settings
1414

15+
from . import common # noqa E402
16+
from .common import open_source
1517

16-
try:
17-
# pyre-ignore[21]
18-
from fbgemm_gpu import open_source # noqa: F401
19-
18+
if open_source:
2019
# pyre-ignore[21]
2120
from test_utils import gpu_unavailable
22-
except Exception:
23-
if torch.version.hip:
24-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
25-
else:
26-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
27-
28-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
21+
else:
2922
from fbgemm_gpu.test.test_utils import gpu_unavailable
3023

3124

fbgemm_gpu/test/quantize/msfp_test.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,13 @@
1010
import torch
1111
from hypothesis import given, HealthCheck, settings
1212

13+
from . import common # noqa E402
14+
from .common import open_source
1315

14-
try:
15-
# pyre-ignore[21]
16-
from fbgemm_gpu import open_source # noqa: F401
17-
16+
if open_source:
1817
# pyre-ignore[21]
1918
from test_utils import gpu_unavailable
20-
except Exception:
21-
if torch.version.hip:
22-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
23-
else:
24-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
25-
26-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
19+
else:
2720
from fbgemm_gpu.test.test_utils import gpu_unavailable
2821

2922

0 commit comments

Comments
 (0)