Skip to content

Commit

Permalink
add equal_nan=False for assert_equal (#225)
Browse files Browse the repository at this point in the history
* add equal_nan=False for assert_equal, which is needed for testing for equality of floating point arrays

1. fix topk and unique, return int64 indices;
2. avoid upcasting to float64 in test cases for operators without real computation;
3. fix floor divide: use the same logic as CPython, numpy and torch.
4. return tensor when all operand are python scalars
5. add scalar_scalar test for binary ops, note that only float32 & int64 are tested, since they are what supported in that cases
  • Loading branch information
iclementine authored Oct 8, 2024
1 parent 0534d4d commit e51b0b1
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/flag_gems/ops/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ def add(A, B, *, alpha=1):
elif isinstance(B, torch.Tensor):
return add_func_scalar_tensor(A, B, alpha)
else:
return A + B * alpha
return torch.tensor(A + B * alpha)
52 changes: 42 additions & 10 deletions src/flag_gems/ops/div.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from ..utils import pointwise_dynamic

try:
from triton.language.extra.cuda.libdevice import div_rd, div_rz, trunc
from triton.language.extra.cuda.libdevice import div_rn, div_rz, fmod, trunc
except ImportError:
try:
from triton.language.math import div_rd, div_rz, trunc
from triton.language.math import div_rn, div_rz, fmod, trunc
except ImportError:
from triton.language.libdevice import div_rd, div_rz, trunc
from triton.language.libdevice import div_rn, div_rz, fmod, trunc


@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")])
Expand Down Expand Up @@ -43,7 +43,7 @@ def true_divide(A, B):
return true_div_func_scalar_tensor(A, B)
else:
# Both scalar
return A / B
return torch.tensor(A / B)


@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
Expand Down Expand Up @@ -74,7 +74,7 @@ def trunc_divide(A, B):
return trunc_div_func_scalar_tensor(A, B)
else:
# Both scalar
return A / B
return torch.tensor(A / B)


@triton.jit
Expand All @@ -98,13 +98,45 @@ def _int_floordiv(x, y):
return tl.where(c1 & c2, x // y - 1, x // y)


# TO be consistent with python, numpy and torch, we have to implement it in the
# following way.
# CPython
# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
# numpy
# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532
# torch
# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23
@triton.jit
def _float_floordiv(x, y):
# NOTE: fmod's sign is the same as the dividend
remainder = fmod(x, y)
imperfect = remainder != 0.0
different_sign = (x < 0) ^ (y < 0)

# NOTE: we have to use div_rn explicitly here
q = div_rn(x - remainder, y)
q = tl.where(imperfect & different_sign, q - 1, q)

floor_q = tl.math.floor(q)
c = q - floor_q > 0.5
floor_q = tl.where(c, floor_q + 1.0, floor_q)

q_is_zeros = q == 0.0
floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q)

is_div_by_zero = y == 0.0
float_division = x / y
out = tl.where(is_div_by_zero, float_division, floor_q)
return out


@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def floor_div_func(x, y):
if x.type.scalar.is_int() & x.type.scalar.is_int():
return _int_floordiv(x, y)
else:
return tl.math.floor(div_rd(x, y))
return _float_floordiv(x, y)


@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
Expand All @@ -113,7 +145,7 @@ def floor_div_func_tensor_scalar(x, y):
if x.type.scalar.is_int() & x.type.scalar.is_int():
return _int_floordiv(x, y)
else:
return tl.math.floor(div_rd(x, y))
return _float_floordiv(x, y)


@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
Expand All @@ -122,7 +154,7 @@ def floor_div_func_scalar_tensor(x, y):
if x.type.scalar.is_int() & x.type.scalar.is_int():
return _int_floordiv(x, y)
else:
return tl.math.floor(div_rd(x, y))
return _float_floordiv(x, y)


def floor_divide(A, B):
Expand All @@ -135,7 +167,7 @@ def floor_divide(A, B):
return floor_div_func_scalar_tensor(A, B)
else:
# Both scalar
return A // B
return torch.tensor(A // B)


def div_mode(A, B, rounding_mode=None):
Expand Down Expand Up @@ -186,4 +218,4 @@ def remainder(A, B):
return rem_st(A, B)
else:
# Both scalar
return A % B
return torch.tensor(A % B)
2 changes: 1 addition & 1 deletion src/flag_gems/ops/mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ def mul(A, B):
return mul_func_scalar(B, A)
else:
# Both scalar
return A * B
return torch.tensor(A * B)
2 changes: 1 addition & 1 deletion src/flag_gems/ops/sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ def sub(A, B, *, alpha=1):
return sub_func_scalar_tensor(A, B, alpha)
else:
# Both scalar
return A - B * alpha
return torch.tensor(A - B * alpha)
4 changes: 2 additions & 2 deletions src/flag_gems/ops/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,12 @@ def topk(x, k, dim=-1, largest=True, sorted=True):

stage1_out = torch.empty(batch_size * chunk_num * k, device=x.device, dtype=x.dtype)
stage1_out_idx = torch.empty(
batch_size * chunk_num * k, device=x.device, dtype=torch.int32
batch_size * chunk_num * k, device=x.device, dtype=torch.int64
)

out_shape = x.shape[:-1] + (k,)
stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype)
stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int32)
stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64)

