Skip to content

Commit

Permalink
Move Dynamo docs back to core (pytorch#89769)
Browse files Browse the repository at this point in the history
With contributions from @svekars and @malfet

Waiting for doc build job to complete
Pull Request resolved: pytorch#89769
Approved by: https://github.com/soumith, https://github.com/malfet
  • Loading branch information
msaroufim authored and pytorchmergebot committed Nov 29, 2022
1 parent 2b52267 commit 9048cf1
Show file tree
Hide file tree
Showing 12 changed files with 2,168 additions and 0 deletions.
Binary file added docs/source/_static/img/dynamo/TorchDynamo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/img/dynamo/td_stack.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
154 changes: 154 additions & 0 deletions docs/source/dynamo/custom-backends.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
Custom Backends
===============

Debugging Backend
-----------------

Suppose you wanted to better understand what is going on during a
compilation you can create a custom compiler which we’ll refer to as a
backend that will print pretty print the fx ``GraphModule`` extracted
from dynamo’s bytecode analysis and return a ``forward()`` callable.

.. code-block:: python
from typing import List
import torch
import torch._dynamo as dynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@dynamo.optimize(my_compiler)
def fn(x, y):
a = torch.cos(x)
b = torch.sin(y)
return a + b
fn(torch.randn(10), torch.randn(10))
Running the above example produces the following output:

::

my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ------------------------------------------------------ ---------- --------
placeholder x x () {}
placeholder y y () {}
call_function cos <built-in method cos of type object at 0x7f1a894649a8> (x,) {}
call_function sin <built-in method sin of type object at 0x7f1a894649a8> (y,) {}
call_function add <built-in function add> (cos, sin) {}
output output output ((add,),) {}

This works for ``torch.nn.Module`` as well as shown below

.. code-block:: python
import torch
import torch._dynamo as dynamo
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(torch.cos(x))
mod = MockModule()
optimized_mod = dynamo.optimize(my_compiler)(mod)
optimized_mod(torch.randn(10))
Let’s take a look at one more example with control flow.

.. code-block:: python
from typing import List
import torch
import torch._dynamo as dynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@dynamo.optimize(my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
Running this example produces the following output:

::

my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f8d259298a0> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}

my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- ----------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (b, -1) {}
call_function mul_1 <built-in function mul> (x, mul) {}
output output output ((mul_1,),) {}

my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- --------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (x, b) {}
output output output ((mul,),) {}

The order of the last two graphs is nondeterministic depending
on which one is encountered first by the just-in-time compiler.

Speedy Backend
--------------

Integrating a custom backend that offers superior performance is also
easy and we’ll integrate a real one
with `optimize_for_inference <https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html>`__:

.. code-block :: python
def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
scripted = torch.jit.trace(gm, example_inputs)
return torch.jit.optimize_for_inference(scripted)
And then you should be able to optimize any existing code with

.. code-block:: python
@dynamo.optimize(optimize_for_inference_compiler)
def code_to_accelerate():
...
Composable Backends
-------------------

TorchDynamo includes many backends, which can be found in
`backends.py <https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/backends.py>`__
or ``torchdynamo.list_backends()``. You can combine these backends
together with the following code:

.. code-block:: python
from torch._dynamo.optimizations import BACKENDS
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
trt_compiled = BACKENDS["tensorrt"](gm, example_inputs)
if trt_compiled is not None:
return trt_compiled
# first backend failed, try something else...
cudagraphs_compiled = BACKENDS["cudagraphs"](gm, example_inputs)
if cudagraphs_compiled is not None:
return cudagraphs_compiled
return gm.forward
145 changes: 145 additions & 0 deletions docs/source/dynamo/deep-dive.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
TorchDynamo Deeper Dive
=======================
**Author**: `Jason Ansel <https://github.com/jansel>`_

What is a guard?
----------------

TorchDynamo operates just-in-time and specializes graphs based on
dynamic properties. For example, the first graph above has the following
guards:

::

GUARDS:
- local 'a' TENSOR_MATCH
- local 'b' TENSOR_MATCH
- global 'torch' FUNCTION_MATCH

If any of those guards fail, the graph will be recaptured and
recompiled. The interesting guard type there is ``TENSOR_MATCH``, which
checks the following torch.Tensor properties:

- Python class of the tensor (tensor subclassing, etc)
- dtype
- device
- requires_grad
- dispatch_key (with thread-local includes/excludes applied)
- ndim
- sizes\* (optional)
- strides\* (optional)

For sizes/strides you can disable this specialization by setting the
following parameter:

.. code-block:: python
torch._dynamo.config.dynamic_shapes = True

The full specialization mode allows the backend compiler to assume an
entirely static graph. Unfortunately, most backends require this.
Operators which return dynamic shapes will trigger a graph break when
not in dynamic shape mode.

What is dynamo doing?
---------------------

If you want to understand better what TorchDynamo is doing, you can set:

.. code-block:: python
torchdynamo.config.debug = True
which triggers useful (but spammy) printouts.

For example, the printouts for the first graph in the ``toy_example``
above are:

::

__compiled_fn_0 <eval_with_key>.1
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f9ca082f8a0> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}

ORIGINAL BYTECODE toy_example example.py 9
10 0 LOAD_FAST 0 (a)
2 LOAD_GLOBAL 0 (torch)
4 LOAD_METHOD 1 (abs)
6 LOAD_FAST 0 (a)
8 CALL_METHOD 1
10 LOAD_CONST 1 (1)
12 BINARY_ADD
14 BINARY_TRUE_DIVIDE
16 STORE_FAST 2 (x)

11 18 LOAD_FAST 1 (b)
20 LOAD_METHOD 2 (sum)
22 CALL_METHOD 0
24 LOAD_CONST 2 (0)
26 COMPARE_OP 0 (<)
28 POP_JUMP_IF_FALSE 38

12 30 LOAD_FAST 1 (b)
32 LOAD_CONST 3 (-1)
34 BINARY_MULTIPLY
36 STORE_FAST 1 (b)

13 >> 38 LOAD_FAST 2 (x)
40 LOAD_FAST 1 (b)
42 BINARY_MULTIPLY
44 RETURN_VALUE

MODIFIED BYTECODE
9 0 LOAD_GLOBAL 3 (__compiled_fn_0)
2 LOAD_FAST 0 (a)
4 LOAD_FAST 1 (b)
6 CALL_FUNCTION 2
8 UNPACK_SEQUENCE 2
10 STORE_FAST 2 (x)
12 POP_JUMP_IF_FALSE 24
14 LOAD_GLOBAL 4 (__resume_at_30_1)
16 LOAD_FAST 1 (b)
18 LOAD_FAST 2 (x)
20 CALL_FUNCTION 2
22 RETURN_VALUE
>> 24 LOAD_GLOBAL 5 (__resume_at_38_2)
26 LOAD_FAST 1 (b)
28 LOAD_FAST 2 (x)
30 CALL_FUNCTION 2
32 RETURN_VALUE

GUARDS:
- local 'a' TENSOR_MATCH
- local 'b' TENSOR_MATCH
- global 'torch' FUNCTION_MATCH

At the top you can see the FX graph (which we already shared above).
Next you see the original bytecode of the function, followed by the
modified bytecode generated by TorchDynamo. Finally, you see the guards
which we covered above.

In the modified bytecode ``__compiled_fn_0`` is the return value of
``my_compiler()`` (the compiled graph). ``__resume_at_30_1`` and
``__resume_at_38_2`` are both generated continuation functions that pick
up execution after a graph break (at bytecode offsets 30 and 38). Each
of these functions take the form:

::

__resume_at_<offset>:
... restore stack state if needed ...
JUMP_ABSOLUTE <offset> into toy_example
... original bytecode of toy_example ...

By generating this `resume_at` function we force the remainder of the
function to be executed in a new Python frame which recursively
triggers TorchDynamo to restart its capture once execution reaches that
point for the first time.
Loading

0 comments on commit 9048cf1

Please sign in to comment.