Skip to content

Commit

Permalink
Allow caffe2-specific graph transformations for OperatorExportTypes.O…
Browse files Browse the repository at this point in the history
…NNX_ATEN_FALLBACK when BUILD_CAFFE2 is ON (pytorch#67460) (pytorch#68490)

Summary:
Pull Request resolved: pytorch#68490

The use of ATEN as a fallback operator during ONNX conversion is important for increasing operator coverage or even provide more efficient implementations over some ONNX ops.

Currently this feature is available through `OperatorExportTypes.ONNX_ATEN_FALLBACK`,
but it also performs changes to the graph that are runnable by Caffe2, only.

This PR introduces restricts caffe2-specific graph transformations for `ONNX_ATEN_FALLBACK`
operator export type for when pytorch is built with caffe2 support (aka BUILD_CAFFE2=1 during build)

The first version of this PR introduced a new operator export type `ONNX_ATEN__STRICT_FALLBACK`,
which essentially is the same as `ONNX_ATEN_FALLBACK` but without caffe2 transformations.
It was preferred to not introduce a new operator export type, but to refine the existing aten fallback one

## BC-breaking note
### The global constant `torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE` is removed in favor of
a less visible `torch.onnx._CAFFE2_ATEN_FALLBACK`.
`PYTORCH_ONNX_CAFFE2_BUNDLE` is really a dead code flag always set to False.
One alternative would be fixing it, but pytorch#66658 disables Caffe2 build by default.
Making a Caffe2 feature a private one seems to make more sense for future deprecation.

### The method `torch.onnx.export` now defaults to ONNX when `operator_export_type` is not specified.
Previously `torch.onnx.export's operator_export_type` intended to default to `ONNX_ATEN_FALLBACK` when `PYTORCH_ONNX_CAFFE2_BUNDLE` was set, but it would never happen as `PYTORCH_ONNX_CAFFE2_BUNDLE` is always undefined

 Co-authored-by: Nikita Shulga <nshulga@fb.com>

Test Plan: Imported from OSS

Reviewed By: jansel

Differential Revision: D32483781

Pulled By: malfet

fbshipit-source-id: e9b447db9466b369e77d747188685495aec3f124
(cherry picked from commit 5fb1eb1)
  • Loading branch information
BowenBao authored and pytorchmergebot committed Feb 10, 2022
1 parent 828e36c commit eb4238f
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 26 deletions.
1 change: 1 addition & 0 deletions torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ if(USE_NUMPY)
target_compile_definitions(torch_python PRIVATE USE_NUMPY)
endif()

list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS BUILD_CAFFE2)
if(HAVE_SOVERSION)
set_target_properties(torch_python PROPERTIES
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})
Expand Down
2 changes: 1 addition & 1 deletion torch/_C/_onnx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from enum import Enum

PYTORCH_ONNX_CAFFE2_BUNDLE: bool
_CAFFE2_ATEN_FALLBACK: bool
PRODUCER_VERSION: str

