Skip to content

Commit

Permalink
[inductor] Rewrite Triton templates + epilogue fusion (retry) (pytorc…
Browse files Browse the repository at this point in the history
…h#91575)

This reverts commit 94262ef to reland pytorch#91105 / pytorch#90738.

Fixes pytorch/torchdynamo#2015

Pull Request resolved: pytorch#91575
Approved by: https://github.com/ngimel
  • Loading branch information
jansel authored and pytorchmergebot committed Jan 11, 2023
1 parent 6912f7c commit 7c1c239
Show file tree
Hide file tree
Showing 34 changed files with 1,584 additions and 1,956 deletions.
2 changes: 1 addition & 1 deletion benchmarks/dynamo/microbenchmarks/microbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def main():
if args.verbose:
torch._inductor.config.debug = True

torch._inductor.config.triton.autotune = True
torch._inductor.config.triton.autotune_pointwise = True

rows = []
for model in (MicroBenchmarks.sum,):
Expand Down
6 changes: 6 additions & 0 deletions benchmarks/dynamo/torchbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def setup_torchbench_cwd():
"tacotron2",
}

REQUIRE_HIGHER_FP16_TOLERANCE = {
"drq",
}

REQUIRE_COSINE_TOLERACE = {
# Just keeping it here even though its empty, if we need this in future.
}
Expand Down Expand Up @@ -335,6 +339,8 @@ def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
cosine = self.args.cosine
# Increase the tolerance for torch allclose
if self.args.float16 or self.args.amp:
if name in REQUIRE_HIGHER_FP16_TOLERANCE:
return 1e-2, cosine
return 1e-3, cosine
if is_training and current_device == "cuda":
tolerance = 1e-3
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,6 @@ def main():
'include/THH/generic/*.h',
'include/sleef.h',
"_inductor/codegen/*.h",
"_inductor/codegen/*.j2",
'share/cmake/ATen/*.cmake',
'share/cmake/Caffe2/*.cmake',
'share/cmake/Caffe2/public/*.cmake',
Expand Down
146 changes: 146 additions & 0 deletions test/inductor/test_select_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Owner(s): ["module: inductor"]
import functools
import logging
from unittest.mock import patch

import torch
import torch._dynamo.config as dynamo_config
import torch._inductor.config as inductor_config
import torch._inductor.select_algorithm as select_algorithm
import torch.nn.functional as F
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import counters
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA

torch.backends.cuda.matmul.allow_tf32 = False


def patches(fn):
def skip_cache(self, key, generate):
return generate()

for patcher in [
patch.object(dynamo_config, "log_level", logging.INFO),
patch.object(dynamo_config, "verbose", True),
patch.object(inductor_config, "debug", True),
patch.object(inductor_config, "max_autotune", True),
patch.object(inductor_config, "epilogue_fusion", True),
patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)),
patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache),
]:
fn = patcher(fn)

@functools.wraps(fn)
def wrapped(*args, **kwargs):
counters.clear()
torch.manual_seed(12345)
assert (
not torch.backends.cuda.matmul.allow_tf32
), "correctness testing is allergic to tf32"
return fn(*args, **kwargs)

return wrapped


class TestSelectAlgorithm(TestCase):
@patches
def test_linear_relu(self):
@torch.compile
def foo(input, weight, bias):
return F.relu(F.linear(input, weight, bias))

