Skip to content

[SOT] Add inline call codeobj to global guard #69803

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def __init__(
self._fn_var = fn_variable
self.return_value: VariableBase | None = None
self._fn_value = fn_variable.value
super().__init__(fn_variable.get_code(), fn_variable.graph)
self._code_var = fn_variable.get_code()
super().__init__(self._code_var.value, fn_variable.graph)
self._name = "Inline"
self._prepare_locals(*args, **kwargs)
self._prepare_closure()
Expand Down Expand Up @@ -273,6 +274,7 @@ def inline_call(self) -> VariableBase:
"""
Execute the inline call of the function.
"""
self._graph.add_global_guarded_variable(self._code_var)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

放在 run 前后差别不大

  • 放在前面,如果内部 breakgraph,global guard 会 rollback
  • 放在后面,成功后才会加,也是一样的

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前没有专门设计 CodeObjectVariable,因此使用的是基类的 guard,会触发 xxx == ___object_yyy== 会比较重,这里后续可优化为 ID MATCH

self.run()
assert self.return_value is not None
return self.return_value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,11 @@ def __init__(
def get_py_value(self, allow_tensor=False):
return self.value

def get_code(self) -> types.CodeType:
return self.value.__code__
def get_code(self) -> VariableBase:
code_obj_var = VariableFactory.from_value(
self.value.__code__, self.graph, GetAttrTracker(self, "__code__")
)
return code_obj_var

def bind(self, instance: VariableBase, name: str):
method_var = MethodVariable(
Expand Down
30 changes: 29 additions & 1 deletion test/sot/test_06_call_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

import unittest

from test_case_base import TestCaseBase
from test_case_base import (
TestCaseBase,
test_instruction_translator_cache_context,
)

import paddle

Expand Down Expand Up @@ -150,5 +153,30 @@ def test_call8(self):
self.assert_results(foo_8, paddle.to_tensor(9))


def apply_fn(fn, x):
return fn(x)


def fn1(x):
return x + 1


def fn2(x):
return x - 1


class TestApplyDifferentFunctions(TestCaseBase):
def test_apply_fn(self):
x = 1
with test_instruction_translator_cache_context() as ctx:
self.assertEqual(ctx.translate_count, 0)
self.assert_results(apply_fn, fn1, x)
self.assertEqual(ctx.translate_count, 1)
self.assert_results(apply_fn, fn2, x)
self.assertEqual(ctx.translate_count, 2)
self.assert_results(apply_fn, fn1, x)
self.assertEqual(ctx.translate_count, 2)


if __name__ == "__main__":
unittest.main()