@@ -91,7 +91,7 @@ def create_mix_model(size, axis):
9191 with pytest .raises (RuntimeError , match = "could not be derived: {m}" ):
9292 factorized_joint_logprob ({M_rv : m_vv , I_rv : i_vv , X_rv : x_vv })
9393
94- with pytest .raises (NotImplementedError ):
94+ with pytest .raises (RuntimeError , match = "could not be derived: {m}" ):
9595 axis_at = at .lscalar ("axis" )
9696 axis_at .tag .test_value = 0
9797 env = create_mix_model ((2 ,), axis_at )
@@ -139,17 +139,19 @@ def test_compute_test_value(op_constructor):
139139
140140
141141@pytest .mark .parametrize (
142- "p_val, size" ,
142+ "p_val, size, supported " ,
143143 [
144- (np .array (0.0 , dtype = pytensor .config .floatX ), ()),
145- (np .array (1.0 , dtype = pytensor .config .floatX ), ()),
146- (np .array (0.0 , dtype = pytensor .config .floatX ), (2 ,)),
147- (np .array (1.0 , dtype = pytensor .config .floatX ), (2 , 1 )),
148- (np .array (1.0 , dtype = pytensor .config .floatX ), (2 , 3 )),
149- (np .array ([0.1 , 0.9 ], dtype = pytensor .config .floatX ), (2 , 3 )),
144+ (np .array (0.0 , dtype = pytensor .config .floatX ), (), True ),
145+ (np .array (1.0 , dtype = pytensor .config .floatX ), (), True ),
146+ (np .array ([0.1 , 0.9 ], dtype = pytensor .config .floatX ), (), True ),
147+ # The cases belowe are not supported because they may pick repeated values via AdvancedIndexing
148+ (np .array (0.0 , dtype = pytensor .config .floatX ), (2 ,), False ),
149+ (np .array (1.0 , dtype = pytensor .config .floatX ), (2 , 1 ), False ),
150+ (np .array (1.0 , dtype = pytensor .config .floatX ), (2 , 3 ), False ),
151+ (np .array ([0.1 , 0.9 ], dtype = pytensor .config .floatX ), (2 , 3 ), False ),
150152 ],
151153)
152- def test_hetero_mixture_binomial (p_val , size ):
154+ def test_hetero_mixture_binomial (p_val , size , supported ):
153155 srng = at .random .RandomStream (29833 )
154156
155157 X_rv = srng .normal (0 , 1 , size = size , name = "X" )
@@ -175,7 +177,12 @@ def test_hetero_mixture_binomial(p_val, size):
175177 m_vv = M_rv .clone ()
176178 m_vv .name = "m"
177179
178- M_logp = joint_logprob ({M_rv : m_vv , I_rv : i_vv }, sum = False )
180+ if supported :
181+ M_logp = joint_logprob ({M_rv : m_vv , I_rv : i_vv }, sum = False )
182+ else :
183+ with pytest .raises (RuntimeError , match = "could not be derived: {m}" ):
184+ joint_logprob ({M_rv : m_vv , I_rv : i_vv }, sum = False )
185+ return
179186
180187 M_logp_fn = pytensor .function ([p_at , m_vv , i_vv ], M_logp )
181188
@@ -204,9 +211,9 @@ def test_hetero_mixture_binomial(p_val, size):
204211
205212
206213@pytest .mark .parametrize (
207- "X_args, Y_args, Z_args, p_val, comp_size, idx_size, extra_indices, join_axis" ,
214+ "X_args, Y_args, Z_args, p_val, comp_size, idx_size, extra_indices, join_axis, supported " ,
208215 [
209- # Scalar mixture components, scalar index
216+ # Scalar components, scalar index
210217 (
211218 (
212219 np .array (0 , dtype = pytensor .config .floatX ),
@@ -225,6 +232,7 @@ def test_hetero_mixture_binomial(p_val, size):
225232 (),
226233 (),
227234 0 ,
235+ True ,
228236 ),
229237 # Degenerate vector mixture components, scalar index along join axis
230238 (
@@ -245,6 +253,7 @@ def test_hetero_mixture_binomial(p_val, size):
245253 (),
246254 (),
247255 0 ,
256+ True ,
248257 ),
249258 # Degenerate vector mixture components, scalar index along join axis (axis=1)
250259 (
@@ -265,6 +274,7 @@ def test_hetero_mixture_binomial(p_val, size):
265274 (),
266275 (slice (None ),),
267276 1 ,
277+ True ,
268278 ),
269279 # Vector mixture components, scalar index along the join axis
270280 (
@@ -285,6 +295,7 @@ def test_hetero_mixture_binomial(p_val, size):
285295 (),
286296 (),
287297 0 ,
298+ True ,
288299 ),
289300 # Vector mixture components, scalar index along the join axis (axis=1)
290301 (
@@ -305,6 +316,7 @@ def test_hetero_mixture_binomial(p_val, size):
305316 (),
306317 (slice (None ),),
307318 1 ,
319+ True ,
308320 ),
309321 # Vector mixture components, scalar index that mixes across components
310322 pytest .param (
@@ -325,6 +337,7 @@ def test_hetero_mixture_binomial(p_val, size):
325337 (),
326338 (),
327339 1 ,
340+ True ,
328341 marks = pytest .mark .xfail (
329342 AssertionError ,
330343 match = "Arrays are not almost equal to 6 decimals" , # This is ignored, but that's where it should fail!
@@ -350,7 +363,10 @@ def test_hetero_mixture_binomial(p_val, size):
350363 (),
351364 (),
352365 0 ,
366+ True ,
353367 ),
368+ # All the tests below rely on AdvancedIndexing, which is not supported at the moment
369+ # See https://github.com/pymc-devs/pymc/issues/6398
354370 # Scalar mixture components, vector index along first axis
355371 (
356372 (
@@ -370,6 +386,7 @@ def test_hetero_mixture_binomial(p_val, size):
370386 (6 ,),
371387 (),
372388 0 ,
389+ False ,
373390 ),
374391 # Vector mixture components, vector index along first axis
375392 (
@@ -390,9 +407,10 @@ def test_hetero_mixture_binomial(p_val, size):
390407 (2 ,),
391408 (slice (None ),),
392409 0 ,
410+ False ,
393411 ),
394412 # Vector mixture components, vector index along last axis
395- pytest . param (
413+ (
396414 (
397415 np .array (0 , dtype = pytensor .config .floatX ),
398416 np .array (1 , dtype = pytensor .config .floatX ),
@@ -410,7 +428,7 @@ def test_hetero_mixture_binomial(p_val, size):
410428 (4 ,),
411429 (slice (None ),),
412430 1 ,
413- marks = pytest . mark . xfail ( IndexError , reason = "Bug in AdvancedIndex Mixture logprob" ) ,
431+ False ,
414432 ),
415433 # Vector mixture components (with degenerate vector parameters), vector index along first axis
416434 (
@@ -431,6 +449,7 @@ def test_hetero_mixture_binomial(p_val, size):
431449 (2 ,),
432450 (),
433451 0 ,
452+ False ,
434453 ),
435454 # Vector mixture components (with vector parameters), vector index along first axis
436455 (
@@ -451,6 +470,7 @@ def test_hetero_mixture_binomial(p_val, size):
451470 (2 ,),
452471 (),
453472 0 ,
473+ False ,
454474 ),
455475 # Vector mixture components (with vector parameters), vector index along first axis, implicit sizes
456476 (
@@ -471,6 +491,7 @@ def test_hetero_mixture_binomial(p_val, size):
471491 None ,
472492 (),
473493 0 ,
494+ False ,
474495 ),
475496 # Matrix mixture components, matrix index
476497 (
@@ -491,6 +512,7 @@ def test_hetero_mixture_binomial(p_val, size):
491512 (2 , 3 ),
492513 (),
493514 0 ,
515+ False ,
494516 ),
495517 # Vector components, matrix indexing (constant along first dimension, then random)
496518 (
@@ -511,6 +533,7 @@ def test_hetero_mixture_binomial(p_val, size):
511533 (5 ,),
512534 (np .arange (5 ),),
513535 0 ,
536+ False ,
514537 ),
515538 # Vector mixture components, tensor3 indexing (constant along first dimension, then degenerate, then random)
516539 (
@@ -531,11 +554,12 @@ def test_hetero_mixture_binomial(p_val, size):
531554 (5 ,),
532555 (np .arange (5 ), None ),
533556 0 ,
557+ False ,
534558 ),
535559 ],
536560)
537561def test_hetero_mixture_categorical (
538- X_args , Y_args , Z_args , p_val , comp_size , idx_size , extra_indices , join_axis
562+ X_args , Y_args , Z_args , p_val , comp_size , idx_size , extra_indices , join_axis , supported
539563):
540564 srng = at .random .RandomStream (29833 )
541565
@@ -561,7 +585,12 @@ def test_hetero_mixture_categorical(
561585 m_vv = M_rv .clone ()
562586 m_vv .name = "m"
563587
564- logp_parts = factorized_joint_logprob ({M_rv : m_vv , I_rv : i_vv }, sum = False )
588+ if supported :
589+ logp_parts = factorized_joint_logprob ({M_rv : m_vv , I_rv : i_vv }, sum = False )
590+ else :
591+ with pytest .raises (RuntimeError , match = "could not be derived: {m}" ):
592+ factorized_joint_logprob ({M_rv : m_vv , I_rv : i_vv }, sum = False )
593+ return
565594
566595 I_logp_fn = pytensor .function ([p_at , i_vv ], logp_parts [i_vv ])
567596 M_logp_fn = pytensor .function ([m_vv , i_vv ], logp_parts [m_vv ])
@@ -854,7 +883,7 @@ def test_mixture_with_DiracDelta():
854883 Y_rv = dirac_delta (0.0 )
855884 Y_rv .name = "Y"
856885
857- I_rv = srng .categorical ([0.5 , 0.5 ], size = 4 )
886+ I_rv = srng .categorical ([0.5 , 0.5 ], size = 1 )
858887
859888 i_vv = I_rv .clone ()
860889 i_vv .name = "i"
0 commit comments