@@ -117,19 +117,21 @@ def test_properties():
117
117
118
118
@pytest .mark .parametrize ("func" , [bool , float , int ])
119
119
@pytest .mark .parametrize ("shape" , [(1 ,), (1 , 1 ), (1 , 1 , 1 )])
120
- def test_copy_scalar_with_func (func , shape ):
121
- X = dpt .usm_ndarray (shape )
122
- Y = np .arange (1 , X .size + 1 , dtype = X .dtype )
123
- X .usm_data .copy_from_host (Y .view ("|u1" ))
120
+ @pytest .mark .parametrize ("dtype" , ["|b1" , "|f8" , "|i8" ])
121
+ def test_copy_scalar_with_func (func , shape , dtype ):
122
+ X = dpt .usm_ndarray (shape , dtype = dtype )
123
+ Y = np .arange (1 , X .size + 1 , dtype = dtype ).reshape (shape )
124
+ X .usm_data .copy_from_host (Y .reshape (- 1 ).view ("|u1" ))
124
125
assert func (X ) == func (Y )
125
126
126
127
127
128
@pytest .mark .parametrize ("method" , ["__bool__" , "__float__" , "__int__" ])
128
129
@pytest .mark .parametrize ("shape" , [(1 ,), (1 , 1 ), (1 , 1 , 1 )])
129
- def test_copy_scalar_with_method (method , shape ):
130
- X = dpt .usm_ndarray (shape )
131
- Y = np .arange (1 , X .size + 1 , dtype = X .dtype )
132
- X .usm_data .copy_from_host (Y .view ("|u1" ))
130
+ @pytest .mark .parametrize ("dtype" , ["|b1" , "|f8" , "|i8" ])
131
+ def test_copy_scalar_with_method (method , shape , dtype ):
132
+ X = dpt .usm_ndarray (shape , dtype = dtype )
133
+ Y = np .arange (1 , X .size + 1 , dtype = dtype ).reshape (shape )
134
+ X .usm_data .copy_from_host (Y .reshape (- 1 ).view ("|u1" ))
133
135
assert getattr (X , method )() == getattr (Y , method )()
134
136
135
137
@@ -141,6 +143,34 @@ def test_copy_scalar_invalid_shape(func, shape):
141
143
func (X )
142
144
143
145
146
+ @pytest .mark .parametrize ("shape" , [(1 ,), (1 , 1 ), (1 , 1 , 1 )])
147
+ @pytest .mark .parametrize ("index_dtype" , ["|i8" ])
148
+ def test_usm_ndarray_as_index (shape , index_dtype ):
149
+ X = dpt .usm_ndarray (shape , dtype = index_dtype )
150
+ Xnp = np .arange (1 , X .size + 1 , dtype = index_dtype ).reshape (shape )
151
+ X .usm_data .copy_from_host (Xnp .reshape (- 1 ).view ("|u1" ))
152
+ Y = np .arange (X .size + 1 )
153
+ assert Y [X ] == Y [1 ]
154
+
155
+
156
+ @pytest .mark .parametrize ("shape" , [(2 ,), (1 , 2 ), (3 , 4 , 5 ), (0 ,)])
157
+ @pytest .mark .parametrize ("index_dtype" , ["|i8" ])
158
+ def test_usm_ndarray_as_index_invalid_shape (shape , index_dtype ):
159
+ X = dpt .usm_ndarray (shape , dtype = index_dtype )
160
+ Y = np .arange (X .size + 1 )
161
+ with pytest .raises (IndexError ):
162
+ Y [X ]
163
+
164
+
165
+ @pytest .mark .parametrize ("shape" , [(1 ,), (1 , 1 ), (1 , 1 , 1 )])
166
+ @pytest .mark .parametrize ("index_dtype" , ["|f8" ])
167
+ def test_usm_ndarray_as_index_invalid_dtype (shape , index_dtype ):
168
+ X = dpt .usm_ndarray (shape , dtype = index_dtype )
169
+ Y = np .arange (X .size + 1 )
170
+ with pytest .raises (IndexError ):
171
+ Y [X ]
172
+
173
+
144
174
@pytest .mark .parametrize (
145
175
"ind" ,
146
176
[
0 commit comments