diff --git a/src/flag_gems/ops/add.py b/src/flag_gems/ops/add.py index d7e52385..9e4a816f 100644 --- a/src/flag_gems/ops/add.py +++ b/src/flag_gems/ops/add.py @@ -37,4 +37,4 @@ def add(A, B, *, alpha=1): elif isinstance(B, torch.Tensor): return add_func_scalar_tensor(A, B, alpha) else: - return A + B * alpha + return torch.tensor(A + B * alpha) diff --git a/src/flag_gems/ops/div.py b/src/flag_gems/ops/div.py index f28ce164..2cf0eb2c 100644 --- a/src/flag_gems/ops/div.py +++ b/src/flag_gems/ops/div.py @@ -7,12 +7,12 @@ from ..utils import pointwise_dynamic try: - from triton.language.extra.cuda.libdevice import div_rd, div_rz, trunc + from triton.language.extra.cuda.libdevice import div_rn, div_rz, fmod, trunc except ImportError: try: - from triton.language.math import div_rd, div_rz, trunc + from triton.language.math import div_rn, div_rz, fmod, trunc except ImportError: - from triton.language.libdevice import div_rd, div_rz, trunc + from triton.language.libdevice import div_rn, div_rz, fmod, trunc @pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) @@ -43,7 +43,7 @@ def true_divide(A, B): return true_div_func_scalar_tensor(A, B) else: # Both scalar - return A / B + return torch.tensor(A / B) @pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @@ -74,7 +74,7 @@ def trunc_divide(A, B): return trunc_div_func_scalar_tensor(A, B) else: # Both scalar - return A / B + return torch.tensor(A / B) @triton.jit @@ -98,13 +98,45 @@ def _int_floordiv(x, y): return tl.where(c1 & c2, x // y - 1, x // y) +# TO be consistent with python, numpy and torch, we have to implement it in the +# following way. +# CPython +# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 +# numpy +# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532 +# torch +# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23 +@triton.jit +def _float_floordiv(x, y): + # NOTE: fmod's sign is the same as the dividend + remainder = fmod(x, y) + imperfect = remainder != 0.0 + different_sign = (x < 0) ^ (y < 0) + + # NOTE: we have to use div_rn explicitly here + q = div_rn(x - remainder, y) + q = tl.where(imperfect & different_sign, q - 1, q) + + floor_q = tl.math.floor(q) + c = q - floor_q > 0.5 + floor_q = tl.where(c, floor_q + 1.0, floor_q) + + q_is_zeros = q == 0.0 + floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q) + + is_div_by_zero = y == 0.0 + float_division = x / y + out = tl.where(is_div_by_zero, float_division, floor_q) + return out + + @pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def floor_div_func(x, y): if x.type.scalar.is_int() & x.type.scalar.is_int(): return _int_floordiv(x, y) else: - return tl.math.floor(div_rd(x, y)) + return _float_floordiv(x, y) @pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) @@ -113,7 +145,7 @@ def floor_div_func_tensor_scalar(x, y): if x.type.scalar.is_int() & x.type.scalar.is_int(): return _int_floordiv(x, y) else: - return tl.math.floor(div_rd(x, y)) + return _float_floordiv(x, y) @pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) @@ -122,7 +154,7 @@ def floor_div_func_scalar_tensor(x, y): if x.type.scalar.is_int() & x.type.scalar.is_int(): return _int_floordiv(x, y) else: - return tl.math.floor(div_rd(x, y)) + return _float_floordiv(x, y) def floor_divide(A, B): @@ -135,7 +167,7 @@ def floor_divide(A, B): return floor_div_func_scalar_tensor(A, B) else: # Both scalar - return A // B + return torch.tensor(A // B) def div_mode(A, B, rounding_mode=None): @@ -186,4 +218,4 @@ def remainder(A, B): return rem_st(A, B) else: # Both scalar - return A % B + return torch.tensor(A % B) diff --git a/src/flag_gems/ops/mul.py b/src/flag_gems/ops/mul.py index d33d1f83..9d793b6d 100644 --- a/src/flag_gems/ops/mul.py +++ b/src/flag_gems/ops/mul.py @@ -28,4 +28,4 @@ def mul(A, B): return mul_func_scalar(B, A) else: # Both scalar - return A * B + return torch.tensor(A * B) diff --git a/src/flag_gems/ops/sub.py b/src/flag_gems/ops/sub.py index c62faf05..f38986f4 100644 --- a/src/flag_gems/ops/sub.py +++ b/src/flag_gems/ops/sub.py @@ -38,4 +38,4 @@ def sub(A, B, *, alpha=1): return sub_func_scalar_tensor(A, B, alpha) else: # Both scalar - return A - B * alpha + return torch.tensor(A - B * alpha) diff --git a/src/flag_gems/ops/topk.py b/src/flag_gems/ops/topk.py index 53d6350d..43995b17 100644 --- a/src/flag_gems/ops/topk.py +++ b/src/flag_gems/ops/topk.py @@ -261,12 +261,12 @@ def topk(x, k, dim=-1, largest=True, sorted=True): stage1_out = torch.empty(batch_size * chunk_num * k, device=x.device, dtype=x.dtype) stage1_out_idx = torch.empty( - batch_size * chunk_num * k, device=x.device, dtype=torch.int32 + batch_size * chunk_num * k, device=x.device, dtype=torch.int64 ) out_shape = x.shape[:-1] + (k,) stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype) - stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int32) + stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64) with torch.cuda.device(x.device): topk_stage1_kernel[ diff --git a/src/flag_gems/ops/unique.py b/src/flag_gems/ops/unique.py index d0533752..caadeae0 100644 --- a/src/flag_gems/ops/unique.py +++ b/src/flag_gems/ops/unique.py @@ -373,7 +373,7 @@ def sorted_quick_unique_flat(sorted_data: torch.Tensor, return_counts: bool): # allocate tensor if return_counts: local_unique = None - origin_idx = torch.empty_like(sorted_data, dtype=torch.int32) + origin_idx = torch.empty_like(sorted_data, dtype=torch.int64) idx = torch.empty_like(origin_idx) else: local_unique = torch.empty_like(sorted_data) @@ -381,7 +381,7 @@ def sorted_quick_unique_flat(sorted_data: torch.Tensor, return_counts: bool): idx = None counts = None tile_sum = torch.empty( - (global_ctas_num,), dtype=torch.int32, device=sorted_data.device + (global_ctas_num,), dtype=torch.int64, device=sorted_data.device ) data_out = None if not return_counts: @@ -654,10 +654,10 @@ def sorted_indices_unique_flat( # allocate tensor ne_result = torch.empty_like(sorted_data, dtype=torch.bool) tile_sum = torch.empty( - (global_ctas_num,), dtype=torch.int32, device=sorted_data.device + (global_ctas_num,), dtype=torch.int64, device=sorted_data.device ) data_out = torch.empty_like(sorted_data) - inverse_indices = torch.empty_like(sorted_data, dtype=torch.int32) + inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64) idx = None if return_counts: idx = torch.empty_like(inverse_indices) @@ -722,14 +722,14 @@ def simple_unique_flat( # allocate tensor data_out = torch.empty_like(sorted_data) if return_inverse: - inverse_indices = torch.empty_like(sorted_data, dtype=torch.int32) + inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64) else: inverse_indices = None if return_counts: - idx = torch.empty_like(sorted_data, dtype=torch.int32) + idx = torch.empty_like(sorted_data, dtype=torch.int64) else: idx = None - unique_size = torch.empty([1], dtype=torch.int32, device=sorted_data.device) + unique_size = torch.empty([1], dtype=torch.int64, device=sorted_data.device) # launch kernel with torch.cuda.device(sorted_data.device.index): diff --git a/src/flag_gems/testing/__init__.py b/src/flag_gems/testing/__init__.py index 47ea9138..765051a8 100644 --- a/src/flag_gems/testing/__init__.py +++ b/src/flag_gems/testing/__init__.py @@ -18,5 +18,5 @@ def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1): torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan) -def assert_equal(res, ref): - assert torch.equal(res, ref) +def assert_equal(res, ref, equal_nan=False): + torch.testing.assert_close(res, ref, atol=0, rtol=0, equal_nan=equal_nan) diff --git a/tests/accuracy_utils.py b/tests/accuracy_utils.py index 2b763324..9c51c4f9 100644 --- a/tests/accuracy_utils.py +++ b/tests/accuracy_utils.py @@ -96,9 +96,9 @@ def gems_assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1): ) -def gems_assert_equal(res, ref): +def gems_assert_equal(res, ref, equal_nan=False): res = to_cpu(res, ref) - flag_gems.testing.assert_equal(res, ref) + flag_gems.testing.assert_equal(res, ref, equal_nan=equal_nan) def unsqueeze_tuple(t, max_len): diff --git a/tests/test_binary_pointwise_ops.py b/tests/test_binary_pointwise_ops.py index 19eef4af..6eb6ecb6 100644 --- a/tests/test_binary_pointwise_ops.py +++ b/tests/test_binary_pointwise_ops.py @@ -1,6 +1,7 @@ import logging import random +import numpy as np import pytest import torch @@ -76,6 +77,28 @@ def test_accuracy_add_scalar_tensor(shape, scalar, alpha, dtype): gems_assert_close(res_out, ref_out, dtype) +@pytest.mark.add +@pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) +def test_accuracy_add_scalar_scalar(dtype): + if dtype == torch.float32: + inp1 = float(np.float32(random.random())) + inp2 = float(np.float32(random.random())) + alpha = float(np.float32(random.random())) + else: + inp1 = random.randint(0, 100) + inp2 = random.randint(0, 100) + alpha = random.randint(0, 100) + + ref_out = torch.add(inp1, inp2, alpha=alpha) + with flag_gems.use_gems(): + res_out = torch.add(inp1, inp2, alpha=alpha) + + if dtype == torch.int64: + gems_assert_equal(res_out, ref_out) + else: + gems_assert_close(res_out, ref_out, dtype) + + @pytest.mark.bitwise_and @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", INT_DTYPES + BOOL_TYPES) @@ -303,6 +326,26 @@ def test_accuracy_div_scalar_tensor(shape, scalar, dtype): gems_assert_close(res_out, ref_out, dtype, equal_nan=True) +@pytest.mark.div +@pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) +def test_accuracy_div_scalar_scalar(dtype): + if dtype == torch.float32: + inp1 = float(np.float32(random.random() + 0.01)) + inp2 = float(np.float32(random.random() + 0.01)) + else: + inp1 = random.randint(1, 100) + inp2 = random.randint(1, 100) + + ref_out = torch.mul(inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.mul(inp1, inp2) + + if dtype == torch.int64: + gems_assert_equal(res_out, ref_out) + else: + gems_assert_close(res_out, ref_out, dtype) + + @pytest.mark.trunc_divide @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", [torch.float32]) @@ -326,6 +369,26 @@ def test_accuracy_trunc_div(shape, dtype): gems_assert_close(res_out, ref_out, dtype, equal_nan=True) +@pytest.mark.trunc_divide +@pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) +def test_accuracy_trunc_divide_scalar_scalar(dtype): + if dtype == torch.float32: + inp1 = float(np.float32(random.random() + 0.01)) + inp2 = float(np.float32(random.random() + 0.01)) + else: + inp1 = random.randint(1, 100) + inp2 = random.randint(1, 100) + + ref_out = torch.div(inp1, inp2, rounding_mode="trunc") + with flag_gems.use_gems(): + res_out = torch.div(inp1, inp2, rounding_mode="trunc") + + if dtype == torch.int64: + gems_assert_equal(res_out, ref_out) + else: + gems_assert_close(res_out, ref_out, dtype) + + # TODO: failed at large size, eg. (65536 * 2048,) @pytest.mark.floor_divide @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @@ -340,7 +403,7 @@ def test_accuracy_floor_div_float(shape, dtype): with flag_gems.use_gems(): res_out = torch.div(inp1, inp2, rounding_mode="floor") - gems_assert_equal(res_out, ref_out) + gems_assert_equal(res_out, ref_out, equal_nan=True) @pytest.mark.floor_divide @@ -386,6 +449,26 @@ def test_accuracy_floor_div_int(shape, dtype): gems_assert_equal(res_out, ref_out) +@pytest.mark.floor_divide +@pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) +def test_accuracy_floor_divide_scalar_scalar(dtype): + if dtype == torch.float32: + inp1 = float(np.float32(random.random() + 0.01)) + inp2 = float(np.float32(random.random() + 0.01)) + else: + inp1 = random.randint(1, 100) + inp2 = random.randint(1, 100) + + ref_out = torch.floor_divide(inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.floor_divide(inp1, inp2) + + if dtype == torch.int64: + gems_assert_equal(res_out, ref_out) + else: + gems_assert_close(res_out, ref_out, dtype) + + @pytest.mark.remainder @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", INT_DTYPES) @@ -651,6 +734,26 @@ def test_accuracy_mul_scalar_tensor(shape, scalar, dtype): gems_assert_close(res_out, ref_out, dtype) +@pytest.mark.mul +@pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) +def test_accuracy_mul_scalar_scalar(dtype): + if dtype == torch.float32: + inp1 = float(np.float32(random.random())) + inp2 = float(np.float32(random.random())) + else: + inp1 = random.randint(0, 100) + inp2 = random.randint(0, 100) + + ref_out = torch.mul(inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.mul(inp1, inp2) + + if dtype == torch.int64: + gems_assert_equal(res_out, ref_out) + else: + gems_assert_close(res_out, ref_out, dtype) + + @pytest.mark.ne @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) @@ -846,6 +949,28 @@ def test_accuracy_sub_scalar_tensor(shape, scalar, alpha, dtype): gems_assert_close(res_out, ref_out, dtype) +@pytest.mark.sub +@pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) +def test_accuracy_sub_scalar_scalar(dtype): + if dtype == torch.float32: + inp1 = float(np.float32(random.random())) + inp2 = float(np.float32(random.random())) + alpha = float(np.float32(random.random())) + else: + inp1 = random.randint(0, 100) + inp2 = random.randint(0, 100) + alpha = random.randint(0, 100) + + ref_out = torch.sub(inp1, inp2, alpha=alpha) + with flag_gems.use_gems(): + res_out = torch.sub(inp1, inp2, alpha=alpha) + + if dtype == torch.int64: + gems_assert_equal(res_out, ref_out) + else: + gems_assert_close(res_out, ref_out, dtype) + + @pytest.mark.where @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) diff --git a/tests/test_special_ops.py b/tests/test_special_ops.py index f60fecdb..dc70f123 100644 --- a/tests/test_special_ops.py +++ b/tests/test_special_ops.py @@ -544,7 +544,7 @@ def test_accuracy_stack(shape, dim, dtype): ) for s in shape ] - ref_inp = [to_reference(_, True) for _ in inp] + ref_inp = [to_reference(_) for _ in inp] ref_out = torch.stack(ref_inp, dim) with flag_gems.use_gems(): @@ -646,7 +646,7 @@ def test_accuracy_cat(shape, dim, dtype): ) for s in shape ] - ref_inp = [to_reference(_, True) for _ in inp] + ref_inp = [to_reference(_) for _ in inp] ref_out = torch.cat(ref_inp, dim) with flag_gems.use_gems():