Skip to content

Commit 4617c1b

Browse files
authored
fix bug of paddle.to_tensor and paddle.moveaxis (#39662)
* fix bug of paddle.to_tensor and paddle.moveaxis * fix CI
1 parent 69ab270 commit 4617c1b

File tree

4 files changed

+31
-19
lines changed

4 files changed

+31
-19
lines changed

python/paddle/fluid/tests/unittests/test_transpose_op.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,14 @@ def test_moveaxis2(self):
423423
self.assertEqual(np.array_equal(out.numpy(), expected), True)
424424
paddle.enable_static()
425425

426+
def test_moveaxis3(self):
427+
paddle.disable_static()
428+
x = paddle.to_tensor(
429+
[[1 + 1j, -1 - 1j], [1 + 1j, -1 - 1j], [1 + 1j, -1 - 1j]])
430+
out = x.moveaxis(0, 1)
431+
self.assertEqual(out.shape, [2, 3])
432+
paddle.enable_static()
433+
426434
def test_error(self):
427435
x = paddle.randn([2, 3, 4, 5])
428436
# src must have the same number with dst

python/paddle/fluid/tests/unittests/test_var_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ def _test_place(place):
5151
np.array_equal(x.numpy(), np.array([1.2], 'float16')))
5252
self.assertEqual(x.dtype, core.VarDesc.VarType.FP16)
5353

54+
# set_default_dtype take effect on int
55+
x = paddle.to_tensor(1, place=place)
56+
self.assertTrue(x.dtype, core.VarDesc.VarType.INT64)
57+
5458
# set_default_dtype take effect on float
5559
x = paddle.to_tensor(1.2, place=place, stop_gradient=False)
5660
self.assertTrue(

python/paddle/tensor/creation.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,6 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
110110
"'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.CustomPlace"
111111
)
112112

113-
#Todo(zhouwei): Support allocate tensor on any other specified card
114-
if isinstance(place, core.CUDAPlace) and isinstance(
115-
_current_expected_place(), core.CUDAPlace) and place._get_device_id(
116-
) != _current_expected_place()._get_device_id():
117-
place = _current_expected_place()
118-
119113
if not isinstance(data, np.ndarray):
120114

121115
def _handle_dtype(data, dtype):
@@ -139,7 +133,7 @@ def _handle_dtype(data, dtype):
139133
data.stop_gradient = stop_gradient
140134
return data
141135
elif isinstance(data, (core.LoDTensor, core.Tensor)):
142-
# Note(zhouwei25): should't expose it to users, just for internal use.
136+
# should't expose it to users, just for internal use.
143137
# convert core.Tensor/core.LoDTensor to VarBase first
144138
# Currenly, there is no copy when places are same
145139
data = paddle.Tensor(data)
@@ -152,15 +146,20 @@ def _handle_dtype(data, dtype):
152146
raise TypeError(
153147
"Can't constructs a 'paddle.Tensor' with data type {}, data type must be scalar|list|tuple|numpy.ndarray|paddle.Tensor".
154148
format(type(data)))
155-
if not dtype and data.dtype in [
156-
'float16', 'float32', 'float64', 'complex64', 'complex128'
157-
]:
158-
default_type = paddle.get_default_dtype()
159-
if np.iscomplexobj(data):
160-
default_type = 'complex64' if default_type in [
161-
'float16', 'float32'
162-
] else 'complex128'
163-
data = data.astype(default_type)
149+
if not dtype:
150+
if data.dtype in [
151+
'float16', 'float32', 'float64', 'complex64', 'complex128'
152+
]:
153+
default_type = paddle.get_default_dtype()
154+
if np.iscomplexobj(data):
155+
default_type = 'complex64' if default_type in [
156+
'float16', 'float32'
157+
] else 'complex128'
158+
data = data.astype(default_type)
159+
# Windows default type is 'int32', while Linux/Mac is 'int64'. Unify they.
160+
if data.dtype in ['int32']:
161+
default_type = "int64"
162+
data = data.astype(default_type)
164163

165164
if dtype and convert_dtype(dtype) != data.dtype:
166165
data = data.astype(convert_dtype(dtype))

python/paddle/tensor/manipulation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2737,9 +2737,10 @@ def moveaxis(x, source, destination, name=None):
27372737
out, _ = _C_ops.transpose2(x, 'axis', perm)
27382738
return out
27392739

2740-
check_variable_and_dtype(
2741-
x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
2742-
'moveaxis')
2740+
check_variable_and_dtype(x, 'x', [
2741+
'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'complex64',
2742+
'complex128'
2743+
], 'moveaxis')
27432744

27442745
helper = LayerHelper('moveaxis', **locals())
27452746
out = helper.create_variable_for_type_inference(x.dtype)

0 commit comments

Comments
 (0)