|
27 | 27 | isinf,
|
28 | 28 | log,
|
29 | 29 | log1p,
|
| 30 | + same_out, |
30 | 31 | switch,
|
31 | 32 | true_div,
|
32 | 33 | upcast,
|
33 | 34 | upgrade_to_float,
|
34 | 35 | upgrade_to_float64,
|
35 | 36 | upgrade_to_float_no_complex,
|
36 | 37 | )
|
| 38 | +from pytensor.scalar.scan import ScalarScanOp |
37 | 39 |
|
38 | 40 |
|
39 | 41 | class Erf(UnaryScalarOp):
|
@@ -751,87 +753,180 @@ def c_code(self, *args, **kwargs):
|
751 | 753 | gammainc_der = GammaIncDer(upgrade_to_float, name="gammainc_der")
|
752 | 754 |
|
753 | 755 |
|
754 |
| -class GammaIncCDer(BinaryScalarOp): |
755 |
| - """ |
756 |
| - Gradient of the the regularized upper gamma function (Q) wrt to the first |
757 |
| - argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp` |
758 |
| - """ |
| 756 | +# class GammaIncCDer(BinaryScalarOp): |
| 757 | +# """ |
| 758 | +# Gradient of the the regularized upper gamma function (Q) wrt to the first |
| 759 | +# argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp` |
| 760 | +# """ |
| 761 | +# |
| 762 | +# @staticmethod |
| 763 | +# def st_impl(k, x): |
| 764 | +# gamma_k = scipy.special.gamma(k) |
| 765 | +# digamma_k = scipy.special.digamma(k) |
| 766 | +# log_x = np.log(x) |
| 767 | +# |
| 768 | +# # asymptotic expansion http://dlmf.nist.gov/8.11#E2 |
| 769 | +# if (x >= k) and (x >= 8): |
| 770 | +# S = 0 |
| 771 | +# k_minus_one_minus_n = k - 1 |
| 772 | +# fac = k_minus_one_minus_n |
| 773 | +# dfac = 1 |
| 774 | +# xpow = x |
| 775 | +# delta = dfac / xpow |
| 776 | +# |
| 777 | +# for n in range(1, 10): |
| 778 | +# k_minus_one_minus_n -= 1 |
| 779 | +# S += delta |
| 780 | +# xpow *= x |
| 781 | +# dfac = k_minus_one_minus_n * dfac + fac |
| 782 | +# fac *= k_minus_one_minus_n |
| 783 | +# delta = dfac / xpow |
| 784 | +# if np.isinf(delta): |
| 785 | +# warnings.warn( |
| 786 | +# "gammaincc_der did not converge", |
| 787 | +# RuntimeWarning, |
| 788 | +# ) |
| 789 | +# return np.nan |
| 790 | +# |
| 791 | +# return ( |
| 792 | +# scipy.special.gammaincc(k, x) * (log_x - digamma_k) |
| 793 | +# + np.exp(-x + (k - 1) * log_x) * S / gamma_k |
| 794 | +# ) |
| 795 | +# |
| 796 | +# # gradient of series expansion http://dlmf.nist.gov/8.7#E3 |
| 797 | +# else: |
| 798 | +# log_precision = np.log(1e-6) |
| 799 | +# max_iters = int(1e5) |
| 800 | +# S = 0 |
| 801 | +# log_s = 0.0 |
| 802 | +# s_sign = 1 |
| 803 | +# log_delta = log_s - 2 * np.log(k) |
| 804 | +# for n in range(1, max_iters + 1): |
| 805 | +# S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta) |
| 806 | +# s_sign = -s_sign |
| 807 | +# log_s += log_x - np.log(n) |
| 808 | +# log_delta = log_s - 2 * np.log(n + k) |
| 809 | +# |
| 810 | +# if np.isinf(log_delta): |
| 811 | +# warnings.warn( |
| 812 | +# "gammaincc_der did not converge", |
| 813 | +# RuntimeWarning, |
| 814 | +# ) |
| 815 | +# return np.nan |
| 816 | +# |
| 817 | +# if log_delta <= log_precision: |
| 818 | +# return ( |
| 819 | +# scipy.special.gammainc(k, x) * (digamma_k - log_x) |
| 820 | +# + np.exp(k * log_x) * S / gamma_k |
| 821 | +# ) |
| 822 | +# |
| 823 | +# warnings.warn( |
| 824 | +# f"gammaincc_der did not converge after {n} iterations", |
| 825 | +# RuntimeWarning, |
| 826 | +# ) |
| 827 | +# return np.nan |
| 828 | +# |
| 829 | +# def impl(self, k, x): |
| 830 | +# return self.st_impl(k, x) |
| 831 | +# |
| 832 | +# def c_code(self, *args, **kwargs): |
| 833 | +# raise NotImplementedError() |
| 834 | +# |
| 835 | +# |
| 836 | +# gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der") |
| 837 | + |
| 838 | + |
| 839 | +class GammaIncCDerInnerScan1(ScalarScanOp): |
| 840 | + nin = 7 |
| 841 | + nout = 6 |
| 842 | + n_steps = 9 |
| 843 | + |
| 844 | + @property |
| 845 | + def fn(self): |
| 846 | + def inner_fn(S, delta, xpow, k_minus_one_minus_n, dfac, fac, x): |
| 847 | + S += delta |
| 848 | + xpow *= x |
| 849 | + k_minus_one_minus_n -= 1 |
| 850 | + dfac = k_minus_one_minus_n * dfac + fac |
| 851 | + fac *= k_minus_one_minus_n |
| 852 | + delta = dfac / xpow |
| 853 | + return S, delta, xpow, k_minus_one_minus_n, dfac, fac |
759 | 854 |
|
760 |
| - @staticmethod |
761 |
| - def st_impl(k, x): |
762 |
| - gamma_k = scipy.special.gamma(k) |
763 |
| - digamma_k = scipy.special.digamma(k) |
764 |
| - log_x = np.log(x) |
| 855 | + return inner_fn |
765 | 856 |
|
766 |
| - # asymptotic expansion http://dlmf.nist.gov/8.11#E2 |
767 |
| - if (x >= k) and (x >= 8): |
768 |
| - S = 0 |
769 |
| - k_minus_one_minus_n = k - 1 |
770 |
| - fac = k_minus_one_minus_n |
771 |
| - dfac = 1 |
772 |
| - xpow = x |
773 |
| - delta = dfac / xpow |
774 | 857 |
|
775 |
| - for n in range(1, 10): |
776 |
| - k_minus_one_minus_n -= 1 |
777 |
| - S += delta |
778 |
| - xpow *= x |
779 |
| - dfac = k_minus_one_minus_n * dfac + fac |
780 |
| - fac *= k_minus_one_minus_n |
781 |
| - delta = dfac / xpow |
782 |
| - if np.isinf(delta): |
783 |
| - warnings.warn( |
784 |
| - "gammaincc_der did not converge", |
785 |
| - RuntimeWarning, |
786 |
| - ) |
787 |
| - return np.nan |
| 858 | +_gammaincc_der_scan1 = GammaIncCDerInnerScan1( |
| 859 | + lambda *types: tuple( |
| 860 | + same_out(type)[0] for type in types[: GammaIncCDerInnerScan1.nout] |
| 861 | + ) |
| 862 | +) |
788 | 863 |
|
| 864 | + |
| 865 | +class GammaIncCDerInnerScan2(ScalarScanOp): |
| 866 | + nin = 7 |
| 867 | + nout = 5 |
| 868 | + n_steps = int(1e5) # maximum number of iterations |
| 869 | + log_precision = np.log(1e-6) |
| 870 | + |
| 871 | + @property |
| 872 | + def fn(self): |
| 873 | + import pytensor.tensor as pt |
| 874 | + from pytensor.scan import until |
| 875 | + |
| 876 | + def inner_fn(S, log_s, s_sign, log_delta, n, k, log_x): |
| 877 | + delta = pt.exp(log_delta) |
| 878 | + S += pt.switch(s_sign > 0, delta, -delta) |
| 879 | + s_sign = -s_sign |
| 880 | + log_s += log_x - pt.log(n) |
| 881 | + log_delta = log_s - 2 * pt.log(n + k) |
| 882 | + n += 1 |
789 | 883 | return (
|
790 |
| - scipy.special.gammaincc(k, x) * (log_x - digamma_k) |
791 |
| - + np.exp(-x + (k - 1) * log_x) * S / gamma_k |
| 884 | + (S, log_s, s_sign, log_delta, n), |
| 885 | + {}, |
| 886 | + until(pt.all(log_delta < self.log_precision)), |
792 | 887 | )
|
793 | 888 |
|
794 |
| - # gradient of series expansion http://dlmf.nist.gov/8.7#E3 |
795 |
| - else: |
796 |
| - log_precision = np.log(1e-6) |
797 |
| - max_iters = int(1e5) |
798 |
| - S = 0 |
799 |
| - log_s = 0.0 |
800 |
| - s_sign = 1 |
801 |
| - log_delta = log_s - 2 * np.log(k) |
802 |
| - for n in range(1, max_iters + 1): |
803 |
| - S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta) |
804 |
| - s_sign = -s_sign |
805 |
| - log_s += log_x - np.log(n) |
806 |
| - log_delta = log_s - 2 * np.log(n + k) |
807 |
| - |
808 |
| - if np.isinf(log_delta): |
809 |
| - warnings.warn( |
810 |
| - "gammaincc_der did not converge", |
811 |
| - RuntimeWarning, |
812 |
| - ) |
813 |
| - return np.nan |
814 |
| - |
815 |
| - if log_delta <= log_precision: |
816 |
| - return ( |
817 |
| - scipy.special.gammainc(k, x) * (digamma_k - log_x) |
818 |
| - + np.exp(k * log_x) * S / gamma_k |
819 |
| - ) |
| 889 | + return inner_fn |
820 | 890 |
|
821 |
| - warnings.warn( |
822 |
| - f"gammaincc_der did not converge after {n} iterations", |
823 |
| - RuntimeWarning, |
824 |
| - ) |
825 |
| - return np.nan |
826 | 891 |
|
827 |
| - def impl(self, k, x): |
828 |
| - return self.st_impl(k, x) |
829 |
| - |
830 |
| - def c_code(self, *args, **kwargs): |
831 |
| - raise NotImplementedError() |
| 892 | +_gammaincc_der_scan2 = GammaIncCDerInnerScan2( |
| 893 | + lambda *types: tuple( |
| 894 | + same_out(type)[0] for type in types[: GammaIncCDerInnerScan2.nout] |
| 895 | + ) |
| 896 | +) |
832 | 897 |
|
833 | 898 |
|
834 |
| -gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der") |
| 899 | +def gammaincc_der(k, x): |
| 900 | + gamma_k = gamma(k) |
| 901 | + digamma_k = psi(k) |
| 902 | + log_x = log(x) |
| 903 | + |
| 904 | + # asymptotic expansion http://dlmf.nist.gov/8.11#E2 |
| 905 | + S = np.array(0.0, dtype="float64") |
| 906 | + dfac = np.array(1.0, dtype="float64") |
| 907 | + xpow = x |
| 908 | + k_minus_one_minus_n = k - 1 |
| 909 | + fac = k_minus_one_minus_n |
| 910 | + delta = true_div(dfac, xpow) |
| 911 | + S, *_ = _gammaincc_der_scan1(S, delta, xpow, k_minus_one_minus_n, fac, dfac, x) |
| 912 | + res1 = ( |
| 913 | + gammaincc(k, x) * (log_x - digamma_k) + exp(-x + (k - 1) * log_x) * S / gamma_k |
| 914 | + ) |
| 915 | + |
| 916 | + # gradient of series expansion http://dlmf.nist.gov/8.7#E3 |
| 917 | + S = np.array(0.0, dtype="float64") |
| 918 | + log_s = np.array(0.0, dtype="float64") |
| 919 | + s_sign = np.array(1, dtype="int8") |
| 920 | + n = np.array(1, dtype="int64") |
| 921 | + log_delta = log_s - 2 * log(k) |
| 922 | + S, *_ = _gammaincc_der_scan2(S, log_s, s_sign, log_delta, n, k, log_x) |
| 923 | + res2 = gammainc(k, x) * (digamma_k - log_x) + exp(k * log_x) * S / gamma_k |
| 924 | + |
| 925 | + return switch( |
| 926 | + (x >= k) & (x >= 8), |
| 927 | + res1, |
| 928 | + res2, |
| 929 | + ) |
835 | 930 |
|
836 | 931 |
|
837 | 932 | class GammaU(BinaryScalarOp):
|
|
0 commit comments