Skip to content

Commit

Permalink
[SOT][PIR] support numpy and register_hook attr (PaddlePaddle#66008)
Browse files Browse the repository at this point in the history
  • Loading branch information
gouzil authored Jul 19, 2024
1 parent d5d1cb9 commit 4311592
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 18 deletions.
11 changes: 6 additions & 5 deletions python/paddle/jit/sot/infer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .utils import Cache, Singleton, map_if_extend, meta_str

DynamicSymbolT = TypeVar("DynamicSymbolT")
SOT_INFER_META_INNER_VAR = "___SOT_INFER_META_INNER_VAR"


class SymbolicInt(metaclass=Singleton):
Expand Down Expand Up @@ -144,10 +145,7 @@ def from_tensor(

@staticmethod
def from_value(value) -> MetaInfo:
if isinstance(value, paddle.pir.Value):
name = "Value@NoName"
else:
name = value.name
name = SOT_INFER_META_INNER_VAR
dtype = MetaInfo._handle_legacy_ir_amp_dtype(value.dtype)
shape = [SymbolicInt() if dim == -1 else dim for dim in value.shape]
return MetaInfo(
Expand All @@ -160,6 +158,9 @@ def from_value(value) -> MetaInfo:
value.place,
)

def is_inner_var(self):
return self.name == SOT_INFER_META_INNER_VAR

def is_dynamic_shape(self):
"""
if SymbolicInt in shape, return True
Expand Down Expand Up @@ -202,7 +203,7 @@ def __init__(self):
# self.var_cache = {}
# self.main_program = paddle.static.Program()
# self.startup_program = paddle.static.Program()
self.var_name_generator = UniqueNameGenerator("infer_meta_variable_")
self.var_name_generator = UniqueNameGenerator(SOT_INFER_META_INNER_VAR)

def gen_name(self, meta):
name = f"{meta.dtype}_{meta.stop_gradient}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
ConstTypes,
FallbackError,
NameGenerator,
get_tensor_methods,
log,
paddle_tensor_methods,
printable,
)
from ....utils.exceptions import HasNoAttributeError, InnerError
Expand Down Expand Up @@ -564,11 +564,18 @@ def getattr(self, name: str, default=None):
"is_integer": paddle.is_integer,
"is_floating_point": paddle.is_floating_point,
}
if name in ["dtype", "type", "name", "persistable", "stop_gradient"]:
if name == "name" and self.meta.name.startswith(
"infer_meta_variable_tmp"
):
raise BreakGraphError(f"{self.meta.name} is a middle tensor.")
if name in ["name", "place", "type"] and self.meta.is_inner_var():
raise BreakGraphError(
f"{self.meta.name} is a middle tensor. get {name} property."
)
if name in [
"dtype",
"type",
"name",
"persistable",
"stop_gradient",
"place",
]:
return VariableFactory.from_value(
getattr(self.meta, name),
self.graph,
Expand All @@ -585,7 +592,7 @@ def getattr(self, name: str, default=None):
return BuiltinVariable(
builtin_fn, self.graph, DanglingTracker()
).bind(self, name)
elif name in paddle_tensor_methods:
elif name in get_tensor_methods():
from .callable import TensorFunctionVariable

fn_var = TensorFunctionVariable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,14 @@ class TensorFunctionVariable(FunctionVariable):
def __init__(
self, method_name: str, graph: FunctionGraph, tracker: Tracker
):
fn = getattr(paddle.static.Variable, method_name)
fn = getattr(
(
paddle.pir.Value
if paddle.framework.use_pir_api()
else paddle.static.Variable
),
method_name,
)
super().__init__(fn, graph, tracker)
self.method_name = method_name

Expand Down
2 changes: 1 addition & 1 deletion python/paddle/jit/sot/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@
)
from .magic_methods import magic_method_builtin_dispatch # noqa: F401
from .paddle_api_config import ( # noqa: F401
get_tensor_methods,
is_break_graph_tensor_methods,
is_inplace_api,
is_not_supported_paddle_layer,
paddle_tensor_methods,
)
from .utils import ( # noqa: F401
Cache,
Expand Down
25 changes: 23 additions & 2 deletions python/paddle/jit/sot/utils/paddle_api_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,22 @@ def is_inplace_api(func):
return func in inplace_apis


def get_tensor_methods():
def get_variable_methods():
return [
member_name
for member_name, member in inspect.getmembers(paddle.static.Variable)
if inspect.isfunction(member)
]


def get_value_methods():
return [
member_name
for member_name, member in inspect.getmembers(paddle.pir.Value)
if inspect.isfunction(member) or inspect.ismethoddescriptor(member)
]


def get_paddle_api():
modules = [
paddle,
Expand Down Expand Up @@ -74,7 +82,20 @@ def get_paddle_api():
)


paddle_tensor_methods = get_tensor_methods()
def create_tensor_methods_getter():
value_methods = get_value_methods()
variable_methods = get_variable_methods()

def _get_tensor_methods():
if paddle.framework.use_pir_api():
return value_methods
else:
return variable_methods

return _get_tensor_methods


get_tensor_methods = create_tensor_methods_getter()
paddle_api_list = get_paddle_api()

# TODO(Aurelius84): It seems that we use it to judge 'in_paddle_module()'.
Expand Down
36 changes: 36 additions & 0 deletions python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,40 @@ def set_shape(self, shape):
def value_hash(self):
return hash(id(self))

@fake_interface_only
def numpy(self):
"""
**Notes**:
**This API is ONLY available in Dygraph mode**
Returns a numpy array shows the value of current :ref:`api_guide_Variable_en`
Returns:
ndarray: The numpy value of current Variable.
Returns type:
ndarray: dtype is same as current Variable
Examples:
.. code-block:: python
>>> import paddle
>>> import paddle.base as base
>>> from paddle.nn import Linear
>>> import numpy as np
>>> data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
>>> with base.dygraph.guard():
... linear = Linear(32, 64)
... data_tensor = paddle.to_tensor(data)
... x = linear(data_tensor)
... print(x.numpy())
"""
pass

@fake_interface_only
def register_hook(self):
"""
Value don't have 'register_hook' interface in static graph mode
But this interface can greatly facilitate dy2static.
So we give a error here.
"""
pass

import paddle

value_methods = [
Expand All @@ -626,6 +660,8 @@ def value_hash(self):
('to_dense', to_dense),
('indices', indices),
('values', values),
("numpy", numpy),
("register_hook", register_hook),
# For basic operators
(
'__add__',
Expand Down
2 changes: 0 additions & 2 deletions test/dygraph_to_static/test_tensor_attr_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,11 @@
'layout',
'nnz',
'num_shard',
'numpy',
'offset',
'pin_memory',
'placements',
'process_mesh',
'reconstruct_from_',
'register_hook',
'retain_grads',
'rows',
'set_string_list',
Expand Down

0 comments on commit 4311592

Please sign in to comment.