Skip to content

Commit aad7681

Browse files
committed
WIP implement scalar Scan Op
1 parent 33d4d36 commit aad7681

File tree

4 files changed

+259
-75
lines changed

4 files changed

+259
-75
lines changed

pytensor/scalar/math.py

Lines changed: 158 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
upgrade_to_float64,
3535
upgrade_to_float_no_complex,
3636
)
37+
from pytensor.scalar.scan import ScalarScanOp
3738

3839

3940
class Erf(UnaryScalarOp):
@@ -751,87 +752,172 @@ def c_code(self, *args, **kwargs):
751752
gammainc_der = GammaIncDer(upgrade_to_float, name="gammainc_der")
752753

753754

754-
class GammaIncCDer(BinaryScalarOp):
755-
"""
756-
Gradient of the the regularized upper gamma function (Q) wrt to the first
757-
argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp`
758-
"""
759-
760-
@staticmethod
761-
def st_impl(k, x):
762-
gamma_k = scipy.special.gamma(k)
763-
digamma_k = scipy.special.digamma(k)
764-
log_x = np.log(x)
765-
766-
# asymptotic expansion http://dlmf.nist.gov/8.11#E2
767-
if (x >= k) and (x >= 8):
768-
S = 0
769-
k_minus_one_minus_n = k - 1
770-
fac = k_minus_one_minus_n
771-
dfac = 1
772-
xpow = x
755+
# class GammaIncCDer(BinaryScalarOp):
756+
# """
757+
# Gradient of the the regularized upper gamma function (Q) wrt to the first
758+
# argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp`
759+
# """
760+
#
761+
# @staticmethod
762+
# def st_impl(k, x):
763+
# gamma_k = scipy.special.gamma(k)
764+
# digamma_k = scipy.special.digamma(k)
765+
# log_x = np.log(x)
766+
#
767+
# # asymptotic expansion http://dlmf.nist.gov/8.11#E2
768+
# if (x >= k) and (x >= 8):
769+
# S = 0
770+
# k_minus_one_minus_n = k - 1
771+
# fac = k_minus_one_minus_n
772+
# dfac = 1
773+
# xpow = x
774+
# delta = dfac / xpow
775+
#
776+
# for n in range(1, 10):
777+
# k_minus_one_minus_n -= 1
778+
# S += delta
779+
# xpow *= x
780+
# dfac = k_minus_one_minus_n * dfac + fac
781+
# fac *= k_minus_one_minus_n
782+
# delta = dfac / xpow
783+
# if np.isinf(delta):
784+
# warnings.warn(
785+
# "gammaincc_der did not converge",
786+
# RuntimeWarning,
787+
# )
788+
# return np.nan
789+
#
790+
# return (
791+
# scipy.special.gammaincc(k, x) * (log_x - digamma_k)
792+
# + np.exp(-x + (k - 1) * log_x) * S / gamma_k
793+
# )
794+
#
795+
# # gradient of series expansion http://dlmf.nist.gov/8.7#E3
796+
# else:
797+
# log_precision = np.log(1e-6)
798+
# max_iters = int(1e5)
799+
# S = 0
800+
# log_s = 0.0
801+
# s_sign = 1
802+
# log_delta = log_s - 2 * np.log(k)
803+
# for n in range(1, max_iters + 1):
804+
# S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta)
805+
# s_sign = -s_sign
806+
# log_s += log_x - np.log(n)
807+
# log_delta = log_s - 2 * np.log(n + k)
808+
#
809+
# if np.isinf(log_delta):
810+
# warnings.warn(
811+
# "gammaincc_der did not converge",
812+
# RuntimeWarning,
813+
# )
814+
# return np.nan
815+
#
816+
# if log_delta <= log_precision:
817+
# return (
818+
# scipy.special.gammainc(k, x) * (digamma_k - log_x)
819+
# + np.exp(k * log_x) * S / gamma_k
820+
# )
821+
#
822+
# warnings.warn(
823+
# f"gammaincc_der did not converge after {n} iterations",
824+
# RuntimeWarning,
825+
# )
826+
# return np.nan
827+
#
828+
# def impl(self, k, x):
829+
# return self.st_impl(k, x)
830+
#
831+
# def c_code(self, *args, **kwargs):
832+
# raise NotImplementedError()
833+
#
834+
#
835+
# gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der")
836+
837+
838+
class GammaIncCDerInnerScan1(ScalarScanOp):
839+
nin = 7
840+
nout = 6
841+
n_steps = 9
842+
843+
@property
844+
def fn(self):
845+
def inner_fn(S, delta, xpow, k_minus_one_minus_n, dfac, fac, x):
846+
S += delta
847+
xpow *= x
848+
k_minus_one_minus_n -= 1
849+
dfac = k_minus_one_minus_n * dfac + fac
850+
fac *= k_minus_one_minus_n
773851
delta = dfac / xpow
852+
return S, delta, xpow, k_minus_one_minus_n, dfac, fac
774853

