Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Free inputs to inductor when they are no longer needed #1600

Merged
merged 10 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import random
import sys
import unittest
import weakref
from unittest.mock import patch

import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn import functional as F
from torch.testing._internal.common_utils import TestCase as TorchTestCase
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten
from torch.utils._pytree import tree_unflatten

Expand Down Expand Up @@ -207,7 +209,6 @@ def run(*ex, **kwargs):
# for graph in exp[2]:
# print("Graph", graph)
assert called, "Ran graph without calling compile_fx"

assert type(actual) == type(correct)

if reference_in_float:
Expand Down Expand Up @@ -3625,10 +3626,10 @@ def f(x):
)
compiled = compile_fx_inner(traced, [torch.randn(8, 4, device=self.device)])

out = compiled(torch.randn(8, 4, device=self.device))
out = compiled([torch.randn(8, 4, device=self.device)])
self.assertEqual(out[0].shape, (16, 2))

out = compiled(torch.randn(12, 4, device=self.device))
out = compiled([torch.randn(12, 4, device=self.device)])
self.assertEqual(out[0].shape, (24, 2))

@requires_cuda()
Expand All @@ -3655,6 +3656,63 @@ def fn(x, y):
)
self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))

@patch.object(config.triton, "mm", "aten")
def test_list_clearing(self):

if self.device == "cpu":
contexts = [contextlib.nullcontext]
else:
contexts = [
contextlib.nullcontext,
lambda: patch.object(config.triton, "cudagraphs", True),
]

for context in contexts:
with context():
inps = [
torch.rand([5, 5]).to(self.device),
torch.rand([5, 5]).to(self.device),
]
inp_refs = [weakref.ref(inp) for inp in inps]

def fn(x, y):
a = x + y
return (a @ a,)

fn_fx = make_fx(fn)(inps[0], inps[1])
fn_compiled = compile_fx_inner(fn_fx, inps)

test_self = self
matmul_seen = False

class TestRefMode(TorchDispatchMode):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol, clever test.

def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}

nonlocal inps
nonlocal inp_refs
nonlocal test_self
nonlocal matmul_seen

# by matmul, inputs should be deallocated
if func is aten.mm.out:
matmul_seen = True
test_self.assertEqual(len(inps), 0)
test_self.assertIsNone(inp_refs[0]())
test_self.assertIsNone(inp_refs[1]())

return func(*args, **kwargs)

with TestRefMode():
fn_compiled(inps)

# for some reason, TorchDispatch doesnt capture the
# cuda mm call (even without cudagraphs)
if self.device == "cpu":
self.assertTrue(matmul_seen)
else:
self.assertEqual(len(inps), 0)


if HAS_CPU:

Expand Down Expand Up @@ -3692,7 +3750,7 @@ def fn(x, y):
fn_fx = make_fx(fn)(x1, y)
fn_compiled = compile_fx_inner(fn_fx, [x1, y])
fn(x2, y)
fn_compiled(x3, y)
fn_compiled([x3, y])
assert same(x2, x3)

def test_no_op_squeeze(self):
Expand Down Expand Up @@ -3783,7 +3841,7 @@ def forward(
]
mod = make_fx(forward)(*inps)
compiled = compile_fx_inner(mod, inps)
compiled(*inps)
compiled(inps)

@patch.object(config, "fallback_random", True)
def test_dtype_factory_issue(self):
Expand All @@ -3799,7 +3857,7 @@ def forward():

mod = make_fx(forward)()
compiled = compile_fx_inner(mod, ())
assert compiled()[0].device.type == "cuda"
assert compiled([])[0].device.type == "cuda"

@patch.object(config.triton, "cudagraphs", True)
def test_expanded_inputs_cudagraphs(self):
Expand Down
2 changes: 1 addition & 1 deletion torchdynamo/debug_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def inductor_fails(fx_g, args, check_str=None):

try:
compile_mod = compile_fx_inner(fx_g, args)
compile_mod(*args)
compile_mod(args)
except Exception as e:
if check_str is not None and check_str not in repr(e):
return False
Expand Down
17 changes: 14 additions & 3 deletions torchinductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,20 @@ def __init__(self):
)

self.prefix.splice(
f"""
"""

async_compile.wait(globals())
del async_compile

def call({', '.join(V.graph.graph_inputs.keys())}):
def call(args):
"""
)
with self.prefix.indent():
inp_len = len(V.graph.graph_inputs.keys())
if inp_len != 0:
lhs = f"{', '.join(V.graph.graph_inputs.keys())}{'' if inp_len != 1 else ','}"
self.prefix.writeline(f"{lhs} = args")
self.prefix.writeline("args.clear()")
for name in V.graph.randomness_seeds:
self.prefix.writeline(
f"torch.randint(2**31, size=(), dtype=torch.int64, out={name})"
Expand Down Expand Up @@ -285,6 +290,12 @@ def codegen_allocation(self, buffer):

def codegen_free(self, buffer):
name = buffer.get_name()

# can be freed but not reused
if isinstance(buffer, ir.InputBuffer):
self.writeline(f"del {name}")
return

if not self.can_reuse(buffer):
return
self.freed.add(name)
Expand Down Expand Up @@ -390,7 +401,7 @@ def add_fake_input(name, shape, stride, device, dtype):
)

output.writeline(
f"print_performance(lambda: call({', '.join(V.graph.graph_inputs.keys())}))"
f"print_performance(lambda: call([{', '.join(V.graph.graph_inputs.keys())}]))"
)

def define_kernel(self, name: str, kernel: str):
Expand Down
34 changes: 20 additions & 14 deletions torchinductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import functorch
import torch.fx
from functorch.compile import make_boxed_compiler
from functorch.compile import min_cut_rematerialization_partition
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._mode_utils import no_dispatch
Expand Down Expand Up @@ -135,6 +134,9 @@ def compile_fx_inner(
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
f"graph {graph_id}",
)

# aot autograd needs to know to pass in inputs as a list
result._boxed_call = True
return result


Expand All @@ -157,14 +159,15 @@ def align_inputs(model, inputs, static_input_idxs=()):
if len(check_inputs) == 0:
return model

def run(*new_inputs):
def run(new_inputs):
for i in check_inputs:
if new_inputs[i].data_ptr() % ALIGNMENT:
if isinstance(new_inputs, tuple):
new_inputs = list(new_inputs)
new_inputs[i] = clone_preserve_strides(new_inputs[i])
new_inputs = [x.to("cuda") if is_unspec_input(x) else x for x in new_inputs]
return model(*new_inputs)
new_inputs_to_cuda = [
x.to("cuda") if is_unspec_input(x) else x for x in new_inputs
]
new_inputs.clear()
return model(new_inputs_to_cuda)

return run

Expand All @@ -177,13 +180,13 @@ def cudagraphify(model, inputs, static_input_idxs=()):

compiled_fn = None

def run(*new_inputs):
def run(new_inputs):
nonlocal compiled_fn
if compiled_fn is None:
with dynamo_utils.preserve_rng_state():
compiled_fn = cudagraphify_impl(model, new_inputs, static_input_idxs)

return compiled_fn(*new_inputs)
return compiled_fn(new_inputs)

return run

Expand Down Expand Up @@ -237,22 +240,23 @@ def static_input(x):
torch.cuda.synchronize()
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
# copy static_inputs because it will be cleared in model
with torch.cuda.stream(stream):
model(*static_inputs)
model(list(static_inputs))
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()

# record
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
static_outputs = model(*static_inputs)
static_outputs = model(list(static_inputs))
if not isinstance(static_outputs, (list, tuple)):
static_outputs = (static_outputs,)

if config.size_asserts:

def run(*new_inputs):
def run(new_inputs):
assert len(static_inputs) == len(new_inputs)
for idx, (dst, src, expanded_dims) in enumerate(
zip(static_inputs, new_inputs, inps_expanded_dims)
Expand All @@ -266,6 +270,7 @@ def run(*new_inputs):
dst = index_expanded_dims(dst, expanded_dims)
src = index_expanded_dims(src, expanded_dims)
dst.copy_(src)
new_inputs.clear()
graph.replay()
return static_outputs

Expand All @@ -274,11 +279,12 @@ def run(*new_inputs):
idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
]

def run(*new_inputs):
def run(new_inputs):
for idx in copy_indices:
src = index_expanded_dims(static_inputs[idx], inps_expanded_dims[idx])
dst = index_expanded_dims(new_inputs[idx], inps_expanded_dims[idx])
dst.copy_(src)
new_inputs.clear()
graph.replay()
return static_outputs

Expand Down Expand Up @@ -357,8 +363,8 @@ def bw_compiler(model: torch.fx.GraphModule, example_inputs):
return aot_autograd(
model_,
example_inputs_,
fw_compiler=make_boxed_compiler(fw_compiler),
bw_compiler=make_boxed_compiler(bw_compiler),
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
decompositions=select_decomp_table(),
partition_fn=functools.partial(
min_cut_rematerialization_partition, compiler="inductor"
Expand Down
5 changes: 5 additions & 0 deletions torchinductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,11 @@ def free_buffers(self):
node = self.name_to_node[name]
if node.can_free():
V.graph.wrapper_code.codegen_free(node.node)
elif name in V.graph.graph_inputs:
storage = V.graph.graph_inputs[name].data
assert storage.is_input_buffer()
V.graph.wrapper_code.codegen_free(storage.data)

self.buffer_names_to_free.clear()

def remove_kernel_local_buffers(self):
Expand Down