@@ -193,16 +193,16 @@ def test_mean_over_axis_0_unsupported_out_types(
193
193
input = dpt .empty ((height , width ), dtype = input_type , device = device )
194
194
output = dpt .empty (width , dtype = output_type , device = device )
195
195
196
- if func (input , output ):
197
- print (output_type )
198
196
assert func (input , output ) is None
199
197
200
198
201
199
@pytest .mark .parametrize (
202
200
"func, device, input_type, output_type" ,
203
201
product (mean_sum , all_devices , [dpt .float32 ], [dpt .float32 ]),
204
202
)
205
- def test_mean_over_axis_0_f_contig_input (func , device , input_type , output_type ):
203
+ def test_mean_sum_over_axis_0_f_contig_input (
204
+ func , device , input_type , output_type
205
+ ):
206
206
skip_unsupported (device , input_type )
207
207
skip_unsupported (device , output_type )
208
208
@@ -212,16 +212,14 @@ def test_mean_over_axis_0_f_contig_input(func, device, input_type, output_type):
212
212
input = dpt .empty ((height , width ), dtype = input_type , device = device ).T
213
213
output = dpt .empty (width , dtype = output_type , device = device )
214
214
215
- if func (input , output ):
216
- print (output_type )
217
215
assert func (input , output ) is None
218
216
219
217
220
218
@pytest .mark .parametrize (
221
219
"func, device, input_type, output_type" ,
222
220
product (mean_sum , all_devices , [dpt .float32 ], [dpt .float32 ]),
223
221
)
224
- def test_mean_over_axis_0_f_contig_output (
222
+ def test_mean_sum_over_axis_0_f_contig_output (
225
223
func , device , input_type , output_type
226
224
):
227
225
skip_unsupported (device , input_type )
@@ -230,9 +228,25 @@ def test_mean_over_axis_0_f_contig_output(
230
228
height = 1
231
229
width = 10
232
230
233
- input = dpt .empty ((height , 10 ), dtype = input_type , device = device )
234
- output = dpt .empty (20 , dtype = output_type , device = device )[::2 ]
231
+ input = dpt .empty ((height , width ), dtype = input_type , device = device )
232
+ output = dpt .empty (width * 2 , dtype = output_type , device = device )[::2 ]
233
+
234
+ assert func (input , output ) is None
235
+
236
+
237
+ @pytest .mark .parametrize (
238
+ "func, device, input_type, output_type" ,
239
+ product (mean_sum , all_devices , [dpt .float32 ], [dpt .float32 , dpt .float64 ]),
240
+ )
241
+ def test_mean_sum_over_axis_0_big_output (func , device , input_type , output_type ):
242
+ skip_unsupported (device , input_type )
243
+ skip_unsupported (device , output_type )
244
+
245
+ local_mem_size = device .local_mem_size
246
+ height = 1
247
+ width = 1 + local_mem_size // output_type .itemsize
248
+
249
+ input = dpt .empty ((height , width ), dtype = input_type , device = device )
250
+ output = dpt .empty (width , dtype = output_type , device = device )
235
251
236
- if func (input , output ):
237
- print (output_type )
238
252
assert func (input , output ) is None
0 commit comments