Skip to content

Commit 7396b31

Browse files
committed
WIP implement scalar Scan Op
1 parent 33d4d36 commit 7396b31

File tree

4 files changed

+240
-73
lines changed

4 files changed

+240
-73
lines changed

pytensor/scalar/math.py

Lines changed: 165 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@
2727
isinf,
2828
log,
2929
log1p,
30+
same_out,
3031
switch,
3132
true_div,
3233
upcast,
3334
upgrade_to_float,
3435
upgrade_to_float64,
3536
upgrade_to_float_no_complex,
3637
)
38+
from pytensor.scalar.scan import ScalarScanOp
3739

3840

3941
class Erf(UnaryScalarOp):
@@ -751,87 +753,180 @@ def c_code(self, *args, **kwargs):
751753
gammainc_der = GammaIncDer(upgrade_to_float, name="gammainc_der")
752754

753755

754-
class GammaIncCDer(BinaryScalarOp):
755-
"""
756-
Gradient of the the regularized upper gamma function (Q) wrt to the first
757-
argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp`
758-
"""
756+
# class GammaIncCDer(BinaryScalarOp):
757+
# """
758+
# Gradient of the the regularized upper gamma function (Q) wrt to the first
759+
# argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp`
760+
# """
761+
#
762+
# @staticmethod
763+
# def st_impl(k, x):
764+
# gamma_k = scipy.special.gamma(k)
765+
# digamma_k = scipy.special.digamma(k)
766+
# log_x = np.log(x)
767+
#
768+
# # asymptotic expansion http://dlmf.nist.gov/8.11#E2
769+
# if (x >= k) and (x >= 8):
770+
# S = 0
771+
# k_minus_one_minus_n = k - 1
772+
# fac = k_minus_one_minus_n
773+
# dfac = 1
774+
# xpow = x
775+
# delta = dfac / xpow
776+
#
777+
# for n in range(1, 10):
778+
# k_minus_one_minus_n -= 1
779+
# S += delta
780+
# xpow *= x
781+
# dfac = k_minus_one_minus_n * dfac + fac
782+
# fac *= k_minus_one_minus_n
783+
# delta = dfac / xpow
784+
# if np.isinf(delta):
785+
# warnings.warn(
786+
# "gammaincc_der did not converge",
787+
# RuntimeWarning,
788+
# )
789+
# return np.nan
790+
#
791+
# return (
792+
# scipy.special.gammaincc(k, x) * (log_x - digamma_k)
793+
# + np.exp(-x + (k - 1) * log_x) * S / gamma_k
794+
# )
795+
#
796+
# # gradient of series expansion http://dlmf.nist.gov/8.7#E3
797+
# else:
798+
# log_precision = np.log(1e-6)
799+
# max_iters = int(1e5)
800+
# S = 0
801+
# log_s = 0.0
802+
# s_sign = 1
803+
# log_delta = log_s - 2 * np.log(k)
804+
# for n in range(1, max_iters + 1):
805+
# S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta)
806+
# s_sign = -s_sign
807+
# log_s += log_x - np.log(n)
808+
# log_delta = log_s - 2 * np.log(n + k)
809+
#
810+
# if np.isinf(log_delta):
811+
# warnings.warn(
812+
# "gammaincc_der did not converge",
813+
# RuntimeWarning,
814+
# )
815+
# return np.nan
816+
#
817+
# if log_delta <= log_precision:
818+
# return (
819+
# scipy.special.gammainc(k, x) * (digamma_k - log_x)
820+
# + np.exp(k * log_x) * S / gamma_k
821+
# )
822+
#
823+
# warnings.warn(
824+
# f"gammaincc_der did not converge after {n} iterations",
825+
# RuntimeWarning,
826+
# )
827+
# return np.nan
828+
#
829+
# def impl(self, k, x):
830+
# return self.st_impl(k, x)
831+
#
832+
# def c_code(self, *args, **kwargs):
833+
# raise NotImplementedError()
834+
#
835+
#
836+
# gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der")
837+
838+
839+
class GammaIncCDerInnerScan1(ScalarScanOp):
840+
nin = 7
841+
nout = 6
842+
n_steps = 9
843+
844+
@property
845+
def fn(self):
846+
def inner_fn(S, delta, xpow, k_minus_one_minus_n, dfac, fac, x):
847+
S += delta
848+
xpow *= x
849+
k_minus_one_minus_n -= 1
850+
dfac = k_minus_one_minus_n * dfac + fac
851+
fac *= k_minus_one_minus_n
852+
delta = dfac / xpow
853+
return S, delta, xpow, k_minus_one_minus_n, dfac, fac
759854

760-
@staticmethod
761-
def st_impl(k, x):
762-
gamma_k = scipy.special.gamma(k)
763-
digamma_k = scipy.special.digamma(k)
764-
log_x = np.log(x)
855+
return inner_fn
765856

