Skip to content

Commit 37fe645

Browse files
[Relax] Ingest Tensor.clamp from torch export (#17725)
Allow handling of Torch.clamp when only min is passed, only max is passed, or tensors are passed as arguments.
1 parent 36a92a4 commit 37fe645

File tree

6 files changed

+276
-33
lines changed

6 files changed

+276
-33
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 86 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# pylint: disable=import-outside-toplevel
2020
"""Base class for PyTorch FX Graph importer."""
2121
import abc
22+
import math
2223
from typing import Callable, Dict, Optional, Tuple, Union
2324

2425
from tvm import relax
@@ -141,19 +142,94 @@ def _celu(self, node: fx.Node) -> relax.Var:
141142

142143
def _clamp(self, node: fx.Node) -> relax.Expr:
143144
args = self.retrieve_args(node)
144-
a_min = args[1] if len(args) > 1 else node.kwargs["min"]
145-
a_max = args[2] if len(args) > 2 else node.kwargs["max"]
145+
x = args[0]
146+
a_min = args[1] if len(args) > 1 else node.kwargs.get("min", -math.inf)
147+
a_max = args[2] if len(args) > 2 else node.kwargs.get("max", math.inf)
148+
149+
a_min = -math.inf if a_min is None else a_min
150+
a_max = math.inf if a_max is None else a_max
151+
152+
# Handle the case where a_min is a tensor
146153
if not isinstance(a_min, (int, float)):
147-
raise ValueError(
148-
f"TVM only supports constant min value for torch.clamp/clip, "
149-
f"but got {a_min} with type {type(a_min)}"
154+
from torch import fx
155+
156+
if isinstance(a_min, fx.Node):
157+
# Extract relax Expr (needed for fx.tracer)
158+
a_min = self.env[a_min]
159+
assert isinstance(a_min, relax.Expr), (
160+
f"Unexpected argument type "
161+
f"passed to torch.clamp/clip: {a_min} with type {type(a_min)}"
150162
)
163+
a_min = self.block_builder.emit(relax.op.broadcast_to(a_min, self.shape_of(x)))
164+
x = self.block_builder.emit(relax.op.maximum(x, a_min))
165+
a_min = -math.inf
166+
167+
# Handle the case where a_max is a tensor
151168
if not isinstance(a_max, (int, float)):
152-
raise ValueError(
153-
f"TVM only supports constant max value for torch.clamp/clip, "
154-
f"but got {a_max} with type {type(a_max)}"
169+
from torch import fx
170+
171+
if isinstance(a_max, fx.Node):
172+
# Extract relax Expr (needed for fx.tracer)
173+
a_max = self.env[a_max]
174+
assert isinstance(a_max, relax.Expr), (
175+
f"Unexpected argument type "
176+
f"passed to torch.clamp/clip: {a_max} with type {type(a_max)}"
177+
)
178+
a_max = self.block_builder.emit(relax.op.broadcast_to(a_max, self.shape_of(x)))
179+
x = self.block_builder.emit(relax.op.minimum(x, a_max))
180+
a_max = math.inf
181+
182+
return self.block_builder.emit(relax.op.clip(x, a_min, a_max))
183+
184+
def _clamp_min(self, node: fx.Node) -> relax.Expr:
185+
args = self.retrieve_args(node)
186+
x = args[0]
187+
a_min = args[1] if len(args) > 1 else node.kwargs.get("min", -math.inf)
188+
a_max = math.inf
189+
190+
a_min = -math.inf if a_min is None else a_min
191+
192+
# Handle the case where a_min is a tensor
193+
if not isinstance(a_min, (int, float)):
194+
from torch import fx
195+
196+
if isinstance(a_min, fx.Node):
197+
# Extract relax Expr (needed for fx.tracer)
198+
a_min = self.env[a_min]
199+
assert isinstance(a_min, relax.Expr), (
200+
f"Unexpected argument type "
201+
f"passed to torch.clamp/clip: {a_min} with type {type(a_min)}"
155202
)
156-
return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
203+
a_min = self.block_builder.emit(relax.op.broadcast_to(a_min, self.shape_of(x)))
204+
x = self.block_builder.emit(relax.op.maximum(x, a_min))
205+
a_min = -math.inf
206+
207+
return self.block_builder.emit(relax.op.clip(x, a_min, a_max))
208+
209+
def _clamp_max(self, node: fx.Node) -> relax.Expr:
210+
args = self.retrieve_args(node)
211+
x = args[0]
212+
a_min = -math.inf
213+
a_max = args[2] if len(args) > 2 else node.kwargs.get("max", math.inf)
214+
215+
a_max = math.inf if a_max is None else a_max
216+
217+
# Handle the case where a_max is a tensor
218+
if not isinstance(a_max, (int, float)):
219+
from torch import fx
220+
221+
if isinstance(a_max, fx.Node):
222+
# Extract relax Expr (needed for fx.tracer)
223+
a_max = self.env[a_max]
224+
assert isinstance(a_max, relax.Expr), (
225+
f"Unexpected argument type "
226+
f"passed to torch.clamp/clip: {a_max} with type {type(a_max)}"
227+
)
228+
a_max = self.block_builder.emit(relax.op.broadcast_to(a_max, self.shape_of(x)))
229+
x = self.block_builder.emit(relax.op.minimum(x, a_max))
230+
a_max = math.inf
231+
232+
return self.block_builder.emit(relax.op.clip(x, a_min, a_max))
157233

158234
def _elu(self, node: fx.Node) -> relax.Var:
159235
x = self.env[node.args[0]]
@@ -696,8 +772,8 @@ def _embedding_impl(
696772
return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size]))
697773

698774
def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> relax.Var:
699-
from torch.fx.immutable_collections import immutable_list
700775
import numpy as np # type: ignore
776+
from torch.fx.immutable_collections import immutable_list
701777

702778
if isinstance(normalized_shape, (immutable_list, tuple)):
703779
normalized_shape = tuple(normalized_shape)

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ def create_convert_map(
193193
"bitwise_not.default": self._unary_op(relax.op.bitwise_not),
194194
"ceil.default": self._unary_op(relax.op.ceil),
195195
"clamp.default": self._clamp,
196+
"clamp_min.default": self._clamp_min,
197+
"clamp_max.default": self._clamp_max,
196198
"cos.default": self._unary_op(relax.op.cos),
197199
"cosh.default": self._unary_op(relax.op.cosh),
198200
"dropout.default": lambda node: self.env[node.args[0]],
@@ -294,6 +296,7 @@ def create_convert_map(
294296
"argmin.default": self._argmax_argmin(relax.op.argmin),
295297
# tensor manipulation
296298
"cat.default": self._cat,
299+
"clamp.Tensor": self._clamp,
297300
"concat.default": self._cat,
298301
"copy_.default": self._copy_,
299302
"cumsum.default": self._cumsum,

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
# pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck
1919
# pylint: disable=import-outside-toplevel
2020
"""PyTorch FX frontend of Relax."""
21-
from typing import Callable, Dict, List, Tuple, Union
2221
from functools import partial, reduce
22+
from typing import Callable, Dict, List, Tuple, Union
2323

2424
import tvm
2525
from tvm import relax
@@ -598,6 +598,7 @@ def create_convert_map(
598598
self,
599599
) -> Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]]:
600600
import operator
601+
601602
from torch import nn
602603

603604
return {

tests/python/relax/test_from_exported_to_cuda.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,87 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
5656
np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5)
5757

5858

59+
@tvm.testing.parametrize_targets("cuda")
60+
def test_tensor_clamp(target, dev):
61+
class ClampBothTensor(torch.nn.Module):
62+
def __init__(self):
63+
super().__init__()
64+
self.register_buffer("min_val", torch.tensor(-1.0))
65+
self.register_buffer("max_val", torch.tensor(1.0))
66+
67+
def forward(self, x):
68+
return x.clamp(min=self.min_val, max=self.max_val)
69+
70+
class ClampBothInt(torch.nn.Module):
71+
def __init__(self):
72+
super().__init__()
73+
self.min_val = -1
74+
self.max_val = 1
75+
76+
def forward(self, x):
77+
return x.clamp(min=self.min_val, max=self.max_val)
78+
79+
class ClampMinOnlyTensor(torch.nn.Module):
80+
def __init__(self):
81+
super().__init__()
82+
self.register_buffer("min_val", torch.tensor(0.0))
83+
84+
def forward(self, x):
85+
return x.clamp(min=self.min_val)
86+
87+
class ClampMinOnlyInt(torch.nn.Module):
88+
def __init__(self):
89+
super().__init__()
90+
self.min_val = 0
91+
92+
def forward(self, x):
93+
return x.clamp(min=self.min_val)
94+
95+
class ClampMaxOnlyTensor(torch.nn.Module):
96+
def __init__(self):
97+
super().__init__()
98+
self.register_buffer("max_val", torch.tensor(0.5))
99+
100+
def forward(self, x):
101+
return x.clamp(max=self.max_val)
102+
103+
class ClampMaxOnlyInt(torch.nn.Module):
104+
def __init__(self):
105+
super().__init__()
106+
self.max_val = 0.5
107+
108+
def forward(self, x):
109+
return x.clamp(max=self.max_val)
110+
111+
class ClampDifferentValues(torch.nn.Module):
112+
def __init__(self):
113+
super().__init__()
114+
self.min_val = -2
115+
self.max_val = 2
116+
117+
def forward(self, x):
118+
return x.clamp(min=self.min_val, max=self.max_val)
119+
120+
# Create random data with values outside our clamp ranges
121+
raw_data = np.random.uniform(-3.0, 3.0, (2, 3, 4, 5)).astype(np.float32)
122+
123+
torch_module0 = ClampBothTensor().eval()
124+
torch_module1 = ClampBothInt().eval()
125+
torch_module2 = ClampMinOnlyTensor().eval()
126+
torch_module3 = ClampMinOnlyInt().eval()
127+
torch_module4 = ClampMaxOnlyTensor().eval()
128+
torch_module5 = ClampMaxOnlyInt().eval()
129+
torch_module6 = ClampDifferentValues().eval()
130+
131+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev)
132+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1, target, dev)
133+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2, target, dev)
134+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev)
135+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module4, target, dev)
136+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module5, target, dev)
137+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module6, target, dev)
138+
139+
59140
@tvm.testing.parametrize_targets("cuda")
60141
def test_tensor_expand_as(target, dev):
61142
class ExpandAs0(torch.nn.Module):

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,18 +135,70 @@ def forward(self, input):
135135
class expected_clamp:
136136
@R.function
137137
def main(
138-
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
138+
input: R.Tensor((1, 3, 10, 10), dtype="float32"),
139139
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
140-
# block 0
141140
with R.dataflow():
142-
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(input_1, 0.1, 0.5)
141+
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
142+
input,
143+
R.prim_value(T.float64(0.10000000000000001)),
144+
R.prim_value(T.float64(0.5)),
145+
)
143146
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
144147
R.output(gv)
145148
return gv
146149