775-
for n in range(1, 10):
776-
k_minus_one_minus_n -= 1
777-
S += delta
778-
xpow *= x
779-
dfac = k_minus_one_minus_n * dfac + fac
780-
fac *= k_minus_one_minus_n
781-
delta = dfac / xpow
782-
if np.isinf(delta):
783-
warnings.warn(
784-
"gammaincc_der did not converge",
785-
RuntimeWarning,
786-
)
787-
return np.nan
854+
return inner_fn
788855

789-
return (
790-
scipy.special.gammaincc(k, x) * (log_x - digamma_k)
791-
+ np.exp(-x + (k - 1) * log_x) * S / gamma_k
792-
)
793856

794-
# gradient of series expansion http://dlmf.nist.gov/8.7#E3
795-
else:
796-
log_precision = np.log(1e-6)
797-
max_iters = int(1e5)
798-
S = 0
799-
log_s = 0.0
800-
s_sign = 1
801-
log_delta = log_s - 2 * np.log(k)
802-
for n in range(1, max_iters + 1):
803-
S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta)
804-
s_sign = -s_sign
805-
log_s += log_x - np.log(n)
806-
log_delta = log_s - 2 * np.log(n + k)
807-
808-
if np.isinf(log_delta):
809-
warnings.warn(
810-
"gammaincc_der did not converge",
811-
RuntimeWarning,
812-
)
813-
return np.nan
814-
815-
if log_delta <= log_precision:
816-
return (
817-
scipy.special.gammainc(k, x) * (digamma_k - log_x)
818-
+ np.exp(k * log_x) * S / gamma_k
819-
)
857+
_gammaincc_der_scan1 = GammaIncCDerInnerScan1()
820858

821-
warnings.warn(
822-
f"gammaincc_der did not converge after {n} iterations",
823-
RuntimeWarning,
824-
)
825-
return np.nan
826859

827-
def impl(self, k, x):
828-
return self.st_impl(k, x)
860+
class GammaIncCDerInnerScan2(ScalarScanOp):
861+
nin = 7
862+
nout = 5
863+
n_steps = int(1e5) # maximum number of iterations
864+
log_precision = np.log(1e-6)
829865

830-
def c_code(self, *args, **kwargs):
831-
raise NotImplementedError()
866+
@property
867+
def fn(self):
868+
import pytensor.tensor as pt
869+
from pytensor.scan import until
832870

871+
def inner_fn(S, log_s, s_sign, log_delta, n, k, log_x):
872+
delta = pt.exp(log_delta)
873+
S += pt.switch(s_sign > 0, delta, -delta)
874+
s_sign = -s_sign
875+
log_s += log_x - pt.log(n)
876+
log_delta = log_s - 2 * pt.log(n + k)
877+
n += 1
878+
return (
879+
(S, log_s, s_sign, log_delta, n),
880+
{},
881+
until(pt.all(log_delta < self.log_precision)),
882+
)
833883

834-
gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der")
884+
return inner_fn
885+
886+
887+
_gammaincc_der_scan2 = GammaIncCDerInnerScan2()
888+
889+
890+
def gammaincc_der(k, x):
891+
gamma_k = gamma(k)
892+
digamma_k = psi(k)
893+
log_x = log(x)
894+
895+
# asymptotic expansion http://dlmf.nist.gov/8.11#E2
896+
S = np.array(0.0, dtype="float64")
897+
dfac = np.array(1.0, dtype="float64")
898+
xpow = x
899+
k_minus_one_minus_n = k - 1
900+
fac = k_minus_one_minus_n
901+
delta = true_div(dfac, xpow)
902+
S, *_ = _gammaincc_der_scan1(S, delta, xpow, k_minus_one_minus_n, fac, dfac, x)
903+
res1 = (
904+
gammaincc(k, x) * (log_x - digamma_k) + exp(-x + (k - 1) * log_x) * S / gamma_k
905+
)
906+
907+
# gradient of series expansion http://dlmf.nist.gov/8.7#E3
908+
S = np.array(0.0, dtype="float64")
909+
log_s = np.array(0.0, dtype="float64")
910+
s_sign = np.array(1, dtype="int8")
911+
n = np.array(1, dtype="int64")
912+
log_delta = log_s - 2 * log(k)
913+
S, *_ = _gammaincc_der_scan2(S, log_s, s_sign, log_delta, n, k, log_x)
914+
res2 = gammainc(k, x) * (digamma_k - log_x) + exp(k * log_x) * S / gamma_k
915+
916+
return switch(
917+
(x >= k) & (x >= 8),
918+
res1,
919+
res2,
920+
)
835921