766-
# asymptotic expansion http://dlmf.nist.gov/8.11#E2
767-
if (x >= k) and (x >= 8):
768-
S = 0
769-
k_minus_one_minus_n = k - 1
770-
fac = k_minus_one_minus_n
771-
dfac = 1
772-
xpow = x
773-
delta = dfac / xpow
774857

775-
for n in range(1, 10):
776-
k_minus_one_minus_n -= 1
777-
S += delta
778-
xpow *= x
779-
dfac = k_minus_one_minus_n * dfac + fac
780-
fac *= k_minus_one_minus_n
781-
delta = dfac / xpow
782-
if np.isinf(delta):
783-
warnings.warn(
784-
"gammaincc_der did not converge",
785-
RuntimeWarning,
786-
)
787-
return np.nan
858+
_gammaincc_der_scan1 = GammaIncCDerInnerScan1(
859+
lambda *types: tuple(
860+
same_out(type)[0] for type in types[: GammaIncCDerInnerScan1.nout]
861+
)
862+
)
788863

864+
865+
class GammaIncCDerInnerScan2(ScalarScanOp):
866+
nin = 7
867+
nout = 5
868+
n_steps = int(1e5) # maximum number of iterations
869+
log_precision = np.log(1e-6)
870+
871+
@property
872+
def fn(self):
873+
import pytensor.tensor as pt
874+
from pytensor.scan import until
875+
876+
def inner_fn(S, log_s, s_sign, log_delta, n, k, log_x):
877+
delta = pt.exp(log_delta)
878+
S += pt.switch(s_sign > 0, delta, -delta)
879+
s_sign = -s_sign
880+
log_s += log_x - pt.log(n)
881+
log_delta = log_s - 2 * pt.log(n + k)
882+
n += 1
789883
return (
790-
scipy.special.gammaincc(k, x) * (log_x - digamma_k)
791-
+ np.exp(-x + (k - 1) * log_x) * S / gamma_k
884+
(S, log_s, s_sign, log_delta, n),
885+
{},
886+
until(pt.all(log_delta < self.log_precision)),
792887
)
793888

794-
# gradient of series expansion http://dlmf.nist.gov/8.7#E3
795-
else:
796-
log_precision = np.log(1e-6)
797-
max_iters = int(1e5)
798-
S = 0
799-
log_s = 0.0
800-
s_sign = 1
801-
log_delta = log_s - 2 * np.log(k)
802-
for n in range(1, max_iters + 1):
803-
S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta)
804-
s_sign = -s_sign
805-
log_s += log_x - np.log(n)
806-
log_delta = log_s - 2 * np.log(n + k)
807-
808-
if np.isinf(log_delta):
809-
warnings.warn(
810-
"gammaincc_der did not converge",
811-
RuntimeWarning,
812-
)
813-
return np.nan
814-
815-
if log_delta <= log_precision:
816-
return (
817-
scipy.special.gammainc(k, x) * (digamma_k - log_x)
818-
+ np.exp(k * log_x) * S / gamma_k
819-
)
889+
return inner_fn
820890

821-
warnings.warn(
822-
f"gammaincc_der did not converge after {n} iterations",
823-
RuntimeWarning,
824-
)
825-
return np.nan
826891

827-
def impl(self, k, x):
828-
return self.st_impl(k, x)
829-
830-
def c_code(self, *args, **kwargs):
831-
raise NotImplementedError()
892+
_gammaincc_der_scan2 = GammaIncCDerInnerScan2(
893+
lambda *types: tuple(
894+
same_out(type)[0] for type in types[: GammaIncCDerInnerScan2.nout]
895+
)
896+
)
832897

833898

834-
gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der")
899+
def gammaincc_der(k, x):
900+
gamma_k = gamma(k)
901+
digamma_k = psi(k)
902+
log_x = log(x)
903+
904+
# asymptotic expansion http://dlmf.nist.gov/8.11#E2
905+
S = np.array(0.0, dtype="float64")
906+
dfac = np.array(1.0, dtype="float64")
907+
xpow = x
908+
k_minus_one_minus_n = k - 1
909+
fac = k_minus_one_minus_n
910+
delta = true_div(dfac, xpow)
911+
S, *_ = _gammaincc_der_scan1(S, delta, xpow, k_minus_one_minus_n, fac, dfac, x)
912+
res1 = (
913+
gammaincc(k, x) * (log_x - digamma_k) + exp(-x + (k - 1) * log_x) * S / gamma_k
914+
)
915+
916+
# gradient of series expansion http://dlmf.nist.gov/8.7#E3
917+
S = np.array(0.0, dtype="float64")
918+
log_s = np.array(0.0, dtype="float64")
919+
s_sign = np.array(1, dtype="int8")
920+
n = np.array(1, dtype="int64")
921+
log_delta = log_s - 2 * log(k)
922+
S, *_ = _gammaincc_der_scan2(S, log_s, s_sign, log_delta, n, k, log_x)
923+
res2 = gammainc(k, x) * (digamma_k - log_x) + exp(k * log_x) * S / gamma_k
924+
925+
return switch(
926+
(x >= k) & (x >= 8),
927+
res1,
928+
res2,
929+
)
835930