class TensorProtoDataType(Enum):
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/serialization/export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ void validateBlock(
"\n\nDefined at:\n" + getNodeStackTraceString(node))
}
} else {
#ifdef BUILD_CAFFE2
// Assuming this is a Caffe2 change as it only modifies an aten op
// for operator_export_type == ONNX_ATEN_FALLBACK, which is a common
// pattern for Caffe2-specific scenarios.
if (node->kind() == aten::expand) {
if (operator_export_type ==
onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK) {
Expand All @@ -117,6 +121,7 @@ void validateBlock(
new_node->s_(Symbol::fromQualString("attr::operator"), "expand");
}
}
#endif
if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) {
if (operator_export_type !=
onnx_torch::OperatorExportTypes::ONNX_FALLTHROUGH) {
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/onnx/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ void initONNXBindings(PyObject* module) {

onnx.attr("PRODUCER_VERSION") = py::str(TORCH_VERSION);

#ifdef PYTORCH_ONNX_CAFFE2_BUNDLE
onnx.attr("PYTORCH_ONNX_CAFFE2_BUNDLE") = true;
#ifdef BUILD_CAFFE2
onnx.attr("_CAFFE2_ATEN_FALLBACK") = true;
#else
onnx.attr("PYTORCH_ONNX_CAFFE2_BUNDLE") = false;
onnx.attr("_CAFFE2_ATEN_FALLBACK") = false;
#endif
}
} // namespace onnx
Expand Down
10 changes: 3 additions & 7 deletions torch/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
TensorProtoDataType = _C._onnx.TensorProtoDataType
OperatorExportTypes = _C._onnx.OperatorExportTypes
TrainingMode = _C._onnx.TrainingMode
PYTORCH_ONNX_CAFFE2_BUNDLE = _C._onnx.PYTORCH_ONNX_CAFFE2_BUNDLE
_CAFFE2_ATEN_FALLBACK = _C._onnx._CAFFE2_ATEN_FALLBACK

ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"

Expand Down Expand Up @@ -32,7 +32,7 @@ def _export(*args, **kwargs):


def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
input_names=None, output_names=None, operator_export_type=None,
input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
opset_version=None, do_constant_folding=True, dynamic_axes=None,
keep_initializers_as_inputs=None, custom_opsets=None,
export_modules_as_functions=False):
Expand Down Expand Up @@ -119,11 +119,7 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM
input nodes of the graph, in order.
output_names (list of str, default empty list): names to assign to the
output nodes of the graph, in order.
operator_export_type (enum, default None):
None usually means ``OperatorExportTypes.ONNX``.
However if PyTorch was built with ``-DPYTORCH_ONNX_CAFFE2_BUNDLE``, None means
``OperatorExportTypes.ONNX_ATEN_FALLBACK``.
operator_export_type (enum, default OperatorExportTypes.ONNX):
* ``OperatorExportTypes.ONNX``: Export all ops as regular ONNX ops
(in the default opset domain).
Expand Down
5 changes: 4 additions & 1 deletion torch/onnx/symbolic_opset11.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,10 @@ def _get_arange_dtype(dtype):
def _dim_arange(g, like, dim):
like_shape = g.op("Shape", like)
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
# Caffe2-specific op
is_caffe2_aten_fallback = (sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and
torch.onnx._CAFFE2_ATEN_FALLBACK)
if is_caffe2_aten_fallback:
return g.op("_caffe2::Range", stop)
return arange(g, stop, 4, None, None, None)

Expand Down
7 changes: 5 additions & 2 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -2405,7 +2405,10 @@ def symbolic(g, *args):
def _dim_arange(g, like, dim):
like_shape = g.op("Shape", like)
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
# Caffe2-specific op
is_caffe2_aten_fallback = (sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK and
torch.onnx._CAFFE2_ATEN_FALLBACK)
if is_caffe2_aten_fallback:
return g.op("_caffe2::Range", stop)
else:
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
Expand All @@ -2426,7 +2429,7 @@ def contiguous(g, input, memory_format):

@parse_args("v", "v", "i")
def _pack_padded_sequence(g, input, lengths, batch_first):
# There currently is no PackPadded operator in ONNX. We rely on an
# Currently there is no PackPadded operator in ONNX. We rely on an
# optimization pass to remove this later. It is an error if all
# PackPadded operators cannot be optimized out.
if batch_first:
Expand Down
25 changes: 13 additions & 12 deletions torch/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,10 @@ def exporter_context(model, mode):


def export(model, args, f, export_params=True, verbose=False, training=None,
input_names=None, output_names=None, operator_export_type=None,
input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
opset_version=None, do_constant_folding=True, dynamic_axes=None,
keep_initializers_as_inputs=None, custom_opsets=None,
export_modules_as_functions=False):
if operator_export_type is None:
if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
else:
operator_export_type = OperatorExportTypes.ONNX

_export(model, args, f, export_params, verbose, training, input_names, output_names,
operator_export_type=operator_export_type, opset_version=opset_version,
Expand Down Expand Up @@ -205,7 +200,10 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
torch._C._jit_pass_onnx_remove_print(graph)
torch._C._jit_pass_onnx_preprocess_caffe2(graph)

if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK:
# Caffe2-specific optimization
is_caffe2_aten_fallback = (operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and
torch.onnx._CAFFE2_ATEN_FALLBACK)
if is_caffe2_aten_fallback:
torch.onnx.symbolic_helper._quantized_ops.clear()
# Unpack quantized weights for conv and linear ops and insert into graph.
torch._C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict)
Expand Down Expand Up @@ -656,7 +654,7 @@ def _reset_trace_module_map():
torch.jit._trace._trace_module_map = None

def _export(model, args, f, export_params=True, verbose=False, training=None,
input_names=None, output_names=None, operator_export_type=None,
input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
export_type=ExportTypes.PROTOBUF_FILE, opset_version=None,
do_constant_folding=True, dynamic_axes=None, keep_initializers_as_inputs=None,
fixed_batch_size=False, custom_opsets=None, add_node_names=True,
Expand Down Expand Up @@ -685,7 +683,7 @@ def _export(model, args, f, export_params=True, verbose=False, training=None,
if opset_version is None:
opset_version = _default_onnx_opset_version
if not operator_export_type:
if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
if torch.onnx._CAFFE2_ATEN_FALLBACK:
operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
else:
operator_export_type = OperatorExportTypes.ONNX
Expand Down Expand Up @@ -1021,8 +1019,10 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat

sym_registry.register_version("", opset_version)

# Quantized op symbolics are registered for opset 9 only.
if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and opset_version == 9:
# Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
is_caffe2_aten_fallback = (operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK and
torch.onnx._CAFFE2_ATEN_FALLBACK)
if is_caffe2_aten_fallback and opset_version == 9:
import torch.onnx.symbolic_caffe2
torch.onnx.symbolic_caffe2.register_quantized_ops("caffe2", opset_version)

Expand Down Expand Up @@ -1175,7 +1175,8 @@ def _run_symbolic_function(g, block, n, inputs, env, operator_export_type=Operat

elif ns == "quantized":
domain = ""
if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK:
# Caffe2-specific quantized op
if is_caffe2_aten_fallback:
domain = "caffe2"
symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type)
if symbolic_fn is None:
Expand Down

0 comments on commit eb4238f

Please sign in to comment.