836922

837923
class GammaU(BinaryScalarOp):

pytensor/scalar/scan.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from pytensor.scalar.basic import ScalarOp, same_out
2+
3+
4+
class ScalarScanOp(ScalarOp):
5+
"""Dummy Scalar Op that encapsulates a scalar scan operation.
6+
7+
This Op is never supposed to be evaluated. It can safely be converted
8+
to an Elemwise which is rewritten into a Scan node during compilation.
9+
10+
TODO: FINISH DOCSTRINGS
11+
TODO: ABC for fn property
12+
"""
13+
14+
def __init__(self, output_types_preference=None, **kwargs):
15+
if output_types_preference is None:
16+
17+
def output_types_preference(*types):
18+
return tuple(same_out(type)[0] for type in types[: self.nout])
19+
20+
super().__init__(output_types_preference=output_types_preference, **kwargs)
21+
22+
def impl(self, *args, **kwargs):
23+
raise RuntimeError("Scalar Scan Ops should never be evaluated!")

pytensor/tensor/rewriting/elemwise.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytensor
88
import pytensor.scalar.basic as aes
99
from pytensor import compile
10-
from pytensor.compile.mode import get_target_language
10+
from pytensor.compile.mode import get_target_language, optdb
1111
from pytensor.configdefaults import config
1212
from pytensor.graph.basic import Apply, Constant, io_toposort
1313
from pytensor.graph.features import ReplaceValidate
@@ -20,11 +20,14 @@
2020
)
2121
from pytensor.graph.rewriting.db import SequenceDB
2222
from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
23+
from pytensor.scalar import ScalarScanOp
2324
from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value
2425
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
2526
from pytensor.tensor.exceptions import NotScalarConstantError
27+
from pytensor.tensor.extra_ops import broadcast_arrays
2628
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize
2729
from pytensor.tensor.shape import shape_padleft
30+
from pytensor.tensor.subtensor import IncSubtensor
2831
from pytensor.tensor.var import TensorConstant
2932

3033

@@ -1025,3 +1028,57 @@ def local_careduce_fusion(fgraph, node):
10251028
"fusion",
10261029
position=49,
10271030
)
1031+
1032+
1033+
@node_rewriter([Elemwise])
1034+
def inline_elemwise_scan(fgraph, node):
1035+
from pytensor.scan.basic import scan
1036+
from pytensor.scan.utils import expand_empty
1037+
1038+
scalar_op = node.op.scalar_op
1039+
1040+
if not isinstance(scalar_op, ScalarScanOp):
1041+
return None
1042+
1043+
# TODO: Add non-batched implementation? That should be better for scans with big difference in required n_steps
1044+
bcasted_inputs = broadcast_arrays(*node.inputs)
1045+
ret, updates = scan(
1046+
scalar_op.fn,
1047+
outputs_info=bcasted_inputs[: scalar_op.nout],
1048+
non_sequences=bcasted_inputs[scalar_op.nout :],
1049+
n_steps=scalar_op.n_steps,
1050+
sequences=None,
1051+
strict=True,
1052+
)
1053+
if updates:
1054+
raise ValueError("Scalar scan should never return updates")
1055+
if scalar_op.nout == 1:
1056+
ret = (ret,)
1057+
1058+
# Scan output size is given by the size of the input leading dimension, by default its n_steps + 1.
1059+
# If we only want to store the last elements we can shorten the leading dimension to 1
1060+
scan_node = ret[0].owner.inputs[0].owner
1061+
scan_inputs = scan_node.inputs
1062+
n_steps = scan_inputs[0]
1063+
n_non_seqs = scan_node.op.info.n_non_seqs
1064+
carried_inputs = scan_inputs[1 : len(scan_inputs) - n_non_seqs :]
1065+
constant_inputs = scan_inputs[len(scan_inputs) - n_non_seqs :]
1066+
new_carried_inputs = []
1067+
for carried_input in carried_inputs:
1068+
assert isinstance(carried_input.owner.op, IncSubtensor)
1069+
fill_value = carried_input.owner.inputs[1]
1070+
# TODO: Check for the global flag where this is controlled
1071+
new_carried_inputs.append(expand_empty(fill_value, 1))
1072+
ret = scan_node.op.make_node(n_steps, *new_carried_inputs, *constant_inputs).outputs
1073+
1074+
return [r[1] for r in ret]
1075+
1076+
1077+
# We want to run this after the scan save mem rewrite, as we already applied it here
1078+
optdb.register(
1079+
"inline_elemwise_scan",
1080+
in2out(inline_elemwise_scan),
1081+
"fast_compile",
1082+
"fast_run",
1083+
position=1.62,
1084+
)

0 commit comments

Comments
 (0)