Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten_convolution_backward function #1707

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,4 @@ tests/mylib.onnxlib
**/serde_test_profiles/*
tools/ort_rewriter_profiling/.logs/*
tools/ort_rewriter_profiling/onnx_models/*
/dump_TestOperatorsOnnxrt
97 changes: 93 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2093,6 +2093,7 @@
return result


@torch_op("aten::convolution_backward", trace_only=True)
def aten_convolution_backward(
grad_output: TensorType,
input: TensorType,
Expand All @@ -2108,7 +2109,87 @@
) -> tuple[TensorType, TensorType, TensorType]:
"""convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)"""

raise NotImplementedError()
# Compute weight.grad : dW_t = X_t * dZ_t
input_t = op.Transpose(input, perm=[1, 0, 2, 3])
dz_t = op.Transpose(grad_output, perm=[1, 0, 2, 3])
dw_t = op.Conv(input_t, dz_t)
dw = op.Transpose(dw_t, perm=[1, 0, 2, 3])
axes = op.Constant(value_ints=[0, 2, 3])
db = op.ReduceSum(grad_output, axes, keepdims=0)

# Compute x.grad: dx = dZ(+0) * W_rot180
# Assume: grad_output=(20,13,48,38)
z_height = op.Shape(grad_output, start=2, end=3) # 48
z_width = op.Shape(grad_output, start=3, end=4) # 38

if stride[0] != 1 or stride[1] != 1:
raise NotImplementedError("stride != 1 is not supported yet")

Check warning on line 2126 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2126

Added line #L2126 was not covered by tests
# if stride[0] != 1: # dilation
# dz_height = z_height * stride[0] - stride[0] + 1
# dz_width = z_width * stride[1] - stride[1] + 1
# pos = _help(z_height, dz_width, stride)
# pos = []
# for j in range(z_height):
# for i in range(0, dz_width, stride[1]):
# pos.append(i + j * dz_width * stride[0])

# index_tensor = op.Constant(value_ints=pos)
# index_tensor = op.Reshape(index_tensor, z_shape)
# # this should not work because the kernel_shape is attribute
# dz = op.MaxUnpool(grad_output, index_tensor, kernel_shape=[dz_height - z_height + 1, dz_width - z_width + 1])

# # Computing padding size
Comment on lines +2127 to +2141

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
# Assume: input=(20,16,50,40)
x_height = op.Shape(input, start=2, end=3) # 50
x_width = op.Shape(input, start=3, end=4) # 40
# Assume: weight=(13,16,3,3)
w_height = op.Shape(weight, start=2, end=3) # 3
w_width = op.Shape(weight, start=3, end=4) # 3
tmp_int = x_height - z_height + w_height - 1 # 50-48+3-1=4
tmp_float = op.Cast(tmp_int, to=FLOAT.dtype)
pad_height = op.Cast(
op.Div(tmp_float, op.Constant(value_floats=[2.0])), to=INT64.dtype
) # 4/2=2
tmp_int = x_width - z_width + w_width - 1 # 40-38+3-1=4
tmp_float = op.Cast(tmp_int, to=FLOAT.dtype)
pad_width = op.Cast(
op.Div(tmp_float, op.Constant(value_floats=[2.0])), to=INT64.dtype
) # 4/2=2
pads = op.Concat( # [0,0,2,2,0,0,2,2]
# begin of dim0, dim1, dim2, dim3
op.Constant(value_ints=[0]),
op.Constant(value_ints=[0]),
pad_height,
pad_width,
# end of dim0, dim1, dim2, dim3
op.Constant(value_ints=[0]),
op.Constant(value_ints=[0]),
pad_height,
pad_width,
axis=0,
)
dz_pad = op.Pad(grad_output, pads) # enlarge the grad_output to (20,13,52,42)

# Transpose from (13,16,3,3) to (16,13,3,3)
w_transpose = op.Transpose(weight, perm=[1, 0, 2, 3])
# Rotate weight (13,16,3,3) with 180 degree: np.rot90(w, 2) -> (13,6,3,3)
w_shape_0 = op.Shape(w_transpose, start=0, end=1) # 13
w_shape_1 = op.Shape(w_transpose, start=1, end=2) # 6
w_shape_2 = op.Constant(value_ints=[1]) # 1
w_shape_3 = op.Constant(value_ints=[-1]) # -1
w_shape_new = op.Concat(w_shape_0, w_shape_1, w_shape_2, w_shape_3, axis=0) # (13,16,1,-1)
w_new = op.Reshape(w_transpose, w_shape_new) # reshape to (13,16,1,-1)
# reverse the values in the last dim (axes=3), e.g. [1,2,3....,9] -> [9,...,3,2,1]
starts = op.Constant(value_int=[-1])
ends = op.Constant(value_int=[-1000])
xiaowuhu marked this conversation as resolved.
Show resolved Hide resolved
axes = op.Constant(value_int=[3])
steps = op.Constant(value_int=[-1])
w_slice = op.Slice(w_new, starts, ends, axes, steps) # weight[:,:,:,-1:-1000:-1]
weight_rot180 = op.Reshape(w_slice, op.Shape(w_transpose)) # reshape to (13,16,3,3)
# dx = dz(pad0) * w(rot180)
dx = op.Conv(dz_pad, weight_rot180)
# Todo: when dx is bigger than input, e.g. 29x29 vs. 28x28, need to delete last row and column of dx
return dx, dw, db


def aten_convolution_backward_overrideable(
Expand Down Expand Up @@ -2934,7 +3015,7 @@
indices_1d = op.Reshape(indices, neg_1)
# Get weight out according to indices_1d,
new_weight = op.Gather(weight, indices_1d)
# This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)

Check warning on line 3018 in onnxscript/function_libs/torch_lib/ops/core.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "happends" is a misspelling of "happens" Raw Output: ./onnxscript/function_libs/torch_lib/ops/core.py:3018:11: "happends" is a misspelling of "happens"
new_weight = op.Mul(new_weight, op.Unsqueeze(per_sample_weights, axes=1))
weight_dim_1 = op.Reshape(op.Shape(weight, start=1), neg_1)
indices_size = op.Shape(indices_1d)
Expand Down Expand Up @@ -3074,7 +3155,7 @@
# Get weight out according to indices,
# e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]]
indices_weight = op.Gather(weight, indices)
# This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)

Check warning on line 3158 in onnxscript/function_libs/torch_lib/ops/core.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "happends" is a misspelling of "happens" Raw Output: ./onnxscript/function_libs/torch_lib/ops/core.py:3158:11: "happends" is a misspelling of "happens"
indices_weight = op.Mul(indices_weight, op.Unsqueeze(per_sample_weights, axes=1))

# The element in sequence must be FLOAT32 dtype due to ORT bug
Expand Down Expand Up @@ -4659,7 +4740,7 @@
return op.LessOrEqual(self, other)


@torch_op(("aten::le.Tensor", "aten::less_equal.Tensor", "_operator::le"))
@torch_op(("aten::le.Scalar", "aten::le.Tensor", "aten::less_equal.Tensor", "_operator::le"))
def aten_le_bool(self: BOOL, other: BOOL) -> BOOL:
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -5564,7 +5645,7 @@
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
"""ONNX Mul doesn't support Boolean, so use And as an equivalent operator."""

# TODO(justinchuby): Handle cases where type reconcilation is not enough,

Check warning on line 5648 in onnxscript/function_libs/torch_lib/ops/core.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "reconcilation" is a misspelling of "reconciliation" Raw Output: ./onnxscript/function_libs/torch_lib/ops/core.py:5648:49: "reconcilation" is a misspelling of "reconciliation"
# since different ONNX operators are used based on different data types.

return op.And(self, other)
Expand Down Expand Up @@ -6583,10 +6664,18 @@
raise NotImplementedError()


def aten_prod(self: TensorType, dtype: Optional[int] = None) -> TensorType:
@torch_op(("aten::prod"), trace_only=True)
def aten_prod(self: TReal, dtype: Optional[int] = None) -> TReal:
"""prod(Tensor self, *, ScalarType? dtype=None) -> Tensor"""

raise NotImplementedError()
return op.ReduceProd(self)

Check warning on line 6671 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6671

Added line #L6671 was not covered by tests


@torch_op("aten::prod.dim_int", trace_only=True)
def aten_prod_dim(self: TReal, dim: int, keepdim: bool = False, dtype: Optional[int] = None) -> TReal:
"""prod(Tensor self, *, ScalarType? dtype=None) -> Tensor"""

return op.ReduceProd(self, axes=dim, keepdims=keepdim)

Check warning on line 6678 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6678

Added line #L6678 was not covered by tests


def aten_promote_types(type1: int, type2: int) -> int:
Expand Down
67 changes: 63 additions & 4 deletions onnxscript/tools/training_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,59 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from __future__ import annotations

import glob
import os
from typing import Any

import torch
from torch.onnx import ExportOptions
from torch.onnx import _OrtBackend as OrtBackend
from torch.onnx import _OrtBackendOptions as OrtBackendOptions


def make_aot_ort(dynamic: bool = False):
def make_aot_ort(dynamic: bool = False) -> Any:
"""Implements an autograd backend for torch.compile based on onnxrt backend."""
export_options = ExportOptions(dynamic_shapes=dynamic)
options = OrtBackendOptions(export_options=export_options)
ort_backend = OrtBackend(options=options)
return ort_backend


def train_loop(model, *args, loss_fn=None, optimizer=None):
"""Implements a training loop to be used in tests."""
def train_loop(
model: Any,
*args,
loss_fn: Any | None = None,
optimizer: Any | None = None,
dump_onnx_models: bool = False,
dump_prefix: str = "dump_train_loop",
dump_clean_first: bool = True,
) -> tuple[Any, tuple[Any, ...]] | tuple[Any, tuple[Any, ...], list[str]]:
Comment on lines +25 to +33

Check notice

Code scanning / CodeQL

Returning tuples with varying lengths Note

train_loop returns
tuple of size 2
and
tuple of size 3
.
"""Implements a training loop to be used in tests.
The function returns the forward output and gradients in a tuple.

if dump_onnx_models is True, the function returns the forward output,
the gradients in a tuple and the generated onnx_files.
If there is no graph break, there should be
two graphs, one for forward, one for backward.

Args:
model: pytorch model
args: inputs
loss_fn: loss function, default is MSELoss
optimizer: optimizer, default is SGD
dump_onnx_models: dumps the model onnxrt backend is producing
dump_prefix: names will be `<dump_prefix>0.onnx`, `<dump_prefix>1.onnx`, ...
dump_clean_first: clean all files starting with the given prefix

Returns:
- the forward outputs
- the backwards gradients
- the dumped onnw models, 2 at least unless the forward, backward
were called before this function is executed or if the model
is not a compiled model
"""

if loss_fn is None:
loss_fn = torch.nn.MSELoss()
Expand All @@ -28,6 +65,16 @@
# Unnecessary in this situation but added for best practices
model.train()

if dump_onnx_models:
if dump_clean_first:
names = glob.glob(f"{dump_prefix}*")
for name in names:
os.remove(name)

Check warning on line 72 in onnxscript/tools/training_helper.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/training_helper.py#L72

Added line #L72 was not covered by tests

old_value = os.environ.get("ONNXRT_DUMP_PATH", None)
os.environ["ONNXRT_DUMP_PATH"] = f"{dump_prefix}_forward"
existing_files = glob.glob(f"{dump_prefix}*.onnx")

# Compute prediction and loss
pred = model(*args)
if isinstance(pred, tuple):
Expand All @@ -39,6 +86,8 @@
loss = loss_fn(v, torch.ones_like(v))

# Backpropagation
if dump_onnx_models:
os.environ["ONNXRT_DUMP_PATH"] = f"{dump_prefix}_backward"
loss.backward()
optimizer.step()
# skip that part to retrieve the gradients
Expand All @@ -47,4 +96,14 @@
# returns the gradients
res = tuple(p.grad for p in model.parameters() if p.grad is not None)
assert len(res) > 0, f"No gradient, loss is {loss}"
return res

if dump_onnx_models:
if old_value is None:
del os.environ["ONNXRT_DUMP_PATH"]
else:
os.environ["ONNXRT_DUMP_PATH"] = old_value

Check warning on line 104 in onnxscript/tools/training_helper.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/training_helper.py#L104

Added line #L104 was not covered by tests
new_files = glob.glob(f"{dump_prefix}*.onnx")
added_files = set(new_files) - set(existing_files)
return pred, res, [f for f in new_files if f in added_files]

return pred, res
109 changes: 109 additions & 0 deletions tests/function_libs/torch_lib/backward_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=not-callable

import copy
import sys
import unittest

import torch

import onnxscript.tools.training_helper
import onnxscript.tools.transformers_models
import onnxscript.tools.transformers_models.llama
Comment on lines +12 to +13
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import onnxscript.tools.transformers_models
import onnxscript.tools.transformers_models.llama

I wonder why ruff doesn't warn the unused imports

from onnxscript._internal.version_utils import has_transformers, torch_older_than


class TestBackward(unittest.TestCase):
@unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows")
@unittest.skipIf(not has_transformers(), reason="transformers is missing")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@unittest.skipIf(not has_transformers(), reason="transformers is missing")

@unittest.skipIf(torch_older_than("2.4"), reason="fails to export")
def test_backward_working(self):
class SimpleCNNN(torch.nn.Module):
def __init__(self):
super().__init__()

self.fc1 = torch.nn.Linear(14, 10)

def forward(self, x):
return torch.nn.functional.relu(self.fc1(x))

input_tensors = (torch.randn(1, 1, 14, 14),)
model = SimpleCNNN()
local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False)

compiled_model = torch.compile(
copy.deepcopy(model),
backend=local_aot_ort,
dynamic=False,
fullgraph=True,
)

expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( # pylint: disable=unbalanced-tuple-unpacking
model, *input_tensors
)
results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( # pylint: disable=unbalanced-tuple-unpacking
compiled_model,
*input_tensors,
dump_onnx_models=True,
dump_prefix="_dump_testbw_working",
dump_clean_first=True,
)
torch.testing.assert_close(expected_results[0], results[0], atol=1e-5, rtol=1e-5)

# Checking there is only two generated graphs otherwise, it means there are graph breaks.
self.assertEqual(len(onnx_models), 2)
torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5)

@unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows")
# @unittest.skipIf(not has_transformers(), reason="transformers is missing")
@unittest.skipIf(torch_older_than("2.4"), reason="fails to export")
# @unittest.skipIf(True, reason="aten.conv_backward not implemented yet.")
def test_backward_conv(self):
class SimpleCNNN(torch.nn.Module):
def __init__(self):
super().__init__()

self.conv1 = torch.nn.Conv2d(
in_channels=1,
out_channels=2,
kernel_size=3,
padding=(0, 0), # not support padding=1, will do it soon
)
self.fc1 = torch.nn.Linear(12, 10)

def forward(self, x):
y = torch.nn.functional.relu(self.conv1(x))
z = self.fc1(y)
return z

input_tensors = (torch.randn(1, 1, 14, 14),)
model = SimpleCNNN()
local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False)

compiled_model = torch.compile(
copy.deepcopy(model),
backend=local_aot_ort,
dynamic=False,
fullgraph=True,
)

expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( # pylint: disable=unbalanced-tuple-unpacking
model, *input_tensors
)
results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( # pylint: disable=unbalanced-tuple-unpacking
compiled_model,
*input_tensors,
dump_onnx_models=True,
dump_prefix="_dump_testbw_conv",
dump_clean_first=True,
)
torch.testing.assert_close(expected_results[0], results[0], atol=1e-5, rtol=1e-5)

# Checking there is only two generated graphs otherwise, it means there are graph breaks.
self.assertEqual(len(onnx_models), 2)
torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5)


if __name__ == "__main__":
unittest.main(verbosity=2)
Loading