Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cudagraphs dynamo backend #80566

Closed
wants to merge 30 commits into from
Closed

Conversation

ezyang
Copy link
Contributor

@ezyang ezyang commented Jun 29, 2022

Stack from ghstack (oldest at bottom):

This backend handles cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Requires this functorch bug fix: pytorch/functorch#935

Signed-off-by: Edward Z. Yang <ezyangfb.com>

These tests demonstrate cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jun 29, 2022
These tests demonstrate cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: 0e75258940714543465013512e07d877b6886e57
Pull Request resolved: #80566
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 29, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit ac5b125 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@ezyang
Copy link
Contributor Author

ezyang commented Jun 29, 2022

cc @ngimel @wconstab

@ezyang ezyang changed the title PoC tests for dynamo cudagraphs PoC more robust cudagraphs dynamo backend Jul 1, 2022
This backend handles cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
@ezyang
Copy link
Contributor Author

ezyang commented Jul 1, 2022

cc @SherlockNoMad

This backend handles cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jul 1, 2022
This backend handles cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: 20158e76b4374cd5efb82a59bc72621b52e1f75c
Pull Request resolved: #80566
return tree_map(cloner, self.static_outputs)

else:
# warmup
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does warmup actually do for graph recording?

our docs suggest to warm up for "a few iterations" (in their example they use 3) but idk what for.
https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/#api-example

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is for handling internal libraries which change what cuda kernels they call based on a cache. A simple example is cudnn benchmarking: the first run will trigger a bunch of benchmarking cuda kernels which you definitely don't want to record. According to the doc nccl needs something like 11 warmup iterations LOL

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's ddp, not nccl, but I'm not sure. Cudnn benchmarking throws an error if someone tries to capture it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops you're right

This backend handles cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jul 6, 2022
This backend handles cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: 19a440b3dd9c5cb31f951fc2e575b62b6bb1e053
Pull Request resolved: #80566
@ezyang
Copy link
Contributor Author

ezyang commented Jul 6, 2022

Substantial refactor of many parts of the code, I also found a bunch of bugs which I will be filing issues for tomorrow

@ezyang
Copy link
Contributor Author

ezyang commented Jul 6, 2022

The screamer bug is somehow cuda graphs, during recording, just loses input mutations. Crazy.

@jansel
Copy link
Contributor

jansel commented Jul 6, 2022

The screamer bug is somehow cuda graphs, during recording, just loses input mutations. Crazy.

Don't you copy all the inputs to invoke the cudagraphs? I'd expect the static_inputs to get mutated, then you need to copy those mutations to the real inputs.

@ngimel
Copy link
Collaborator

ngimel commented Jul 6, 2022

I don't think that's a bug, graph recording don't execute the kernels, it only records the succession of launches.

In [1]: import torch

In [2]: a=torch.ones(4, device="cuda")

In [3]: graph = torch.cuda.CUDAGraph()

In [4]: with torch.cuda.graph(graph):
   ...:     for _ in range(4):
   ...:         a += 1
   ...: 

In [5]: a
Out[5]: tensor([1., 1., 1., 1.], device='cuda:0') # expected, nothing happened to a

In [7]: graph.replay()

In [8]: a
Out[8]: tensor([5., 5., 5., 5.], device='cuda:0') # expected, now a is modified

In [9]: graph = torch.cuda.CUDAGraph()
In [10]: with torch.cuda.graph(graph):
    ...:     for _ in range(4):
    ...:         b=2*a
    ...: 

In [11]: b
Out[11]: tensor([0., 0., 0., 0.], device='cuda:0') # expected, b is uninitialized
In [12]: graph.replay()

In [13]: b
Out[13]: tensor([10., 10., 10., 10.], device='cuda:0') #expected, b is computed

@ezyang
Copy link
Contributor Author

ezyang commented Jul 6, 2022

@ngimel is right. Need to update the docs to make this more clear lol

This backend handles cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
@ezyang ezyang requested a review from Chillee July 6, 2022 14:56
super().run(*args)
return self.mutated_inputs

class ProxyTensorInterpreter(torch.fx.Interpreter):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Move this into a shared area. @Chillee I wanted you to look over this

mutated_inputs = FindInputMutations(submod)(*map(unwrap_elem, args))
# smh the module didn't get transferred wut
self.new_module.add_submodule(target, CudaGraphModule(submod, mutated_inputs))
return wrap_output(out, torch.fx.Proxy(self.new_graph.call_module(target, tree_map(unwrap_proxy_node, args), tree_map(unwrap_proxy_node, kwargs)), self.tracer))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kind of blegh

with FakeTensorMode.push() as mode:
t.run(*map(mode.from_tensor, inputs))
model = t.new_module
model.recompile()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot wordier than it should be

