Skip to content

Commit

Permalink
[PIR] support Value.to (PaddlePaddle#66003)
Browse files Browse the repository at this point in the history
  • Loading branch information
gouzil authored Jul 20, 2024
1 parent 69378d8 commit 95293c0
Show file tree
Hide file tree
Showing 4 changed files with 522 additions and 2 deletions.
5 changes: 4 additions & 1 deletion python/paddle/base/dygraph/tensor_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,10 @@ def transform(t, device, dtype, blocking):

@overload
def to(
self: Tensor, device: PlaceLike, blocking: bool | None = ...
self: Tensor,
device: PlaceLike,
dtype: DTypeLike | None = ...,
blocking: bool | None = ...,
) -> Tensor:
...

Expand Down
248 changes: 248 additions & 0 deletions python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
# limitations under the License.


import inspect
import warnings

import numpy as np

from paddle import _C_ops
from paddle.base.libpaddle import DataType
from paddle.base.wrapped_decorator import wrap_decorator
Expand Down Expand Up @@ -602,6 +605,249 @@ def set_shape(self, shape):
def value_hash(self):
return hash(id(self))

def _to(
self,
device=None,
dtype=None,
blocking=None,
):
if device is None and dtype is None and blocking is None:
return self

if device is not None:
if isinstance(device, str):
device = paddle.device._convert_to_place(device)
elif isinstance(
device,
(
paddle.core.Place,
paddle.CPUPlace,
paddle.CUDAPlace,
paddle.CUDAPinnedPlace,
# paddle.XPUPlace, # no support
# paddle.CustomPlace, # no support
),
):
pass
else:
raise ValueError(
"device value error, must be str, paddle.CPUPlace(), paddle.CUDAPlace(), paddle.CUDAPinnedPlace(), paddle.XPUPlace() or paddle.CustomPlace(), but the type of device is "
+ type(device).__name__
)

if blocking is None:
blocking = True
else:
assert isinstance(
blocking, bool
), "blocking value error, must be the True, False or None"

def transform(t, device, dtype, blocking):
if dtype is None:
dtype = t.dtype
t_used = t

# 1. cast Tensor to dtype
if dtype != t_used.dtype:
with paddle.base.framework._dygraph_place_guard(
place=t_used.place
):
t_casted = t_used.cast(dtype=dtype)
else:
t_casted = t_used

# 2. Copy casted Tensor(in CPU or GPU) to device
if isinstance(device, paddle.CUDAPlace):
new_t = t_casted.cuda(blocking=blocking)
elif isinstance(device, paddle.CUDAPinnedPlace):
if blocking is not True:
warnings.warn(
"blocking is not supported, and it will be ignored."
)
new_t = _C_ops.memcpy(self, 2)
elif isinstance(device, paddle.CPUPlace):
new_t = t_casted.cpu()
else:
new_t = t_casted

return new_t

return transform(self, device, dtype, blocking)

def to(self, *args, **kwargs):
"""
Performs Tensor dtype and/or device conversion. A paddle.dtype and place
are inferred from the arguments of ``self.to(*args, **kwargs)``.There are
three ways to call `to`:
1. to(dtype, blocking=True)
2. to(device, dtype=None, blocking=True)
3. to(other, blocking=True)
Returns:
Tensor: self
Examples:
.. code-block:: python
>>> import paddle
>>> tensorx = paddle.to_tensor([1,2,3])
>>> print(tensorx)
Tensor(shape=[3], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[1, 2, 3])
>>> tensorx = tensorx.to("cpu")
>>> print(tensorx.place)
Place(cpu)
>>> tensorx = tensorx.to("float32")
>>> print(tensorx.dtype)
paddle.float32
>>> tensorx = tensorx.to("gpu", "int16")
>>> print(tensorx)
Tensor(shape=[3], dtype=int16, place=Place(gpu:0), stop_gradient=True,
[1, 2, 3])
>>> tensor2 = paddle.to_tensor([4,5,6])
>>> tensor2
Tensor(shape=[3], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[4, 5, 6])
>>> tensor2 = tensor2.to(tensorx)
>>> print(tensor2)
Tensor(shape=[3], dtype=int16, place=Place(gpu:0), stop_gradient=True,
[4, 5, 6])
"""

size_args = len(args)
size_kwargs = len(kwargs)

if size_args + size_kwargs > 3 or size_args + size_kwargs == 0:
raise TypeError(
"to() received too many arguments - expected one of:\n \
* (Union[str, paddle.CPUPlace(), paddle.CUDAPlace(), paddle.CUDAPinnedPlace(), paddle.XPUPlace(), paddle.CustomPlace()] \
device, Union[str, paddle.dtype, numpy.dtype] dtype, bool blocking)\n \
* (Union[str, paddle.dtype, numpy.dtype] dtype, bool blocking)\n \
* (paddle.Tensor other, bool blocking) "
)
valid_keys = {"device", "dtype", "blocking", "other"}
invalid_keys = set(kwargs.keys()) - valid_keys
if len(invalid_keys) != 0:
raise TypeError(
"to() got an unexpected keyword argument "
+ list(invalid_keys)[0]
)

def dtype_first_sig(dtype, blocking=None):
...

def device_first_sig(device, dtype=None, blocking=None):
...

def tensor_like_first_sig(other, blocking=None):
...

class _NoArg:
...

def is_dtype(arg):
valid_dtypes = [
"bfloat16",
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint8",
"complex64",
"complex128",
"bool",
]
return isinstance(arg, (paddle.dtype, np.dtype)) or (
isinstance(arg, str) and arg.lower() in valid_dtypes
)

def is_device(arg):
return isinstance(arg, (paddle.core.Place, str))

def is_tensor(arg):
return isinstance(arg, paddle.pir.Value)

def create_positional_arg_extractor(position: int):
def extract_positional_arg(args, kwargs):
if len(args) > position:
return args[position]
return _NoArg()

return extract_positional_arg

def create_keyword_arg_extractor(key: str, position: int):
def extract_keyword_arg(args, kwargs):
if (
key in kwargs
and len(kwargs) > position
and list(kwargs.keys())[position] == key
):
return kwargs[key]
return _NoArg()

return extract_keyword_arg

def chain_extractors(*extractors):
def chain(args, kwargs):
for extractor in extractors:
if not isinstance(arg := extractor(args, kwargs), _NoArg):
return arg
return _NoArg()

return chain

def dispatch_to_signature(*args, **kwargs):
# dict[signature, (extractor, condition)]
signature_map = {
dtype_first_sig: (
chain_extractors(
create_positional_arg_extractor(position=0),
create_keyword_arg_extractor(key="dtype", position=0),
),
is_dtype,
),
device_first_sig: (
chain_extractors(
create_positional_arg_extractor(position=0),
create_keyword_arg_extractor(key="device", position=0),
),
is_device,
),
tensor_like_first_sig: (
chain_extractors(
create_positional_arg_extractor(position=0),
create_keyword_arg_extractor(key="other", position=0),
),
is_tensor,
),
}

for sig, (extractor, condition) in signature_map.items():
if not isinstance(
arg := extractor(args, kwargs), _NoArg
) and condition(arg):
bound_args = inspect.signature(sig).bind(*args, **kwargs)
bound_args.apply_defaults()
return bound_args.arguments
raise ValueError("No matching signature found.")

args = dispatch_to_signature(*args, **kwargs)
other = args.get("other", None)
if other is not None:
args.pop("other")
args["dtype"] = other.dtype
# in dy2static, we need show warning for this case
other.place # noqa: B018

return self._to(**args)

@fake_interface_only
def numpy(self):
"""
Expand Down Expand Up @@ -660,6 +906,8 @@ def register_hook(self):
('to_dense', to_dense),
('indices', indices),
('values', values),
("_to", _to),
("to", to),
("numpy", numpy),
("register_hook", register_hook),
# For basic operators
Expand Down
1 change: 0 additions & 1 deletion test/dygraph_to_static/test_tensor_attr_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
'set_value',
'set_vocab',
'strides',
'to',
'to_sparse_coo',
'to_sparse_csr',
'tolist',
Expand Down
Loading

0 comments on commit 95293c0

Please sign in to comment.