@@ -226,7 +226,7 @@ def test_hetero_mixture_binomial(p_val, size):
226226 (),
227227 0 ,
228228 ),
229- # Degenerate vector mixture components, scalar index
229+ # Degenerate vector mixture components, scalar index along join axis
230230 (
231231 (
232232 np .array ([0 ], dtype = pytensor .config .floatX ),
@@ -246,7 +246,27 @@ def test_hetero_mixture_binomial(p_val, size):
246246 (),
247247 0 ,
248248 ),
249- # Scalar mixture components, vector index
249+ # Degenerate vector mixture components, scalar index along join axis (axis=1)
250+ (
251+ (
252+ np .array ([0 ], dtype = pytensor .config .floatX ),
253+ np .array (1 , dtype = pytensor .config .floatX ),
254+ ),
255+ (
256+ np .array ([0.5 ], dtype = pytensor .config .floatX ),
257+ np .array (0.5 , dtype = pytensor .config .floatX ),
258+ ),
259+ (
260+ np .array ([100 ], dtype = pytensor .config .floatX ),
261+ np .array (1 , dtype = pytensor .config .floatX ),
262+ ),
263+ np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
264+ None ,
265+ (),
266+ (slice (None ),),
267+ 1 ,
268+ ),
269+ # Vector mixture components, scalar index along the join axis
250270 (
251271 (
252272 np .array (0 , dtype = pytensor .config .floatX ),
@@ -261,49 +281,72 @@ def test_hetero_mixture_binomial(p_val, size):
261281 np .array (1 , dtype = pytensor .config .floatX ),
262282 ),
263283 np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
284+ (4 ,),
264285 (),
265- (6 ,),
266286 (),
267287 0 ,
268288 ),
289+ # Vector mixture components, scalar index along the join axis (axis=1)
269290 (
270291 (
271- np .array ([ 0 , - 100 ] , dtype = pytensor .config .floatX ),
292+ np .array (0 , dtype = pytensor .config .floatX ),
272293 np .array (1 , dtype = pytensor .config .floatX ),
273294 ),
274295 (
275- np .array ([ 0.5 , 1 ] , dtype = pytensor .config .floatX ),
276- np .array ([ 0.5 , 1 ] , dtype = pytensor .config .floatX ),
296+ np .array (0.5 , dtype = pytensor .config .floatX ),
297+ np .array (0.5 , dtype = pytensor .config .floatX ),
277298 ),
278299 (
279- np .array ([ 100 , 1000 ] , dtype = pytensor .config .floatX ),
300+ np .array (100 , dtype = pytensor .config .floatX ),
280301 np .array (1 , dtype = pytensor .config .floatX ),
281302 ),
282- np .array ([[0.1 , 0.5 , 0.4 ], [0.4 , 0.1 , 0.5 ]], dtype = pytensor .config .floatX ),
283- (2 ,),
284- (2 ,),
303+ np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
304+ (4 ,),
305+ (),
306+ (slice (None ),),
307+ 1 ,
308+ ),
309+ # Matrix components, scalar index along first axis
310+ (
311+ (
312+ np .array (0 , dtype = pytensor .config .floatX ),
313+ np .array (1 , dtype = pytensor .config .floatX ),
314+ ),
315+ (
316+ np .array (0.5 , dtype = pytensor .config .floatX ),
317+ np .array (0.5 , dtype = pytensor .config .floatX ),
318+ ),
319+ (
320+ np .array (100 , dtype = pytensor .config .floatX ),
321+ np .array (1 , dtype = pytensor .config .floatX ),
322+ ),
323+ np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
324+ (2 , 3 ),
325+ (),
285326 (),
286327 0 ,
287328 ),
329+ # Scalar mixture components, vector index along first axis
288330 (
289331 (
290- np .array ([ 0 , - 100 ] , dtype = pytensor .config .floatX ),
332+ np .array (0 , dtype = pytensor .config .floatX ),
291333 np .array (1 , dtype = pytensor .config .floatX ),
292334 ),
293335 (
294- np .array ([ 0.5 , 1 ] , dtype = pytensor .config .floatX ),
295- np .array ([ 0.5 , 1 ] , dtype = pytensor .config .floatX ),
336+ np .array (0.5 , dtype = pytensor .config .floatX ),
337+ np .array (0.5 , dtype = pytensor .config .floatX ),
296338 ),
297339 (
298- np .array ([ 100 , 1000 ] , dtype = pytensor .config .floatX ),
340+ np .array (100 , dtype = pytensor .config .floatX ),
299341 np .array (1 , dtype = pytensor .config .floatX ),
300342 ),
301- np .array ([[ 0.1 , 0.5 , 0.4 ], [ 0.4 , 0.1 , 0.5 ] ], dtype = pytensor .config .floatX ),
302- None ,
303- None ,
343+ np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
344+ () ,
345+ ( 6 ,) ,
304346 (),
305347 0 ,
306348 ),
349+ # Vector mixture components, vector index along first axis
307350 (
308351 (
309352 np .array (0 , dtype = pytensor .config .floatX ),
@@ -320,10 +363,31 @@ def test_hetero_mixture_binomial(p_val, size):
320363 np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
321364 (2 ,),
322365 (2 ,),
323- (),
366+ (slice ( None ), ),
324367 0 ,
325368 ),
326- # Same as before but with degenerate vector parameters
369+ # Vector mixture components, vector index along last axis
370+ pytest .param (
371+ (
372+ np .array (0 , dtype = pytensor .config .floatX ),
373+ np .array (1 , dtype = pytensor .config .floatX ),
374+ ),
375+ (
376+ np .array (0.5 , dtype = pytensor .config .floatX ),
377+ np .array (0.5 , dtype = pytensor .config .floatX ),
378+ ),
379+ (
380+ np .array (100 , dtype = pytensor .config .floatX ),
381+ np .array (1 , dtype = pytensor .config .floatX ),
382+ ),
383+ np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
384+ (2 ,),
385+ (4 ,),
386+ (slice (None ),),
387+ 1 ,
388+ marks = pytest .mark .xfail (IndexError , reason = "Bug in AdvancedIndex Mixture logprob" ),
389+ ),
390+ # Vector mixture components (with degenerate vector parameters), vector index along first axis
327391 (
328392 (
329393 np .array ([0 ], dtype = pytensor .config .floatX ),
@@ -343,45 +407,48 @@ def test_hetero_mixture_binomial(p_val, size):
343407 (),
344408 0 ,
345409 ),
410+ # Vector mixture components (with vector parameters), vector index along first axis
346411 (
347412 (
348- np .array (0 , dtype = pytensor .config .floatX ),
413+ np .array ([ 0 , - 100 ] , dtype = pytensor .config .floatX ),
349414 np .array (1 , dtype = pytensor .config .floatX ),
350415 ),
351416 (
352- np .array (0.5 , dtype = pytensor .config .floatX ),
353- np .array (0.5 , dtype = pytensor .config .floatX ),
417+ np .array ([ 0.5 , 1 ] , dtype = pytensor .config .floatX ),
418+ np .array ([ 0.5 , 1 ] , dtype = pytensor .config .floatX ),
354419 ),
355420 (
356- np .array (100 , dtype = pytensor .config .floatX ),
421+ np .array ([ 100 , 1000 ] , dtype = pytensor .config .floatX ),
357422 np .array (1 , dtype = pytensor .config .floatX ),
358423 ),
359- np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
360- (2 , 3 ),
361- (2 , 3 ),
424+ np .array ([[ 0.1 , 0.5 , 0.4 ], [ 0.4 , 0.1 , 0.5 ] ], dtype = pytensor .config .floatX ),
425+ (2 ,),
426+ (2 ,),
362427 (),
363428 0 ,
364429 ),
430+ # Vector mixture components (with vector parameters), vector index along first axis, implicit sizes
365431 (
366432 (
367- np .array (0 , dtype = pytensor .config .floatX ),
433+ np .array ([ 0 , - 100 ] , dtype = pytensor .config .floatX ),
368434 np .array (1 , dtype = pytensor .config .floatX ),
369435 ),
370436 (
371- np .array (0.5 , dtype = pytensor .config .floatX ),
372- np .array (0.5 , dtype = pytensor .config .floatX ),
437+ np .array ([ 0.5 , 1 ] , dtype = pytensor .config .floatX ),
438+ np .array ([ 0.5 , 1 ] , dtype = pytensor .config .floatX ),
373439 ),
374440 (
375- np .array (100 , dtype = pytensor .config .floatX ),
441+ np .array ([ 100 , 1000 ] , dtype = pytensor .config .floatX ),
376442 np .array (1 , dtype = pytensor .config .floatX ),
377443 ),
378- np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
379- ( 2 , 3 ) ,
380- () ,
444+ np .array ([[ 0.1 , 0.5 , 0.4 ], [ 0.4 , 0.1 , 0.5 ] ], dtype = pytensor .config .floatX ),
445+ None ,
446+ None ,
381447 (),
382448 0 ,
383449 ),
384- pytest .param (
450+ # Matrix mixture components, matrix index
451+ (
385452 (
386453 np .array (0 , dtype = pytensor .config .floatX ),
387454 np .array (1 , dtype = pytensor .config .floatX ),
@@ -395,12 +462,12 @@ def test_hetero_mixture_binomial(p_val, size):
395462 np .array (1 , dtype = pytensor .config .floatX ),
396463 ),
397464 np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
398- (3 ,),
399- (3 ,),
400- (slice (None ),),
401- 1 ,
402- marks = pytest .mark .xfail (IndexError , reason = "Bug in AdvancedIndex Mixture logprob" ),
465+ (2 , 3 ),
466+ (2 , 3 ),
467+ (),
468+ 0 ,
403469 ),
470+ # Vector components, matrix indexing (constant along first dimension, then random)
404471 (
405472 (
406473 np .array (0 , dtype = pytensor .config .floatX ),
@@ -420,6 +487,7 @@ def test_hetero_mixture_binomial(p_val, size):
420487 (np .arange (5 ),),
421488 0 ,
422489 ),
490+ # Vector mixture components, tensor3 indexing (constant along first dimension, then degenerate, then random)
423491 (
424492 (
425493 np .array (0 , dtype = pytensor .config .floatX ),
0 commit comments