# TODO: this is not compositional
with FakeTensorMode.push() as mode:
fake_args = [mode.from_tensor(a) for a in args]
return super().run(*fake_args)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be moved somewhere else

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass isn't sound; because we save fake tensors directly on nodes, if a graph has metadata changing operation like resize_ it will mutate the fake tensor

This backend handles cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Requires this functorch bug fix: pytorch/functorch#935

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jul 6, 2022
This backend handles cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: 45b7ef1cd924b460857859435d8c188a3774d821
Pull Request resolved: #80566
This backend handles cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Requires this functorch bug fix: pytorch/functorch#935

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
This backend handles cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Requires this functorch bug fix: pytorch/functorch#935

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jul 21, 2022
This backend handles cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: 27084075f7c6d60ce0762ed7f3f94b92a7a6e9bd
Pull Request resolved: #80566
@ezyang
Copy link
Contributor Author

ezyang commented Jul 21, 2022

@pytorchbot merge -g

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Merge failed due to Refusing to merge as mandatory check(s) Lint failed for rule superuser
Raised by https://github.com/pytorch/pytorch/actions/runs/2714874815

This backend handles cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Requires this functorch bug fix: pytorch/functorch#935

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jul 22, 2022
This backend handles cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: aaac41e05840ee08ff955cf95bc813b7f9f9e8df
Pull Request resolved: #80566
@ezyang
Copy link
Contributor Author

ezyang commented Jul 22, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Contributor

Hey @ezyang.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

def model(x, y):
return (x + y) * y

with torchdynamo.optimize(aot_autograd_cudagraphs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b q: Is the optimization unrolled outside of the with scope? If you called torchdynamo.optimize() in a loop would the result be the same as calling it once?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a little subtle

First, the context manager is misleading. It doesn't actually turn on optimization for the inside of the manager. Optimization only turns on when you hit a new frame (e.g., do a function call).

With that out of the way, what if you have the optimization inside or outside of a loop? It will depend. If the loop successfully unrolls, then you will get a compiled graph outside the loop that has the unrolled graph. But let's say there's some reason we can't compile the outer frame. Then we will compile the inner function, and the two applications are equivalent/

loss = model(x, y).sum()
loss.backward()

@patch("functorch._src.config.use_functionalize", True)
Copy link
Member

@msaroufim msaroufim Jul 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob q: I've seen the terminology functionalize a few times and my understanding is you remove stateful operations to ship to compilers that can't represent aliasing. Is that the majority of compilers? Some compilers we really care about?

EDIT: nvm saw Jason's comment about how CUDA graphs don't support input mutation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functionalize gives you a graph that doesn't have mutating operations in it. CUDA graphs actually doesn't want functionalization, but we need it because there are passes like the partitioner we use here which are unsound in the presence of mutation.

# NB: we override __call__ as we don't need any nn.Module machinery
# and to reduce overhead
def __call__(self, *args):
# TODO: once we've recorded here, we'd like to replace the __call__
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment the general PT 2.0 strategy where you do graph surgery by swapping graph nodes for compiled code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is a general pattern for doing compilation on FX graph directly.

# and to reduce overhead
def __call__(self, *args):
# TODO: once we've recorded here, we'd like to replace the __call__
# implementation with compiled bytecode that copies into static, replays
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is graph capture what NVIDIA would call what we call tracing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sort of. But our tracing (proxy tensor) operates at a different level than cuda graph capture. Our tracing capture aten ops; cuda graph captures cuda kernel launches

@facebook-github-bot facebook-github-bot deleted the gh/ezyang/1228/head branch July 25, 2022 14:18
facebook-github-bot pushed a commit that referenced this pull request Jul 26, 2022
Summary:
This backend handles cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Requires this functorch bug fix: pytorch/functorch#935

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: #80566
Approved by: https://github.com/ngimel, https://github.com/wconstab

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/3c2c2cc9474b46238bf2f517762ab853b84bbf4d

Reviewed By: osalpekar

Differential Revision: D38114100

Pulled By: ezyang

fbshipit-source-id: 3fb056e599cef605792cea9d794de701c596a9d8
ezyang added a commit to ezyang/torchdynamo that referenced this pull request Aug 9, 2022
Previously it was in pytorch/pytorch but it depends on torchdynamo
code more closely, so this seems like the logical place.

Previously at pytorch/pytorch#80566

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
ezyang added a commit to pytorch/torchdynamo that referenced this pull request Aug 10, 2022
* Move aot_cudagraphs backend here

Previously it was in pytorch/pytorch but it depends on torchdynamo
code more closely, so this seems like the logical place.

Previously at pytorch/pytorch#80566

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants