Skip to content

Commit 96eec95

Browse files
anjali411facebook-github-bot
authored andcommitted
torch.from_numpy for complex dtypes (pytorch#35531)
Summary: Pull Request resolved: pytorch#35531 Differential Revision: D20693581 Pulled By: anjali411 fbshipit-source-id: d53e26b4175452fa00b287efbfceea18104c1364
1 parent f101949 commit 96eec95

File tree

4 files changed

+28
-14
lines changed

4 files changed

+28
-14
lines changed

caffe2/python/pybind_state.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ int CaffeToNumpyType(const TypeMeta& meta) {
113113
{TypeMeta::Id<bool>(), NPY_BOOL},
114114
{TypeMeta::Id<double>(), NPY_DOUBLE},
115115
{TypeMeta::Id<float>(), NPY_FLOAT},
116+
{TypeMeta::Id<std::complex<double>>(), NPY_COMPLEX128},
117+
{TypeMeta::Id<std::complex<float>>(), NPY_COMPLEX64},
116118
{TypeMeta::Id<at::Half>(), NPY_FLOAT16},
117119
{TypeMeta::Id<int>(), NPY_INT},
118120
{TypeMeta::Id<int8_t>(), NPY_INT8},

test/test_torch.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4527,6 +4527,8 @@ def test_from_numpy(self):
45274527
np.double,
45284528
np.float,
45294529
np.float16,
4530+
np.complex64,
4531+
np.complex128,
45304532
np.int64,
45314533
np.int32,
45324534
np.int16,
@@ -4535,22 +4537,29 @@ def test_from_numpy(self):
45354537
np.longlong,
45364538
np.bool,
45374539
]
4540+
complex_dtypes = [
4541+
np.complex64,
4542+
np.complex128,
4543+
]
4544+
45384545
for dtype in dtypes:
45394546
array = np.array([1, 2, 3, 4], dtype=dtype)
45404547
tensor_from_array = torch.from_numpy(array)
45414548
# TODO: change to tensor equality check once HalfTensor
45424549
# implements `==`
45434550
for i in range(len(array)):
45444551
self.assertEqual(tensor_from_array[i], array[i])
4545-
# This is a special test case for Windows
4546-
# https://github.com/pytorch/pytorch/issues/22615
4547-
array2 = array % 2
4548-
tensor_from_array2 = torch.from_numpy(array2)
4549-
for i in range(len(array2)):
4550-
self.assertEqual(tensor_from_array2[i], array2[i])
4552+
# ufunc 'remainder' not supported for complex dtypes
4553+
if dtype not in complex_dtypes:
4554+
# This is a special test case for Windows
4555+
# https://github.com/pytorch/pytorch/issues/22615
4556+
array2 = array % 2
4557+
tensor_from_array2 = torch.from_numpy(array2)
4558+
for i in range(len(array2)):
4559+
self.assertEqual(tensor_from_array2[i], array2[i])
45514560

45524561
# Test unsupported type
4553-
array = np.array([1, 2, 3, 4], dtype=np.complex)
4562+
array = np.array([1, 2, 3, 4], dtype=np.uint16)
45544563
with self.assertRaises(TypeError):
45554564
tensor_from_array = torch.from_numpy(array)
45564565

torch/_torch_docs.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def merge_dicts(*dicts):
7070
returned tensor. Default: ``False``.
7171
pin_memory (bool, optional): If set, returned tensor would be allocated in
7272
the pinned memory. Works only for CPU tensors. Default: ``False``.
73-
memory_format (:class:`torch.memory_format`, optional): the desired memory format of
73+
memory_format (:class:`torch.memory_format`, optional): the desired memory format of
7474
returned Tensor. Default: ``torch.contiguous_format``.
7575
"""))
7676

@@ -86,7 +86,7 @@ def merge_dicts(*dicts):
8686
returned tensor. Default: ``False``.
8787
pin_memory (bool, optional): If set, returned tensor would be allocated in
8888
the pinned memory. Works only for CPU tensors. Default: ``False``.
89-
memory_format (:class:`torch.memory_format`, optional): the desired memory format of
89+
memory_format (:class:`torch.memory_format`, optional): the desired memory format of
9090
returned Tensor. Default: ``torch.preserve_format``.
9191
""")
9292

@@ -2199,8 +2199,9 @@ def merge_dicts(*dicts):
21992199
tensor is not resizable.
22002200
22012201
It currently accepts :attr:`ndarray` with dtypes of ``numpy.float64``,
2202-
``numpy.float32``, ``numpy.float16``, ``numpy.int64``, ``numpy.int32``,
2203-
``numpy.int16``, ``numpy.int8``, ``numpy.uint8``, and ``numpy.bool``.
2202+
``numpy.float32``, ``numpy.float16``, ``numpy.complex64``, ``numpy.complex128``,
2203+
``numpy.int64``, ``numpy.int32``, ``numpy.int16``, ``numpy.int8``, ``numpy.uint8``,
2204+
and ``numpy.bool``.
22042205
22052206
Example::
22062207

torch/csrc/utils/tensor_numpy.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,11 @@ at::Tensor tensor_from_numpy(PyObject* obj) {
190190

191191
int aten_to_numpy_dtype(const ScalarType scalar_type) {
192192
switch (scalar_type) {
193-
case kComplexDouble: return NPY_COMPLEX128;
194-
case kComplexFloat: return NPY_COMPLEX64;
195193
case kDouble: return NPY_DOUBLE;
196194
case kFloat: return NPY_FLOAT;
197195
case kHalf: return NPY_HALF;
196+
case kComplexDouble: return NPY_COMPLEX128;
197+
case kComplexFloat: return NPY_COMPLEX64;
198198
case kLong: return NPY_INT64;
199199
case kInt: return NPY_INT32;
200200
case kShort: return NPY_INT16;
@@ -211,6 +211,8 @@ ScalarType numpy_dtype_to_aten(int dtype) {
211211
case NPY_DOUBLE: return kDouble;
212212
case NPY_FLOAT: return kFloat;
213213
case NPY_HALF: return kHalf;
214+
case NPY_COMPLEX64: return kComplexFloat;
215+
case NPY_COMPLEX128: return kComplexDouble;
214216
case NPY_INT16: return kShort;
215217
case NPY_INT8: return kChar;
216218
case NPY_UINT8: return kByte;
@@ -236,7 +238,7 @@ ScalarType numpy_dtype_to_aten(int dtype) {
236238
if (!pytype) throw python_error();
237239
throw TypeError(
238240
"can't convert np.ndarray of type %s. The only supported types are: "
239-
"float64, float32, float16, int64, int32, int16, int8, uint8, and bool.",
241+
"float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.",
240242
((PyTypeObject*)pytype.get())->tp_name);
241243
}
242244

0 commit comments

Comments
 (0)