foo(
torch.randn(64, 32, device="cuda"),
torch.randn(16, 32, device="cuda"),
torch.randn(16, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
# It would be nice to assert this got fused into a single kernel, but that
# only happens if we select a triton template (and not aten).

@patches
def test_addmm(self):
@torch.compile
def foo(input, weight, bias):
return torch.addmm(bias, input, weight)

foo(
torch.randn(20, 33, device="cuda"),
torch.randn(33, 16, device="cuda"),
torch.randn(20, 16, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@patches
def test_mm(self):
@torch.compile
def foo(a, b):
return torch.mm(a, b)

foo(
torch.randn(8, 32, device="cuda"),
torch.randn(32, 8, device="cuda"),
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@patches
def test_mm_skip(self):
@torch.compile
def foo(a, b):
return torch.mm(a, b)

foo(
torch.randn(8, 32, device="cuda", dtype=torch.float64),
torch.randn(32, 8, device="cuda", dtype=torch.float64),
)
# float64 not supported by tl.dot()
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)

@patches
def test_bmm(self):
@torch.compile
def foo(a, b):
return torch.bmm(a, b)

foo(
torch.randn(2, 8, 32, device="cuda"),
torch.randn(2, 32, 8, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@patches
def test_mm_not_even_k(self):
@torch.compile
def foo(a, b):
return torch.mm(a, b)

foo(
torch.randn(11, 22, device="cuda"),
torch.randn(22, 33, device="cuda"),
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@patches
def test_baddbmm(self):
@torch.compile
def foo(a, b, c):
return torch.baddbmm(c, a, b)

foo(
torch.randn(2, 8, 32, device="cuda"),
torch.randn(2, 32, 8, device="cuda"),
torch.randn(2, 1, 8, device="cuda"),
)
# Autotuning checks correctness of each version
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)


if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu

if IS_LINUX and HAS_CUDA and is_big_gpu(0):
run_tests()
92 changes: 3 additions & 89 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
unittest.skipIf, not HAS_MULTIGPU, "requires multiple cuda devices"
)

torch._inductor.config.triton.autotune = False # too slow
torch._inductor.config.triton.autotune_pointwise = False # too slow


# For OneDNN bf16 path, OneDNN requires the cpu has intel avx512 with avx512bw,
Expand Down Expand Up @@ -2505,76 +2505,6 @@ def fn(x, y):
self.assertEqual(a.stride(), c.stride())
self.assertEqual(c.stride()[2], 1)

@requires_cuda()
@patch.object(config.triton, "convolution", "triton")
@patch.object(config.triton, "dense_indexing", "True")
def test_triton_conv(self):
@torch._dynamo.optimize("inductor", nopython=True)
def triton_conv(
x,
w,
bias,
stride,
padding,
dilation,
groups,
):
y = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
return y

stride, padding, dilation, groups = (1, 1), (0, 0), (1, 1), 1
dtype = torch.float32
x = torch.randn((32, 128, 32, 32), dtype=dtype, device=self.device)
w = torch.randn((32, 128, 1, 1), dtype=dtype, device=self.device)
bias = torch.randn((32), dtype=dtype, device=self.device)

y = triton_conv(x, w, bias, stride, padding, dilation, groups)
y_correct = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
self.assertTrue(same(y, y_correct, cos_similarity=True, tol=0.1))

@requires_cuda()
@patch.object(config.triton, "convolution", "autotune")
@patch.object(config.triton, "dense_indexing", "True")
def test_conv_autotune(self):
@torch._dynamo.optimize("inductor", nopython=True)
def triton_conv(
x,
w,
bias,
stride,
padding,
dilation,
groups,
):
y = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
return y

stride, padding, dilation, groups = (1, 1), (0, 0), (1, 1), 1
dtype = torch.float32
x = torch.randn((32, 128, 32, 32), dtype=dtype, device=self.device)
w = torch.randn((32, 128, 1, 1), dtype=dtype, device=self.device)
bias = torch.randn((32), dtype=dtype, device=self.device)

y = triton_conv(x, w, bias, stride, padding, dilation, groups)
y_correct = torch.conv2d(x, w, bias, stride, padding, dilation, groups)
self.assertTrue(same(y, y_correct, cos_similarity=True, tol=0.1))

@patch.object(config.triton, "mm", "triton")
def test_triton_mm2(self):
@torch._dynamo.optimize("inductor", nopython=True)
def fn(x, y):
return torch.relu(torch.mm(x, y))

N = 1024
a = torch.randn([N, N], device=self.device, dtype=torch.float32)
b = torch.randn([N, N], device=self.device, dtype=torch.float32)
c1 = torch.relu(torch.mm(a, b))
torch._inductor.metrics.reset()
c = fn(a, b)
assert torch.allclose(c1, c, atol=1e-3, rtol=1e-3)
if self.device == "cuda":
assert torch._inductor.metrics.generated_kernel_count == 1

def test_std(self):
def fn(x):
return (
Expand Down Expand Up @@ -4560,12 +4490,6 @@ def fn(a, b):
)
expected_kernel = 0
# codegen mm kernel from template
if config.triton.mm != "aten" and self.device == "cuda":
expected_kernel = 1
if config.triton.mm == "autotune":
self.assertLessEqual(
torch._inductor.metrics.generated_kernel_count, expected_kernel
)
self.assertEqual(
torch._inductor.metrics.generated_kernel_count, expected_kernel
)
Expand Down Expand Up @@ -4641,15 +4565,6 @@ def run(x):
result.sum().backward()

expected_kernel = 4
if config.triton.mm != "aten" and self.device == "cuda":
# fwd: 2 * (mm+dropout) kernels = 2 kernels
# bwd: dropout + (mm) + 2 * (mm+dropout) kernels = 4 kernels
# expect 2 + 4 = 6 kernels
expected_kernel = 6
if config.triton.mm == "autotune":
self.assertLessEqual(
torch._inductor.metrics.generated_kernel_count, expected_kernel
)
self.assertEqual(
torch._inductor.metrics.generated_kernel_count, expected_kernel
)
Expand Down Expand Up @@ -4979,7 +4894,6 @@ def fn(x, y):
inputs = (inputs[1], inputs[0])
self.assertTrue(same(opt(*inputs), fn(*inputs)))

@patch.object(config.triton, "mm", "aten")
def test_list_clearing(self):

if self.device == "cpu":
Expand Down Expand Up @@ -5685,7 +5599,7 @@ def forward(self, view, reshape_2):
res = opt_mod(*args)
self.assertTrue(same(ref, res))

@patch.object(config.triton, "autotune", True)
@patch.object(config.triton, "autotune_pointwise", True)
def test_inplace_add_alpha_autotune(self):
def fn(x, y):
aten.add_.Tensor(x, y, alpha=0.55)
Expand All @@ -5703,7 +5617,7 @@ def fn(x, y):
fn_compiled([x3, y])
assert same(x2, x3)

@patch.object(config.triton, "autotune", True)
@patch.object(config.triton, "autotune_pointwise", True)
def test_inplace_buffer_autotune(self):
def foo(x, y, z):
a = x @ y
Expand Down
8 changes: 6 additions & 2 deletions torch/_dynamo/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,12 @@ def _fn(*args, **kwargs):
return _fn


def rand_strided(size, stride, dtype=torch.float32, device="cpu"):
needed_size = sum((shape - 1) * stride for shape, stride in zip(size, stride)) + 1
def rand_strided(size, stride, dtype=torch.float32, device="cpu", extra_size=0):
needed_size = (
sum((shape - 1) * stride for shape, stride in zip(size, stride))
+ 1
+ extra_size
)
if dtype.is_floating_point:
buffer = torch.randn(needed_size, dtype=dtype, device=device)
else:
Expand Down
Loading

0 comments on commit 7c1c239

Please sign in to comment.