@@ -116,19 +116,21 @@ def test_properties():
116
116
117
117
@pytest .mark .parametrize ("func" , [bool , float , int ])
118
118
@pytest .mark .parametrize ("shape" , [(1 ,), (1 , 1 ), (1 , 1 , 1 )])
119
- def test_copy_scalar_with_func (func , shape ):
120
- X = dpt .usm_ndarray (shape )
121
- Y = np .arange (1 , X .size + 1 , dtype = X .dtype )
122
- X .usm_data .copy_from_host (Y .view ("|u1" ))
119
+ @pytest .mark .parametrize ("dtype" , ["|b1" , "|f8" , "|i8" ])
120
+ def test_copy_scalar_with_func (func , shape , dtype ):
121
+ X = dpt .usm_ndarray (shape , dtype = dtype )
122
+ Y = np .arange (1 , X .size + 1 , dtype = dtype ).reshape (shape )
123
+ X .usm_data .copy_from_host (Y .reshape (- 1 ).view ("|u1" ))
123
124
assert func (X ) == func (Y )
124
125
125
126
126
127
@pytest .mark .parametrize ("method" , ["__bool__" , "__float__" , "__int__" ])
127
128
@pytest .mark .parametrize ("shape" , [(1 ,), (1 , 1 ), (1 , 1 , 1 )])
128
- def test_copy_scalar_with_method (method , shape ):
129
- X = dpt .usm_ndarray (shape )
130
- Y = np .arange (1 , X .size + 1 , dtype = X .dtype )
131
- X .usm_data .copy_from_host (Y .view ("|u1" ))
129
+ @pytest .mark .parametrize ("dtype" , ["|b1" , "|f8" , "|i8" ])
130
+ def test_copy_scalar_with_method (method , shape , dtype ):
131
+ X = dpt .usm_ndarray (shape , dtype = dtype )
132
+ Y = np .arange (1 , X .size + 1 , dtype = dtype ).reshape (shape )
133
+ X .usm_data .copy_from_host (Y .reshape (- 1 ).view ("|u1" ))
132
134
assert getattr (X , method )() == getattr (Y , method )()
133
135
134
136
@@ -140,6 +142,34 @@ def test_copy_scalar_invalid_shape(func, shape):
140
142
func (X )
141
143
142
144
145
+ @pytest .mark .parametrize ("shape" , [(1 ,), (1 , 1 ), (1 , 1 , 1 )])
146
+ @pytest .mark .parametrize ("index_dtype" , ["|i8" ])
147
+ def test_usm_ndarray_as_index (shape , index_dtype ):
148
+ X = dpt .usm_ndarray (shape , dtype = index_dtype )
149
+ Xnp = np .arange (1 , X .size + 1 , dtype = index_dtype ).reshape (shape )
150
+ X .usm_data .copy_from_host (Xnp .reshape (- 1 ).view ("|u1" ))
151
+ Y = np .arange (X .size + 1 )
152
+ assert Y [X ] == Y [1 ]
153
+
154
+
155
+ @pytest .mark .parametrize ("shape" , [(2 ,), (1 , 2 ), (3 , 4 , 5 ), (0 ,)])
156
+ @pytest .mark .parametrize ("index_dtype" , ["|i8" ])
157
+ def test_usm_ndarray_as_index_invalid_shape (shape , index_dtype ):
158
+ X = dpt .usm_ndarray (shape , dtype = index_dtype )
159
+ Y = np .arange (X .size + 1 )
160
+ with pytest .raises (IndexError ):
161
+ Y [X ]
162
+
163
+
164
+ @pytest .mark .parametrize ("shape" , [(1 ,), (1 , 1 ), (1 , 1 , 1 )])
165
+ @pytest .mark .parametrize ("index_dtype" , ["|f8" ])
166
+ def test_usm_ndarray_as_index_invalid_dtype (shape , index_dtype ):
167
+ X = dpt .usm_ndarray (shape , dtype = index_dtype )
168
+ Y = np .arange (X .size + 1 )
169
+ with pytest .raises (IndexError ):
170
+ Y [X ]
171
+
172
+
143
173
@pytest .mark .parametrize (
144
174
"ind" ,
145
175
[
0 commit comments