Skip to content

Commit de50a92

Browse files
SigureMoCopilotgouzil
authored
[SOT][DynamicShape] Fix SymbolicVariable get_py_value reversed call and break graph on infermeta encount TypeError (#70009)
--------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: gouzil <66515297+gouzil@users.noreply.github.com>
1 parent b31ee6b commit de50a92

File tree

4 files changed

+65
-7
lines changed

4 files changed

+65
-7
lines changed

python/paddle/jit/sot/opcode_translator/executor/function_graph.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
map_if,
5353
switch_symbol_registry,
5454
)
55+
from ...utils.exceptions import BreakGraphError
5556
from ..instruction_utils import get_instructions
5657
from .guard import Guard, StringifiedExpression, make_guard
5758
from .mutable_data import MutationDel, MutationNew, MutationSet
@@ -661,7 +662,9 @@ def try_infer_meta_fn(args, kwargs) -> Any:
661662
for arg in flatten_vars
662663
):
663664
# TODO(zrr1999): maybe we can continue to fallback to all args are constant.
664-
raise e
665+
raise BreakGraphError(
666+
f"InferMeta encount {type(e)}, but all args are not symbolic."
667+
)
665668

666669
args, kwargs = map_if(
667670
(args, kwargs),
@@ -686,7 +689,9 @@ def try_infer_meta_fn(args, kwargs) -> Any:
686689
isinstance(arg, SymbolicVariable)
687690
for arg in flatten_vars
688691
):
689-
raise e
692+
raise BreakGraphError(
693+
f"InferMeta encount {type(e)}, but all args are not symbolic."
694+
)
690695

691696
args, kwargs = map_structure(
692697
replace_symbolic_var_with_constant_var, (args, kwargs)

python/paddle/jit/sot/opcode_translator/executor/variables/basic.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@
106106
STATIC_DIM_FREQ_THRESHOLD = 5
107107

108108

109+
def method_to_reverse_method(method_name: str) -> str | None:
110+
if not method_name.startswith("__") or not method_name.endswith("__"):
111+
return None
112+
name = method_name[2:-2]
113+
return f"__r{name}__"
114+
115+
109116
class ConstantVariable(VariableBase):
110117
"""
111118
ConstantVariable is a subclass of VariableBase used to wrap a Variable of the const type.
@@ -802,11 +809,30 @@ def get_py_value(self, allow_tensor: bool = False) -> bool | int | float:
802809
), f"self.value is None, but tracker is not SymbolicOperationTracker. tracker: {self.tracker}"
803810
inputs = self.tracker.inputs
804811
assert len(inputs) >= 1
805-
other_inputs_value = [x.get_py_value() for x in inputs[1:]]
806-
self.value = getattr(
807-
inputs[0].get_py_value(), self.tracker.method_name
808-
)(*other_inputs_value)
809-
assert isinstance(self.value, (bool, int, float))
812+
input_values = [x.get_py_value() for x in inputs]
813+
value = getattr(input_values[0], self.tracker.method_name)(
814+
*input_values[1:]
815+
)
816+
# TODO(SigureMo): A Temporary solution for the case that the method is not implemented.
817+
# e.g. In user code, we have `1 * 0.1`, the lhs is a SymbolicVariable, and the rhs is a float.
818+
# We trace the method `__mul__` from the lhs, but actually, python use `float.__rmul__`,
819+
# `int.__mul__(float)` is not implemented. So we get NotImplemented here.
820+
# We need to find a better way to handle this case.
821+
if isinstance(value, type(NotImplemented)):
822+
reversed_method = method_to_reverse_method(
823+
self.tracker.method_name
824+
)
825+
if reversed_method is None:
826+
raise InnerError(
827+
f"Unsupported method {self.tracker.method_name} for SymbolicVariable"
828+
)
829+
value = getattr(input_values[1], reversed_method)(
830+
input_values[0], *input_values[2:]
831+
)
832+
self.value = value
833+
assert isinstance(
834+
self.value, (bool, int, float)
835+
), f"SymbolicVariable.get_py_value() should return bool, int or float, but got {type(self.value)}"
810836
return self.value
811837

812838
def get_py_type(self):

test/sot/test_16_paddle_api.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ def paddle_api_function_call_concat(
3838
return paddle.concat([x, y], axis=axis)
3939

4040

41+
def paddle_api_function_breakgraph_when_type_error(
42+
x: paddle.Tensor, axis: paddle.Tensor
43+
):
44+
return paddle.nn.functional.softmax(x, axis=axis)
45+
46+
4147
class TestPaddleApiCall(TestCaseBase):
4248
def test_paddle_api_method_call(self):
4349
self.assert_results(paddle_api_method_call, paddle.to_tensor(2.0))
@@ -55,6 +61,13 @@ def test_paddle_api_function_call_concat(self):
5561
self.assert_results(paddle_api_function_call_concat, a, b, 0)
5662
self.assert_results(paddle_api_function_call_concat, a, b, 1)
5763

64+
def test_paddle_api_function_breakgraph_when_type_error(self):
65+
x = paddle.to_tensor([[1, 2], [3, 4]], dtype=paddle.float32)
66+
axis = paddle.to_tensor(1)
67+
self.assert_results(
68+
paddle_api_function_breakgraph_when_type_error, x, axis
69+
)
70+
5871

5972
if __name__ == "__main__":
6073
unittest.main()

test/sot/test_sot_dynamic_shape.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import math
1718
import unittest
1819

1920
from test_case_base import (
@@ -58,6 +59,12 @@ def dynamic_shape_in_list(x, shape):
5859
return x.reshape(shape)
5960

6061

62+
def dynamic_shape_int_mul_float(x):
63+
y = x * 0.5
64+
z = math.sin(y) # Trigger get_py_value
65+
return z
66+
67+
6168
class CustomConv(paddle.nn.Conv2D):
6269
def __init__(self, *args, **kwargs):
6370
super().__init__(*args, **kwargs)
@@ -193,6 +200,13 @@ def test_pad_dynamic_shape_fallback(self):
193200
self.assert_results(pad_func, paddle.randn([1, 3, 224, 224]), i)
194201
self.assertEqual(ctx.translate_count, i)
195202

203+
def test_dynamic_shape_int_mul_float(self):
204+
with allow_dynamic_shape_guard(
205+
True
206+
), test_instruction_translator_cache_context() as ctx:
207+
for i in range(1, 6):
208+
self.assert_results(dynamic_shape_int_mul_float, i)
209+
196210

197211
if __name__ == '__main__':
198212
unittest.main()

0 commit comments

Comments
 (0)