From d28a4f68c81fb69404e8f2eddc7f5b9bbbd3c4bf Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 24 Apr 2024 00:40:42 -0700 Subject: [PATCH] Fix quantization of all 0s (#1028) --- mlx/ops.cpp | 5 ++++- python/tests/test_quantized.py | 8 ++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 34903f107b..6f1b0b7967 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3300,7 +3300,10 @@ std::tuple quantize( reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s); array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); - array delta = divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s); + array delta = maximum( + divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s), + array(1e-7, w.dtype()), + s); array scales = squeeze(delta, -1, s); array biases = squeeze(w_min, -1, s); diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 60e036d69c..6e30bac5ef 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -18,6 +18,14 @@ def test_quantize_dequantize(self): eps = 1e-6 self.assertTrue((errors <= (scales[..., None] + eps)).all()) + # test quantize/dequantize 0s + a = mx.zeros((256, 512)) + for gs in [32, 64, 128]: + for b in [2, 4, 8]: + w_q, scales, biases = mx.quantize(a, gs, b) + a_hat = mx.dequantize(w_q, scales, biases, gs, b) + self.assertTrue(mx.all(a_hat == 0)) + def test_qmm(self): key = mx.random.key(0) k1, k2 = mx.random.split(key)