The FX Intermediate Representation (FX IR) is a specialized dialect of PyTorch's FX intermediate representation. FX IR reduces the ambiguities inherent to the original intermediate representation, facilitating the application of compiler algorithms such as static analysis and graph rewriting.
FX IR enforces the following invariants on a Graph.
- Node opcodes. Each node must have one of the following opcodes:
"placeholder", representing inputTensors;"output", representing outputTensors;"call_module", representing individualTensors and operations.
- I/O nodes. A
Graphmust contain exactly one"placeholder"and exactly one"output":- the
"placeholder"node represents the collection of model inputs; - the
"output"node represents the collection of model outputs.
- the
Arraynodes. A"call_module"node representing an individualTensor, with the following signature:- consume a tuple of
Tensors; - produce a single
Tensor.
- consume a tuple of
- Non-
Arraynodes. A"call_module"node representing an operation, with the following signature:- consume one or more
Tensors; - produce a tuple of
Tensors.
- consume one or more
Thanks to the invariants, users of the FX IR can make the following assumptions.
- Every operation is represented as a non-
Array"call_module"node. - Every
Graphis bipartite:- the array partition contains all the
Array"call_module"nodes; - the operator partition contains the
"placeholder"node, non-Array"call_module"nodes, and the"output"node.
- the array partition contains all the