@@ -110,12 +110,6 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
110
110
"'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.CustomPlace"
111
111
)
112
112
113
- #Todo(zhouwei): Support allocate tensor on any other specified card
114
- if isinstance (place , core .CUDAPlace ) and isinstance (
115
- _current_expected_place (), core .CUDAPlace ) and place ._get_device_id (
116
- ) != _current_expected_place ()._get_device_id ():
117
- place = _current_expected_place ()
118
-
119
113
if not isinstance (data , np .ndarray ):
120
114
121
115
def _handle_dtype (data , dtype ):
@@ -139,7 +133,7 @@ def _handle_dtype(data, dtype):
139
133
data .stop_gradient = stop_gradient
140
134
return data
141
135
elif isinstance (data , (core .LoDTensor , core .Tensor )):
142
- # Note(zhouwei25): should't expose it to users, just for internal use.
136
+ # should't expose it to users, just for internal use.
143
137
# convert core.Tensor/core.LoDTensor to VarBase first
144
138
# Currenly, there is no copy when places are same
145
139
data = paddle .Tensor (data )
@@ -152,15 +146,20 @@ def _handle_dtype(data, dtype):
152
146
raise TypeError (
153
147
"Can't constructs a 'paddle.Tensor' with data type {}, data type must be scalar|list|tuple|numpy.ndarray|paddle.Tensor" .
154
148
format (type (data )))
155
- if not dtype and data .dtype in [
156
- 'float16' , 'float32' , 'float64' , 'complex64' , 'complex128'
157
- ]:
158
- default_type = paddle .get_default_dtype ()
159
- if np .iscomplexobj (data ):
160
- default_type = 'complex64' if default_type in [
161
- 'float16' , 'float32'
162
- ] else 'complex128'
163
- data = data .astype (default_type )
149
+ if not dtype :
150
+ if data .dtype in [
151
+ 'float16' , 'float32' , 'float64' , 'complex64' , 'complex128'
152
+ ]:
153
+ default_type = paddle .get_default_dtype ()
154
+ if np .iscomplexobj (data ):
155
+ default_type = 'complex64' if default_type in [
156
+ 'float16' , 'float32'
157
+ ] else 'complex128'
158
+ data = data .astype (default_type )
159
+ # Windows default type is 'int32', while Linux/Mac is 'int64'. Unify they.
160
+ if data .dtype in ['int32' ]:
161
+ default_type = "int64"
162
+ data = data .astype (default_type )
164
163
165
164
if dtype and convert_dtype (dtype ) != data .dtype :
166
165
data = data .astype (convert_dtype (dtype ))
0 commit comments