@@ -657,6 +657,82 @@ def check_selfconsistency_discrete_logcdf(
657657 )
658658
659659
660+ def check_selfconsistency_continuous_icdf (
661+ distribution : Distribution ,
662+ paramdomains : Dict [str , Domain ],
663+ decimal : Optional [int ] = None ,
664+ n_samples : int = 100 ,
665+ ) -> None :
666+ """
667+ Check that the icdf and logcdf functions of the distribution are consistent for a sample of probability values.
668+ """
669+ if decimal is None :
670+ decimal = select_by_precision (float64 = 6 , float32 = 3 )
671+
672+ dist = create_dist_from_paramdomains (distribution , paramdomains )
673+ value = dist .type ()
674+ value .name = "value"
675+
676+ dist_icdf = icdf (dist , value )
677+ dist_icdf_fn = pytensor .function (list (inputvars (dist_icdf )), dist_icdf )
678+
679+ dist_logcdf = logcdf (dist , value )
680+ dist_logcdf_fn = compile_pymc (list (inputvars (dist_logcdf )), dist_logcdf )
681+
682+ domains = paramdomains .copy ()
683+ domains ["value" ] = Domain (np .linspace (0 , 1 , 10 ))
684+
685+ for point in product (domains , n_samples = n_samples ):
686+ point = dict (point )
687+ value = point .pop ("value" )
688+
689+ with pytensor .config .change_flags (mode = Mode ("py" )):
690+ npt .assert_almost_equal (
691+ value ,
692+ np .exp (dist_logcdf_fn (** point , value = dist_icdf_fn (** point , value = value ))),
693+ decimal = decimal ,
694+ err_msg = f"point: { point } , value: { value } " ,
695+ )
696+
697+
698+ def check_selfconsistency_discrete_icdf (
699+ distribution : Distribution ,
700+ domain : Domain ,
701+ paramdomains : Dict [str , Domain ],
702+ n_samples : int = 100 ,
703+ ) -> None :
704+ """
705+ Check that the icdf and logcdf functions of the distribution are
706+ consistent for a sample of values in the domain of the
707+ distribution.
708+ """
709+ dist = create_dist_from_paramdomains (distribution , paramdomains )
710+
711+ value = pt .TensorType (dtype = "float64" , shape = [])("value" )
712+
713+ dist_icdf = icdf (dist , value )
714+ dist_icdf_fn = pytensor .function (list (inputvars (dist_icdf )), dist_icdf )
715+
716+ dist_logcdf = logcdf (dist , value )
717+ dist_logcdf_fn = compile_pymc (list (inputvars (dist_logcdf )), dist_logcdf )
718+
719+ domains = paramdomains .copy ()
720+ domains ["value" ] = domain
721+
722+ for point in product (domains , n_samples = n_samples ):
723+ point = dict (point )
724+ value = point .pop ("value" )
725+
726+ with pytensor .config .change_flags (mode = Mode ("py" )):
727+ expected_value = value
728+ computed_value = dist_icdf_fn (
729+ ** point , value = np .exp (dist_logcdf_fn (** point , value = value ))
730+ )
731+ assert (
732+ expected_value == computed_value
733+ ), f"expected_value = { expected_value } , computed_value = { computed_value } , { point } "
734+
735+
660736def assert_moment_is_expected (model , expected , check_finite_logp = True ):
661737 fn = make_initial_point_fn (
662738 model = model ,
0 commit comments