with torch.cuda.device(x.device):
topk_stage1_kernel[
Expand Down
14 changes: 7 additions & 7 deletions src/flag_gems/ops/unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,15 +373,15 @@ def sorted_quick_unique_flat(sorted_data: torch.Tensor, return_counts: bool):
# allocate tensor
if return_counts:
local_unique = None
origin_idx = torch.empty_like(sorted_data, dtype=torch.int32)
origin_idx = torch.empty_like(sorted_data, dtype=torch.int64)
idx = torch.empty_like(origin_idx)
else:
local_unique = torch.empty_like(sorted_data)
origin_idx = None
idx = None
counts = None
tile_sum = torch.empty(
(global_ctas_num,), dtype=torch.int32, device=sorted_data.device
(global_ctas_num,), dtype=torch.int64, device=sorted_data.device
)
data_out = None
if not return_counts:
Expand Down Expand Up @@ -654,10 +654,10 @@ def sorted_indices_unique_flat(
# allocate tensor
ne_result = torch.empty_like(sorted_data, dtype=torch.bool)
tile_sum = torch.empty(
(global_ctas_num,), dtype=torch.int32, device=sorted_data.device
(global_ctas_num,), dtype=torch.int64, device=sorted_data.device
)
data_out = torch.empty_like(sorted_data)
inverse_indices = torch.empty_like(sorted_data, dtype=torch.int32)
inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64)
idx = None
if return_counts:
idx = torch.empty_like(inverse_indices)
Expand Down Expand Up @@ -722,14 +722,14 @@ def simple_unique_flat(
# allocate tensor
data_out = torch.empty_like(sorted_data)
if return_inverse:
inverse_indices = torch.empty_like(sorted_data, dtype=torch.int32)
inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64)
else:
inverse_indices = None
if return_counts:
idx = torch.empty_like(sorted_data, dtype=torch.int32)
idx = torch.empty_like(sorted_data, dtype=torch.int64)
else:
idx = None
unique_size = torch.empty([1], dtype=torch.int32, device=sorted_data.device)
unique_size = torch.empty([1], dtype=torch.int64, device=sorted_data.device)

# launch kernel
with torch.cuda.device(sorted_data.device.index):
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan)


def assert_equal(res, ref):
assert torch.equal(res, ref)
def assert_equal(res, ref, equal_nan=False):
torch.testing.assert_close(res, ref, atol=0, rtol=0, equal_nan=equal_nan)
4 changes: 2 additions & 2 deletions tests/accuracy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def gems_assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
)


def gems_assert_equal(res, ref):
def gems_assert_equal(res, ref, equal_nan=False):
res = to_cpu(res, ref)
flag_gems.testing.assert_equal(res, ref)
flag_gems.testing.assert_equal(res, ref, equal_nan=equal_nan)


def unsqueeze_tuple(t, max_len):
Expand Down
127 changes: 126 additions & 1 deletion tests/test_binary_pointwise_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import random

import numpy as np
import pytest
import torch

Expand Down Expand Up @@ -76,6 +77,28 @@ def test_accuracy_add_scalar_tensor(shape, scalar, alpha, dtype):
gems_assert_close(res_out, ref_out, dtype)


@pytest.mark.add
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
def test_accuracy_add_scalar_scalar(dtype):
if dtype == torch.float32:
inp1 = float(np.float32(random.random()))
inp2 = float(np.float32(random.random()))
alpha = float(np.float32(random.random()))
else:
inp1 = random.randint(0, 100)
inp2 = random.randint(0, 100)
alpha = random.randint(0, 100)

ref_out = torch.add(inp1, inp2, alpha=alpha)
with flag_gems.use_gems():
res_out = torch.add(inp1, inp2, alpha=alpha)

if dtype == torch.int64:
gems_assert_equal(res_out, ref_out)
else:
gems_assert_close(res_out, ref_out, dtype)


@pytest.mark.bitwise_and
@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", INT_DTYPES + BOOL_TYPES)
Expand Down Expand Up @@ -303,6 +326,26 @@ def test_accuracy_div_scalar_tensor(shape, scalar, dtype):
gems_assert_close(res_out, ref_out, dtype, equal_nan=True)


