Skip to content

Commit 0523739

Browse files
q10facebook-github-bot
authored andcommitted
Fix Quantize tests (#2301)
Summary: Pull Request resolved: #2301 - Fix Quantize tests to pass again Reviewed By: jspark1105 Differential Revision: D53283205 fbshipit-source-id: f0c2c6846155222633617f0cf55805e1abd054e1
1 parent 3d7af4a commit 0523739

File tree

11 files changed

+370
-251
lines changed

11 files changed

+370
-251
lines changed

fbgemm_gpu/test/quantize/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .common import ( # noqa F401
8+
fused_rowwise_8bit_dequantize_reference,
9+
fused_rowwise_8bit_dequantize_reference_half,
10+
fused_rowwise_8bit_quantize_reference,
11+
fused_rowwise_nbit_quantize_dequantize_reference,
12+
fused_rowwise_nbit_quantize_reference,
13+
)

fbgemm_gpu/test/quantize/bfloat16_test.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,7 @@
1313
import torch
1414
from hypothesis import given, HealthCheck, settings
1515

16-
17-
try:
18-
# pyre-ignore[21]
19-
from fbgemm_gpu import open_source # noqa: F401
20-
21-
except Exception:
22-
if torch.version.hip:
23-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
24-
else:
25-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
26-
27-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
16+
from . import common # noqa E402
2817

2918

3019
class SparseNNOperatorsGPUTest(unittest.TestCase):

fbgemm_gpu/test/quantize/common.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import math, struct
8+
from typing import Callable
9+
10+
import fbgemm_gpu
11+
import numpy as np
12+
import torch
13+
14+
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
15+
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
16+
17+
try:
18+
# pyre-ignore[21]
19+
from fbgemm_gpu import open_source # noqa: F401
20+
21+
except Exception:
22+
if torch.version.hip:
23+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
24+
else:
25+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
26+
27+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
28+
29+
# Eigen/Python round 0.5 away from 0, Numpy rounds to even
30+
round_to_nearest: Callable[[np.ndarray], np.ndarray] = np.vectorize(round)
31+
32+
33+
def bytes_to_floats(byte_matrix: np.ndarray) -> np.ndarray:
34+
floats = np.empty([np.shape(byte_matrix)[0], 1], dtype=np.float32)
35+
for i, byte_values in enumerate(byte_matrix):
36+
(floats[i],) = struct.unpack("f", bytearray(byte_values))
37+
return floats
38+
39+
40+
def floats_to_bytes(floats: np.ndarray) -> np.ndarray:
41+
byte_matrix = np.empty([np.shape(floats)[0], 4], dtype=np.uint8)
42+
for i, value in enumerate(floats):
43+
assert isinstance(value, np.float32), (value, floats)
44+
as_bytes = struct.pack("f", value)
45+
# In Python3 bytes will be a list of int, in Python2 a list of string
46+
if isinstance(as_bytes[0], int):
47+
byte_matrix[i] = list(as_bytes)
48+
else:
49+
byte_matrix[i] = list(map(ord, as_bytes))
50+
return byte_matrix
51+
52+
53+
def bytes_to_half_floats(byte_matrix: np.ndarray) -> np.ndarray:
54+
floats = np.empty([np.shape(byte_matrix)[0], 1], dtype=np.float16)
55+
for i, byte_values in enumerate(byte_matrix):
56+
(floats[i],) = np.frombuffer(
57+
memoryview(byte_values).tobytes(), dtype=np.float16
58+
)
59+
return floats
60+
61+
62+
def half_floats_to_bytes(floats: np.ndarray) -> np.ndarray:
63+
byte_matrix = np.empty([np.shape(floats)[0], 2], dtype=np.uint8)
64+
for i, value in enumerate(floats):
65+
assert isinstance(value, np.float16), (value, floats)
66+
byte_matrix[i] = np.frombuffer(
67+
memoryview(value.tobytes()).tobytes(), dtype=np.uint8
68+
)
69+
return byte_matrix
70+
71+
72+
def fused_rowwise_8bit_quantize_reference(data: np.ndarray) -> np.ndarray:
73+
minimum = np.min(data, axis=-1, keepdims=True)
74+
maximum = np.max(data, axis=-1, keepdims=True)
75+
span = maximum - minimum
76+
bias = minimum
77+
scale = span / 255.0
78+
inverse_scale = 255.0 / (span + 1e-8)
79+
quantized_data = round_to_nearest((data - bias) * inverse_scale)
80+
scale_bytes = floats_to_bytes(scale.reshape(-1))
81+
scale_bytes = scale_bytes.reshape(data.shape[:-1] + (scale_bytes.shape[-1],))
82+
bias_bytes = floats_to_bytes(bias.reshape(-1))
83+
bias_bytes = bias_bytes.reshape(data.shape[:-1] + (bias_bytes.shape[-1],))
84+
return np.concatenate([quantized_data, scale_bytes, bias_bytes], axis=-1)
85+
86+
87+
def fused_rowwise_8bit_dequantize_reference(fused_quantized: np.ndarray) -> np.ndarray:
88+
scale = bytes_to_floats(fused_quantized[..., -8:-4].astype(np.uint8).reshape(-1, 4))
89+
scale = scale.reshape(fused_quantized.shape[:-1] + (scale.shape[-1],))
90+
bias = bytes_to_floats(fused_quantized[..., -4:].astype(np.uint8).reshape(-1, 4))
91+
bias = bias.reshape(fused_quantized.shape[:-1] + (bias.shape[-1],))
92+
quantized_data = fused_quantized[..., :-8]
93+
return quantized_data * scale + bias
94+
95+
96+
def fused_rowwise_8bit_dequantize_reference_half(
97+
fused_quantized: np.ndarray,
98+
) -> np.ndarray:
99+
scale = bytes_to_half_floats(
100+
fused_quantized[..., -8:-4].astype(np.uint8).reshape(-1, 4)
101+
)
102+
scale = scale.reshape(fused_quantized.shape[:-1] + (scale.shape[-1],))
103+
bias = bytes_to_half_floats(
104+
fused_quantized[..., -4:].astype(np.uint8).reshape(-1, 4)
105+
)
106+
bias = bias.reshape(fused_quantized.shape[:-1] + (bias.shape[-1],))
107+
quantized_data = fused_quantized[..., :-8]
108+
return quantized_data * scale + bias
109+
110+
111+
def fused_rowwise_nbit_quantize_reference(data: np.ndarray, bit: int) -> np.ndarray:
112+
minimum = np.min(data, axis=1).astype(np.float16).astype(np.float32)
113+
maximum = np.max(data, axis=1)
114+
span = maximum - minimum
115+
qmax = (1 << bit) - 1
116+
scale = (span / qmax).astype(np.float16).astype(np.float32)
117+
bias = np.zeros(data.shape[0])
118+
quantized_data = np.zeros(data.shape).astype(np.uint8)
119+
120+
for i in range(data.shape[0]):
121+
bias[i] = minimum[i]
122+
inverse_scale = 1.0 if scale[i] == 0.0 else 1 / scale[i]
123+
if scale[i] == 0.0 or math.isinf(inverse_scale):
124+
scale[i] = 1.0
125+
inverse_scale = 1.0
126+
quantized_data[i] = np.clip(
127+
np.round((data[i, :] - minimum[i]) * inverse_scale), 0, qmax
128+
)
129+
130+
# pack
131+
assert 8 % bit == 0
132+
num_elem_per_byte = 8 // bit
133+
packed_dim = (data.shape[1] + num_elem_per_byte - 1) // num_elem_per_byte
134+
packed_data = np.zeros([data.shape[0], packed_dim]).astype(np.uint8)
135+
for i in range(data.shape[0]):
136+
for j in range(data.shape[1]):
137+
if j % num_elem_per_byte == 0:
138+
packed_data[i, j // num_elem_per_byte] = quantized_data[i, j]
139+
else:
140+
packed_data[i, j // num_elem_per_byte] += quantized_data[i, j] << (
141+
(j % num_elem_per_byte) * bit
142+
)
143+
144+
scale_bytes = half_floats_to_bytes(scale.astype(np.float16))
145+
bias_bytes = half_floats_to_bytes(bias.astype(np.float16))
146+
return np.concatenate([packed_data, scale_bytes, bias_bytes], axis=1)
147+
148+
149+
def fused_rowwise_nbit_quantize_dequantize_reference(
150+
data: np.ndarray, bit: int
151+
) -> np.ndarray:
152+
fused_quantized = fused_rowwise_nbit_quantize_reference(data, bit)
153+
scale = bytes_to_half_floats(fused_quantized[:, -4:-2].astype(np.uint8)).astype(
154+
np.float32
155+
)
156+
bias = bytes_to_half_floats(fused_quantized[:, -2:].astype(np.uint8)).astype(
157+
np.float32
158+
)
159+
quantized_data = fused_quantized[:, :-4]
160+
161+
# unpack
162+
packed_dim = fused_quantized.shape[1] - 4
163+
assert 8 % bit == 0
164+
num_elem_per_byte = 8 // bit
165+
assert packed_dim == ((data.shape[1] + num_elem_per_byte - 1) // num_elem_per_byte)
166+
unpacked_data = np.zeros(data.shape).astype(np.uint8)
167+
for i in range(data.shape[0]):
168+
for j in range(data.shape[1]):
169+
unpacked_data[i, j] = (
170+
quantized_data[i, j // num_elem_per_byte]
171+
>> ((j % num_elem_per_byte) * bit)
172+
) & ((1 << bit) - 1)
173+
174+
return scale * unpacked_data + bias

fbgemm_gpu/test/quantize/fp8_rowwise_test.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,13 @@
1515
from fbgemm_gpu.split_embedding_configs import SparseType
1616
from hypothesis import given, settings, Verbosity
1717

18+
from . import common # noqa E402
19+
from .common import open_source
1820

19-
try:
21+
if open_source:
2022
# pyre-ignore[21]
21-
from fbgemm_gpu import open_source # noqa: F401
22-
23-
# pyre-ignore[21]
24-
from test_utils import ( # noqa: F401
25-
gpu_unavailable,
26-
optests,
27-
symint_vector_unsupported,
28-
)
29-
except Exception:
30-
if torch.version.hip:
31-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
32-
else:
33-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
34-
35-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
23+
from test_utils import gpu_unavailable, optests, symint_vector_unsupported
24+
else:
3625
from fbgemm_gpu.test.test_utils import (
3726
gpu_unavailable,
3827
optests,

fbgemm_gpu/test/quantize/fused_8bit_rowwise_test.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,18 @@
1212
from fbgemm_gpu.split_embedding_configs import SparseType
1313
from hypothesis import assume, given, HealthCheck, settings
1414

15-
try:
16-
# pyre-ignore[21]
17-
from fbgemm_gpu import open_source # noqa: F401
15+
from . import common # noqa E402
16+
from .common import (
17+
fused_rowwise_8bit_dequantize_reference,
18+
fused_rowwise_8bit_quantize_reference,
19+
open_source,
20+
)
1821

22+
if open_source:
1923
# pyre-ignore[21]
20-
from test_utils import (
21-
fused_rowwise_8bit_dequantize_reference,
22-
fused_rowwise_8bit_quantize_reference,
23-
gpu_available,
24-
)
25-
except Exception:
26-
if torch.version.hip:
27-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
28-
else:
29-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
30-
31-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
32-
33-
from fbgemm_gpu.test.test_utils import (
34-
fused_rowwise_8bit_dequantize_reference,
35-
fused_rowwise_8bit_quantize_reference,
36-
gpu_available,
37-
)
24+
from test_utils import gpu_available
25+
else:
26+
from fbgemm_gpu.test.test_utils import gpu_available
3827

3928
no_long_tests: bool = False
4029

fbgemm_gpu/test/quantize/fused_nbit_rowwise_test.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,20 @@
1212
from fbgemm_gpu.split_embedding_configs import SparseType
1313
from hypothesis import assume, given, HealthCheck, settings
1414

15-
try:
16-
# pyre-ignore[21]
17-
from fbgemm_gpu import open_source # noqa: F401
15+
from . import common # noqa E402
16+
from .common import (
17+
bytes_to_half_floats,
18+
fused_rowwise_nbit_quantize_dequantize_reference,
19+
fused_rowwise_nbit_quantize_reference,
20+
open_source,
21+
)
1822

23+
if open_source:
1924
# pyre-ignore[21]
20-
from test_utils import (
21-
bytes_to_half_floats,
22-
fused_rowwise_nbit_quantize_dequantize_reference,
23-
fused_rowwise_nbit_quantize_reference,
24-
gpu_available,
25-
)
26-
except Exception:
27-
if torch.version.hip:
28-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
29-
else:
30-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
25+
from test_utils import gpu_available
26+
else:
27+
from fbgemm_gpu.test.test_utils import gpu_available
3128

32-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
33-
from fbgemm_gpu.test.test_utils import (
34-
bytes_to_half_floats,
35-
fused_rowwise_nbit_quantize_dequantize_reference,
36-
fused_rowwise_nbit_quantize_reference,
37-
gpu_available,
38-
)
3929

4030
no_long_tests: bool = False
4131

fbgemm_gpu/test/quantize/hfp8_test.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,7 @@
1212
from hypothesis import given, HealthCheck, settings
1313
from torch import Tensor
1414

15-
16-
try:
17-
# pyre-ignore[21]
18-
from fbgemm_gpu import open_source # noqa: F401
19-
20-
except Exception:
21-
if torch.version.hip:
22-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
23-
else:
24-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
25-
26-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
15+
from . import common # noqa E402
2716

2817

2918
class TestHFP8QuantizationConversion(unittest.TestCase):

fbgemm_gpu/test/quantize/mixed_dim_int8_test.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,13 @@
1212
from fbgemm_gpu.split_embedding_configs import SparseType
1313
from hypothesis import given, HealthCheck, settings
1414

15+
from . import common # noqa E402
16+
from .common import open_source
1517

16-
try:
17-
# pyre-ignore[21]
18-
from fbgemm_gpu import open_source # noqa: F401
19-
18+
if open_source:
2019
# pyre-ignore[21]
2120
from test_utils import gpu_unavailable
22-
except Exception:
23-
if torch.version.hip:
24-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
25-
else:
26-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
27-
28-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
21+
else:
2922
from fbgemm_gpu.test.test_utils import gpu_unavailable
3023

3124

fbgemm_gpu/test/quantize/msfp_test.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,13 @@
1010
import torch
1111
from hypothesis import given, HealthCheck, settings
1212

13+
from . import common # noqa E402
14+
from .common import open_source
1315

14-
try:
15-
# pyre-ignore[21]
16-
from fbgemm_gpu import open_source # noqa: F401
17-
16+
if open_source:
1817
# pyre-ignore[21]
1918
from test_utils import gpu_unavailable
20-
except Exception:
21-
if torch.version.hip:
22-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip")
23-
else:
24-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
25-
26-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
19+
else:
2720
from fbgemm_gpu.test.test_utils import gpu_unavailable
2821

2922

fbgemm_gpu/test/tbe/cache_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-ignore-all-errors[56]
99

1010
import unittest
11-
from typing import Callable, Dict, List, Optional, Tuple
11+
from typing import Optional, Tuple
1212

1313
import hypothesis.strategies as st
1414
import numpy as np

0 commit comments

Comments
 (0)