147150
verify_model(Clamp(), example_args, {}, expected_clamp)
148151

152+
class ClampMinOnly(Module):
153+
def forward(self, input):
154+
return torch.clamp(input, min=0.5, max=None)
155+
156+
@tvm.script.ir_module
157+
class expected_clamp_min_only:
158+
@R.function
159+
def main(
160+
input: R.Tensor((1, 3, 10, 10), dtype="float32"),
161+
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
162+
with R.dataflow():
163+
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
164+
input, R.prim_value(T.float64(0.5)), R.prim_value(T.float64("inf"))
165+
)
166+
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
167+
R.output(gv)
168+
return gv
169+
170+
verify_model(ClampMinOnly(), example_args, {}, expected_clamp_min_only)
171+
172+
class ClampTensors(Module):
173+
def forward(self, input):
174+
return torch.clamp(input, min=input, max=input)
175+
176+
@tvm.script.ir_module
177+
class expected_clamp_tensors:
178+
@R.function
179+
def main(
180+
input: R.Tensor((1, 3, 10, 10), dtype="float32"),
181+
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
182+
with R.dataflow():
183+
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.broadcast_to(
184+
input, R.shape([1, 3, 10, 10])
185+
)
186+
lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.maximum(input, lv)
187+
lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.broadcast_to(
188+
input, R.shape([1, 3, 10, 10])
189+
)
190+
lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.minimum(lv1, lv2)
191+
lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
192+
lv3, R.prim_value(T.float64("-inf")), R.prim_value(T.float64("inf"))
193+
)
194+
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv4,)
195+
R.output(gv)
196+
return gv
197+
198+
verify_model(ClampTensors(), example_args, {}, expected_clamp_tensors)
199+
149200
# dropout
201+
150202
class Dropout1(Module):
151203
def __init__(self):
152204
super().__init__()
@@ -3248,3 +3300,7 @@ def main(
32483300
exported_program = export(Identity(), args=example_args)
32493301
mod = from_exported_program(exported_program, no_bind_return_tuple=True)
32503302
tvm.ir.assert_structural_equal(mod, Expected)
3303+
3304+
3305+
if __name__ == "__main__":
3306+
tvm.testing.main()

0 commit comments

Comments
 (0)