@pytest.mark.div
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
def test_accuracy_div_scalar_scalar(dtype):
if dtype == torch.float32:
inp1 = float(np.float32(random.random() + 0.01))
inp2 = float(np.float32(random.random() + 0.01))
else:
inp1 = random.randint(1, 100)
inp2 = random.randint(1, 100)

ref_out = torch.mul(inp1, inp2)
with flag_gems.use_gems():
res_out = torch.mul(inp1, inp2)

if dtype == torch.int64:
gems_assert_equal(res_out, ref_out)
else:
gems_assert_close(res_out, ref_out, dtype)


@pytest.mark.trunc_divide
@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", [torch.float32])
Expand All @@ -326,6 +369,26 @@ def test_accuracy_trunc_div(shape, dtype):
gems_assert_close(res_out, ref_out, dtype, equal_nan=True)


@pytest.mark.trunc_divide
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
def test_accuracy_trunc_divide_scalar_scalar(dtype):
if dtype == torch.float32:
inp1 = float(np.float32(random.random() + 0.01))
inp2 = float(np.float32(random.random() + 0.01))
else:
inp1 = random.randint(1, 100)
inp2 = random.randint(1, 100)

ref_out = torch.div(inp1, inp2, rounding_mode="trunc")
with flag_gems.use_gems():
res_out = torch.div(inp1, inp2, rounding_mode="trunc")

if dtype == torch.int64:
gems_assert_equal(res_out, ref_out)
else:
gems_assert_close(res_out, ref_out, dtype)


# TODO: failed at large size, eg. (65536 * 2048,)
@pytest.mark.floor_divide
@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
Expand All @@ -340,7 +403,7 @@ def test_accuracy_floor_div_float(shape, dtype):
with flag_gems.use_gems():
res_out = torch.div(inp1, inp2, rounding_mode="floor")

gems_assert_equal(res_out, ref_out)
gems_assert_equal(res_out, ref_out, equal_nan=True)


@pytest.mark.floor_divide
Expand Down Expand Up @@ -386,6 +449,26 @@ def test_accuracy_floor_div_int(shape, dtype):
gems_assert_equal(res_out, ref_out)


@pytest.mark.floor_divide
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
def test_accuracy_floor_divide_scalar_scalar(dtype):
if dtype == torch.float32:
inp1 = float(np.float32(random.random() + 0.01))
inp2 = float(np.float32(random.random() + 0.01))
else:
inp1 = random.randint(1, 100)
inp2 = random.randint(1, 100)

ref_out = torch.floor_divide(inp1, inp2)
with flag_gems.use_gems():
res_out = torch.floor_divide(inp1, inp2)

if dtype == torch.int64:
gems_assert_equal(res_out, ref_out)
else:
gems_assert_close(res_out, ref_out, dtype)


@pytest.mark.remainder
@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", INT_DTYPES)
Expand Down Expand Up @@ -651,6 +734,26 @@ def test_accuracy_mul_scalar_tensor(shape, scalar, dtype):
gems_assert_close(res_out, ref_out, dtype)


@pytest.mark.mul
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
def test_accuracy_mul_scalar_scalar(dtype):
if dtype == torch.float32:
inp1 = float(np.float32(random.random()))
inp2 = float(np.float32(random.random()))
else:
inp1 = random.randint(0, 100)
inp2 = random.randint(0, 100)

ref_out = torch.mul(inp1, inp2)
with flag_gems.use_gems():
res_out = torch.mul(inp1, inp2)

if dtype == torch.int64:
gems_assert_equal(res_out, ref_out)
else:
gems_assert_close(res_out, ref_out, dtype)


@pytest.mark.ne
@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
Expand Down Expand Up @@ -846,6 +949,28 @@ def test_accuracy_sub_scalar_tensor(shape, scalar, alpha, dtype):
gems_assert_close(res_out, ref_out, dtype)


@pytest.mark.sub
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
def test_accuracy_sub_scalar_scalar(dtype):
if dtype == torch.float32:
inp1 = float(np.float32(random.random()))
inp2 = float(np.float32(random.random()))
alpha = float(np.float32(random.random()))
else:
inp1 = random.randint(0, 100)
inp2 = random.randint(0, 100)
alpha = random.randint(0, 100)

ref_out = torch.sub(inp1, inp2, alpha=alpha)
with flag_gems.use_gems():
res_out = torch.sub(inp1, inp2, alpha=alpha)

if dtype == torch.int64:
gems_assert_equal(res_out, ref_out)
else:
gems_assert_close(res_out, ref_out, dtype)


@pytest.mark.where
@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
Expand Down
Loading

0 comments on commit e51b0b1

Please sign in to comment.