Skip to content

Commit 68198a6

Browse files
committed
binding tensor method
1 parent 23d3aa1 commit 68198a6

File tree

4 files changed

+35
-62
lines changed

4 files changed

+35
-62
lines changed

python/paddle/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
import paddle.batch
3333
batch = batch.batch
3434
import paddle.tensor
35-
from .fluid import monkey_patch_variable, monkey_patch_math_varbase
35+
from .fluid import monkey_patch_variable
36+
from .fluid.dygraph import monkey_patch_math_varbase
3637
monkey_patch_variable()
3738
monkey_patch_math_varbase()
3839
import paddle.framework

python/paddle/fluid/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
from .dygraph.base import enable_dygraph, disable_dygraph
8989
from .io import save, load, load_program_state, set_program_state
9090
from .dygraph.checkpoint import save_dygraph, load_dygraph
91-
from .dygraph.varbase_patch_methods import monkey_patch_varbase, monkey_patch_math_varbase
91+
from .dygraph.varbase_patch_methods import monkey_patch_varbase
9292
from . import generator
9393
Tensor = LoDTensor
9494
enable_imperative = enable_dygraph

python/paddle/fluid/dygraph/math_op_patch.py

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def __impl__(self, other_var):
218218
return __impl__
219219

220220
# Todo(zhouwei): implement dygraph template to adapt to any function, receive('op_type', 'arg_template')
221-
# Such as _method_creator_('addmm', 'x, y, alpha=1.0, beta=1.0, name=None')
221+
# Such as _method_creator_('addmm', 'x, y, alpha=1.0, beta=1.0, name=None'). It can reduce call time.
222222
def _method_creator_(op_type, arg_template=None):
223223
def __impl__(self):
224224
op = getattr(core.ops, op_type)
@@ -246,36 +246,22 @@ def __impl__(self):
246246
('ndim', _ndim_),
247247
('size', lambda x: x.shape),
248248
# Type2: From Template that create core.ops automatically. It's recommended.
249-
('__add__',
250-
_binary_creator_('__add__', 'elementwise_add', False, _scalar_add_)),
249+
('__add__', _binary_creator_('__add__', 'elementwise_add', False, _scalar_add_)),
251250
## a+b == b+a. Do not need to reverse explicitly
252-
('__radd__',
253-
_binary_creator_('__radd__', 'elementwise_add', False, _scalar_add_)),
254-
('__sub__', _binary_creator_('__sub__', 'elementwise_sub', False,
255-
_scalar_sub_)),
256-
('__rsub__', _binary_creator_('__rsub__', 'elementwise_sub', True,
257-
_scalar_rsub_)),
258-
('__mul__',
259-
_binary_creator_('__mul__', 'elementwise_mul', False, _scalar_mul_)),
251+
('__radd__', _binary_creator_('__radd__', 'elementwise_add', False, _scalar_add_)),
252+
('__sub__', _binary_creator_('__sub__', 'elementwise_sub', False, _scalar_sub_)),
253+
('__rsub__', _binary_creator_('__rsub__', 'elementwise_sub', True, _scalar_rsub_)),
254+
('__mul__', _binary_creator_('__mul__', 'elementwise_mul', False, _scalar_mul_)),
260255
## a*b == b*a. Do not need to reverse explicitly
261-
('__rmul__',
262-
_binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)),
263-
('__div__', _binary_creator_('__div__', 'elementwise_div', False,
264-
_scalar_div_)),
265-
('__truediv__', _binary_creator_('__truediv__', 'elementwise_div',
266-
False, _scalar_div_)),
267-
('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True,
268-
None)),
269-
('__rtruediv__', _binary_creator_('rtruediv__', 'elementwise_div', True,
270-
None)),
271-
('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False,
272-
None)),
273-
('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True,
274-
None)),
275-
('__floordiv__', _binary_creator_('__floordiv__',
276-
'elementwise_floordiv', False, None)),
277-
('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False,
278-
None)),
256+
('__rmul__', _binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)),
257+
('__div__', _binary_creator_('__div__', 'elementwise_div', False, _scalar_div_)),
258+
('__truediv__', _binary_creator_('__truediv__', 'elementwise_div', False, _scalar_div_)),
259+
('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True, None)),
260+
('__rtruediv__', _binary_creator_('rtruediv__', 'elementwise_div', True, None)),
261+
('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False, None)),
262+
('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True, None)),
263+
('__floordiv__', _binary_creator_('__floordiv__', 'elementwise_floordiv', False, None)),
264+
('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False, None)),
279265
## for logical compare
280266
('__eq__', _binary_creator_('__eq__', 'equal', False, None)),
281267
('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)),

python/paddle/fluid/layers/math_op_patch.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -417,37 +417,23 @@ def __impl__(self, other_var):
417417
# b=-a
418418
('__neg__', _neg_),
419419
('astype', astype),
420-
('__add__', _binary_creator_('__add__', 'elementwise_add', False,
421-
_scalar_add_)),
422-
# a+b == b+a. Do not need to reverse explicitly
423-
('__radd__', _binary_creator_('__radd__', 'elementwise_add', False,
424-
_scalar_add_)),
425-
('__sub__', _binary_creator_('__sub__', 'elementwise_sub', False,
426-
_scalar_sub_)),
427-
('__rsub__', _binary_creator_('__rsub__', 'elementwise_sub', True,
428-
_scalar_rsub_)),
429-
('__mul__', _binary_creator_('__mul__', 'elementwise_mul', False,
430-
_scalar_mul_)),
431-
# a*b == b*a. Do not need to reverse explicitly
432-
('__rmul__', _binary_creator_('__rmul__', 'elementwise_mul', False,
433-
_scalar_mul_)),
434-
('__div__', _binary_creator_('__div__', 'elementwise_div', False,
435-
_scalar_div_)),
436-
('__truediv__', _binary_creator_('__truediv__', 'elementwise_div',
437-
False, _scalar_div_)),
438-
('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True,
439-
None)),
440-
('__rtruediv__', _binary_creator_('rtruediv__', 'elementwise_div', True,
441-
None)),
442-
('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False,
443-
None)),
444-
('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True,
445-
None)),
446-
('__floordiv__', _binary_creator_('__floordiv__',
447-
'elementwise_floordiv', False, None)),
448-
('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False,
449-
None)),
450-
# for logical compare
420+
('__add__', _binary_creator_('__add__', 'elementwise_add', False, _scalar_add_)),
421+
# a+b == b+a. Do not need to reverse explicitly
422+
('__radd__', _binary_creator_('__radd__', 'elementwise_add', False, _scalar_add_)),
423+
('__sub__', _binary_creator_('__sub__', 'elementwise_sub', False, _scalar_sub_)),
424+
('__rsub__', _binary_creator_('__rsub__', 'elementwise_sub', True, _scalar_rsub_)),
425+
('__mul__', _binary_creator_('__mul__', 'elementwise_mul', False, _scalar_mul_)),
426+
# a*b == b*a. Do not need to reverse explicitly
427+
('__rmul__', _binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)),
428+
('__div__', _binary_creator_('__div__', 'elementwise_div', False, _scalar_div_)),
429+
('__truediv__', _binary_creator_('__truediv__', 'elementwise_div', False, _scalar_div_)),
430+
('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True, None)),
431+
('__rtruediv__', _binary_creator_('rtruediv__', 'elementwise_div', True, None)),
432+
('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False, None)),
433+
('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True, None)),
434+
('__floordiv__', _binary_creator_('__floordiv__', 'elementwise_floordiv', False, None)),
435+
('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False, None)),
436+
# for logical compare
451437
('__eq__', _binary_creator_('__eq__', 'equal', False, None)),
452438
('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)),
453439
('__lt__', _binary_creator_('__lt__', 'less_than', False, None)),

0 commit comments

Comments
 (0)