diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index 73dc5ad3..3a2bebc6 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -113,7 +113,28 @@ def nonzero_args(dtype, batch, size): op_name="nonzero", torch_op=torch.nonzero, arg_func=nonzero_args, - dtypes=FLOAT_DTYPES + INT_DTYPES + [torch.bool], + dtypes=FLOAT_DTYPES, + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + +def test_perf_nonzero_int(): + def nonzero_args(dtype, batch, size): + if dtype == torch.bool: + inp = torch.randint(0, 2, [batch, size], dtype=torch.int, device="cuda").to( + torch.bool + ) + else: + inp = torch.randint(0, 2, [batch, size], dtype=dtype, device="cuda") + return (inp,) + + bench = Benchmark( + op_name="nonzero_int", + torch_op=torch.nonzero, + arg_func=nonzero_args, + dtypes=INT_DTYPES, batch=REDUCTION_BATCH, sizes=SIZES, )