836931

837932
class GammaU(BinaryScalarOp):

pytensor/scalar/scan.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from pytensor.scalar.basic import ScalarOp
2+
3+
4+
class ScalarScanOp(ScalarOp):
5+
"""Dummy Scalar Op that encapsulates a scalar scan operation.
6+
7+
This Op is never supposed to be evaluated. It can safely be converted
8+
to an Elemwise which is rewritten into a Scan node during compilation.
9+
10+
TODO: FINISH DOCSTRINGS
11+
"""
12+
13+
def impl(self, *args, **kwargs):
14+
raise RuntimeError("Scalar Scan Ops should never be evaluated!")

pytensor/tensor/rewriting/elemwise.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytensor
88
import pytensor.scalar.basic as aes
99
from pytensor import compile
10-
from pytensor.compile.mode import get_target_language
10+
from pytensor.compile.mode import get_target_language, optdb
1111
from pytensor.configdefaults import config
1212
from pytensor.graph.basic import Apply, Constant, io_toposort
1313
from pytensor.graph.features import ReplaceValidate
@@ -20,8 +20,10 @@
2020
)
2121
from pytensor.graph.rewriting.db import SequenceDB
2222
from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
23+
from pytensor.scalar import ScalarScanOp
2324
from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value
2425
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
26+
from pytensor.tensor.extra_ops import broadcast_arrays
2527
from pytensor.tensor.exceptions import NotScalarConstantError
2628
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize
2729
from pytensor.tensor.shape import shape_padleft
@@ -1025,3 +1027,41 @@ def local_careduce_fusion(fgraph, node):
10251027
"fusion",
10261028
position=49,
10271029
)
1030+
1031+
1032+
@node_rewriter([Elemwise])
1033+
def inline_elemwise_scan(fgraph, node):
1034+
from pytensor.scan.basic import scan
1035+
1036+
scalar_op = node.op.scalar_op
1037+
1038+
if not isinstance(scalar_op, ScalarScanOp):
1039+
return None
1040+
1041+
# TODO: Add non-batched implementation?
1042+
n_carried_inputs = scalar_op.nout - scalar_op.nin
1043+
bcasted_inputs = broadcast_arrays(*node.inputs)
1044+
ret, updates = scan(
1045+
scalar_op.fn,
1046+
outputs_info=bcasted_inputs[:n_carried_inputs],
1047+
non_sequences=bcasted_inputs[n_carried_inputs:],
1048+
n_steps=scalar_op.n_steps,
1049+
sequences=None,
1050+
strict=True,
1051+
)
1052+
if updates:
1053+
raise ValueError("Scalar scan should never return updates")
1054+
if scalar_op.nout == 1:
1055+
ret = (ret,)
1056+
return [r[-1] for r in ret]
1057+
1058+
1059+
# We want to run this before the first merge optimizer
1060+
# and before the first scan optimizer.
1061+
optdb.register(
1062+
"inline_elemwise_scan",
1063+
in2out(inline_elemwise_scan),
1064+
"fast_compile",
1065+
"fast_run",
1066+
position=-0.01,
1067+
)

tests/tensor/test_math_scipy.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33

4+
from pytensor.gradient import verify_grad
45

56
scipy = pytest.importorskip("scipy")
67

@@ -9,11 +10,11 @@
910
import scipy.special
1011
import scipy.stats
1112

12-
from pytensor import function
13+
from pytensor import function, grad
1314
from pytensor import tensor as at
1415
from pytensor.compile.mode import get_default_mode
1516
from pytensor.configdefaults import config
16-
from pytensor.tensor import inplace
17+
from pytensor.tensor import inplace, vector, gammaincc
1718
from tests import unittest_tools as utt
1819
from tests.tensor.utils import (
1920
_good_broadcast_unary_chi2sf,
@@ -422,6 +423,23 @@ def test_gammainc_ddk_tabulated_values():
422423
)
423424

424425

426+
def test_gammaincc_ddk_performance(benchmark):
427+
rng = np.random.default_rng(1)
428+
k = vector("k")
429+
x = vector("x")
430+
431+
out = gammaincc(k, x)
432+
grad_fn = function([k, x], grad(out.sum(), wrt=[k]), mode="FAST_RUN")
433+
vals = [
434+
# Values that hit the second branch of the gradient
435+
np.full((1000,), 3.2),
436+
np.full((1000,), 0.01),
437+
]
438+
439+
verify_grad(gammaincc, vals, rng=rng)
440+
benchmark(grad_fn, *vals)
441+
442+
425443
TestGammaUBroadcast = makeBroadcastTester(
426444
op=at.gammau,
427445
expected=expected_gammau,

0 commit comments

Comments
 (0)