@@ -410,7 +410,7 @@ def transform_input(inputs):
410
410
marginalized_rv .type , dependent_logps
411
411
)
412
412
413
- rv_shape = constant_fold (tuple (marginalized_rv .shape ))
413
+ rv_shape = constant_fold (tuple (marginalized_rv .shape ), raise_not_constant = False )
414
414
rv_domain = get_domain_of_finite_discrete_rv (marginalized_rv )
415
415
rv_domain_tensor = pt .moveaxis (
416
416
pt .full (
@@ -579,6 +579,15 @@ def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs):
579
579
return True
580
580
581
581
582
+ from pytensor .graph .basic import graph_inputs
583
+
584
+
585
+ def collect_shared_vars (outputs , blockers ):
586
+ return [
587
+ inp for inp in graph_inputs (outputs , blockers = blockers ) if isinstance (inp , SharedVariable )
588
+ ]
589
+
590
+
582
591
def replace_finite_discrete_marginal_subgraph (fgraph , rv_to_marginalize , all_rvs ):
583
592
# TODO: This should eventually be integrated in a more general routine that can
584
593
# identify other types of supported marginalization, of which finite discrete
@@ -621,27 +630,21 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
621
630
rvs_to_marginalize = [rv_to_marginalize , * dependent_rvs ]
622
631
623
632
outputs = rvs_to_marginalize
624
- # Clone replace inner RV rng inputs so that we can be sure of the update order
625
- # replace_inputs = {rng: rng.type() for rng in updates_rvs_to_marginalize.keys()}
626
- # Clone replace outter RV inputs, so that their shared RNGs don't make it into
627
- # the inner graph of the marginalized RVs
628
- # FIXME: This shouldn't be needed!
629
- replace_inputs = {}
630
- replace_inputs .update ({input_rv : input_rv .type () for input_rv in input_rvs })
631
- cloned_outputs = clone_replace (outputs , replace = replace_inputs )
633
+ # We are strict about shared variables in SymbolicRandomVariables
634
+ inputs = input_rvs + collect_shared_vars (rvs_to_marginalize , blockers = input_rvs )
632
635
633
636
if isinstance (rv_to_marginalize .owner .op , DiscreteMarkovChain ):
634
637
marginalize_constructor = DiscreteMarginalMarkovChainRV
635
638
else :
636
639
marginalize_constructor = FiniteDiscreteMarginalRV
637
640
638
641
marginalization_op = marginalize_constructor (
639
- inputs = list ( replace_inputs . values ()) ,
640
- outputs = cloned_outputs ,
642
+ inputs = inputs ,
643
+ outputs = outputs ,
641
644
ndim_supp = ndim_supp ,
642
645
)
643
646
644
- marginalized_rvs = marginalization_op (* replace_inputs . keys () )
647
+ marginalized_rvs = marginalization_op (* inputs )
645
648
fgraph .replace_all (tuple (zip (rvs_to_marginalize , marginalized_rvs )))
646
649
return rvs_to_marginalize , marginalized_rvs
647
650
0 commit comments