Skip to content

Implement Scalar Scan (as dummy Op) #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 158 additions & 72 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
upgrade_to_float64,
upgrade_to_float_no_complex,
)
from pytensor.scalar.scan import ScalarScanOp


class Erf(UnaryScalarOp):
Expand Down Expand Up @@ -751,87 +752,172 @@ def c_code(self, *args, **kwargs):
gammainc_der = GammaIncDer(upgrade_to_float, name="gammainc_der")


class GammaIncCDer(BinaryScalarOp):
"""
Gradient of the the regularized upper gamma function (Q) wrt to the first
argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp`
"""

@staticmethod
def st_impl(k, x):
gamma_k = scipy.special.gamma(k)
digamma_k = scipy.special.digamma(k)
log_x = np.log(x)

# asymptotic expansion http://dlmf.nist.gov/8.11#E2
if (x >= k) and (x >= 8):
S = 0
k_minus_one_minus_n = k - 1
fac = k_minus_one_minus_n
dfac = 1
xpow = x
# class GammaIncCDer(BinaryScalarOp):
# """
# Gradient of the the regularized upper gamma function (Q) wrt to the first
# argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp`
# """
#
# @staticmethod
# def st_impl(k, x):
# gamma_k = scipy.special.gamma(k)
# digamma_k = scipy.special.digamma(k)
# log_x = np.log(x)
#
# # asymptotic expansion http://dlmf.nist.gov/8.11#E2
# if (x >= k) and (x >= 8):
# S = 0
# k_minus_one_minus_n = k - 1
# fac = k_minus_one_minus_n
# dfac = 1
# xpow = x
# delta = dfac / xpow
#
# for n in range(1, 10):
# k_minus_one_minus_n -= 1
# S += delta
# xpow *= x
# dfac = k_minus_one_minus_n * dfac + fac
# fac *= k_minus_one_minus_n
# delta = dfac / xpow
# if np.isinf(delta):
# warnings.warn(
# "gammaincc_der did not converge",
# RuntimeWarning,
# )
# return np.nan
#
# return (
# scipy.special.gammaincc(k, x) * (log_x - digamma_k)
# + np.exp(-x + (k - 1) * log_x) * S / gamma_k
# )
#
# # gradient of series expansion http://dlmf.nist.gov/8.7#E3
# else:
# log_precision = np.log(1e-6)
# max_iters = int(1e5)
# S = 0
# log_s = 0.0
# s_sign = 1
# log_delta = log_s - 2 * np.log(k)
# for n in range(1, max_iters + 1):
# S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta)
# s_sign = -s_sign
# log_s += log_x - np.log(n)
# log_delta = log_s - 2 * np.log(n + k)
#
# if np.isinf(log_delta):
# warnings.warn(
# "gammaincc_der did not converge",
# RuntimeWarning,
# )
# return np.nan
#
# if log_delta <= log_precision:
# return (
# scipy.special.gammainc(k, x) * (digamma_k - log_x)
# + np.exp(k * log_x) * S / gamma_k
# )
#
# warnings.warn(
# f"gammaincc_der did not converge after {n} iterations",
# RuntimeWarning,
# )
# return np.nan
#
# def impl(self, k, x):
# return self.st_impl(k, x)
#
# def c_code(self, *args, **kwargs):
# raise NotImplementedError()
#
#
# gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der")


class GammaIncCDerInnerScan1(ScalarScanOp):
nin = 7
nout = 6
n_steps = 9

@property
def fn(self):
def inner_fn(S, delta, xpow, k_minus_one_minus_n, dfac, fac, x):
S += delta
xpow *= x
k_minus_one_minus_n -= 1
dfac = k_minus_one_minus_n * dfac + fac
fac *= k_minus_one_minus_n
delta = dfac / xpow
return S, delta, xpow, k_minus_one_minus_n, dfac, fac

for n in range(1, 10):
k_minus_one_minus_n -= 1
S += delta
xpow *= x
dfac = k_minus_one_minus_n * dfac + fac
fac *= k_minus_one_minus_n
delta = dfac / xpow
if np.isinf(delta):
warnings.warn(
"gammaincc_der did not converge",
RuntimeWarning,
)
return np.nan
return inner_fn

return (
scipy.special.gammaincc(k, x) * (log_x - digamma_k)
+ np.exp(-x + (k - 1) * log_x) * S / gamma_k
)

# gradient of series expansion http://dlmf.nist.gov/8.7#E3
else:
log_precision = np.log(1e-6)
max_iters = int(1e5)
S = 0
log_s = 0.0
s_sign = 1
log_delta = log_s - 2 * np.log(k)
for n in range(1, max_iters + 1):
S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta)
s_sign = -s_sign
log_s += log_x - np.log(n)
log_delta = log_s - 2 * np.log(n + k)

if np.isinf(log_delta):
warnings.warn(
"gammaincc_der did not converge",
RuntimeWarning,
)
return np.nan

if log_delta <= log_precision:
return (
scipy.special.gammainc(k, x) * (digamma_k - log_x)
+ np.exp(k * log_x) * S / gamma_k
)
_gammaincc_der_scan1 = GammaIncCDerInnerScan1()

warnings.warn(
f"gammaincc_der did not converge after {n} iterations",
RuntimeWarning,
)
return np.nan

def impl(self, k, x):
return self.st_impl(k, x)
class GammaIncCDerInnerScan2(ScalarScanOp):
nin = 7
nout = 5
n_steps = int(1e5) # maximum number of iterations
log_precision = np.log(1e-6)

def c_code(self, *args, **kwargs):
raise NotImplementedError()
@property
def fn(self):
import pytensor.tensor as pt
from pytensor.scan import until

def inner_fn(S, log_s, s_sign, log_delta, n, k, log_x):
delta = pt.exp(log_delta)
S += pt.switch(s_sign > 0, delta, -delta)
s_sign = -s_sign
log_s += log_x - pt.log(n)
log_delta = log_s - 2 * pt.log(n + k)
n += 1
return (
(S, log_s, s_sign, log_delta, n),
{},
until(pt.all(log_delta < self.log_precision)),
)

gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der")
return inner_fn


_gammaincc_der_scan2 = GammaIncCDerInnerScan2()


def gammaincc_der(k, x):
gamma_k = gamma(k)
digamma_k = psi(k)
log_x = log(x)

# asymptotic expansion http://dlmf.nist.gov/8.11#E2
S = np.array(0.0, dtype="float64")
dfac = np.array(1.0, dtype="float64")
xpow = x
k_minus_one_minus_n = k - 1
fac = k_minus_one_minus_n
delta = true_div(dfac, xpow)
S, *_ = _gammaincc_der_scan1(S, delta, xpow, k_minus_one_minus_n, fac, dfac, x)
res1 = (
gammaincc(k, x) * (log_x - digamma_k) + exp(-x + (k - 1) * log_x) * S / gamma_k
)

# gradient of series expansion http://dlmf.nist.gov/8.7#E3
S = np.array(0.0, dtype="float64")
log_s = np.array(0.0, dtype="float64")
s_sign = np.array(1, dtype="int8")
n = np.array(1, dtype="int64")
log_delta = log_s - 2 * log(k)
S, *_ = _gammaincc_der_scan2(S, log_s, s_sign, log_delta, n, k, log_x)
res2 = gammainc(k, x) * (digamma_k - log_x) + exp(k * log_x) * S / gamma_k

return switch(
(x >= k) & (x >= 8),
res1,
res2,
)


class GammaU(BinaryScalarOp):
Expand Down
23 changes: 23 additions & 0 deletions pytensor/scalar/scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pytensor.scalar.basic import ScalarOp, same_out


class ScalarScanOp(ScalarOp):
"""Dummy Scalar Op that encapsulates a scalar scan operation.

This Op is never supposed to be evaluated. It can safely be converted
to an Elemwise which is rewritten into a Scan node during compilation.

TODO: FINISH DOCSTRINGS
TODO: ABC for fn property
"""

def __init__(self, output_types_preference=None, **kwargs):
if output_types_preference is None:

def output_types_preference(*types):
return tuple(same_out(type)[0] for type in types[: self.nout])

super().__init__(output_types_preference=output_types_preference, **kwargs)

def impl(self, *args, **kwargs):
raise RuntimeError("Scalar Scan Ops should never be evaluated!")
3 changes: 3 additions & 0 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,9 @@ def transform(r):
return DimShuffle((), ["x"] * nd)(res)

new_r = Elemwise(node.op, {})(*[transform(ipt) for ipt in node.inputs])
if isinstance(new_r, (list, tuple)):
# Scalar Op with multiple outputs
new_r = new_r[r.owner.outputs.index(r)]
return new_r

ret = []
Expand Down
59 changes: 58 additions & 1 deletion pytensor/tensor/rewriting/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytensor
import pytensor.scalar.basic as aes
from pytensor import compile
from pytensor.compile.mode import get_target_language
from pytensor.compile.mode import get_target_language, optdb
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Constant, io_toposort
from pytensor.graph.features import ReplaceValidate
Expand All @@ -20,11 +20,14 @@
)
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
from pytensor.scalar import ScalarScanOp
from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.subtensor import IncSubtensor
from pytensor.tensor.var import TensorConstant


Expand Down Expand Up @@ -1025,3 +1028,57 @@ def local_careduce_fusion(fgraph, node):
"fusion",
position=49,
)


@node_rewriter([Elemwise])
def inline_elemwise_scan(fgraph, node):
from pytensor.scan.basic import scan
from pytensor.scan.utils import expand_empty

scalar_op = node.op.scalar_op

if not isinstance(scalar_op, ScalarScanOp):
return None

# TODO: Add non-batched implementation? That should be better for scans with big difference in required n_steps
bcasted_inputs = broadcast_arrays(*node.inputs)
ret, updates = scan(
scalar_op.fn,
outputs_info=bcasted_inputs[: scalar_op.nout],
non_sequences=bcasted_inputs[scalar_op.nout :],
n_steps=scalar_op.n_steps,
sequences=None,
strict=True,
)
if updates:
raise ValueError("Scalar scan should never return updates")
if scalar_op.nout == 1:
ret = (ret,)

# Scan output size is given by the size of the input leading dimension, by default its n_steps + 1.
# If we only want to store the last elements we can shorten the leading dimension to 1
scan_node = ret[0].owner.inputs[0].owner
scan_inputs = scan_node.inputs
n_steps = scan_inputs[0]
n_non_seqs = scan_node.op.info.n_non_seqs
carried_inputs = scan_inputs[1 : len(scan_inputs) - n_non_seqs :]
constant_inputs = scan_inputs[len(scan_inputs) - n_non_seqs :]
new_carried_inputs = []
for carried_input in carried_inputs:
assert isinstance(carried_input.owner.op, IncSubtensor)
fill_value = carried_input.owner.inputs[1]
# TODO: Check for the global flag where this is controlled
new_carried_inputs.append(expand_empty(fill_value, 1))
ret = scan_node.op.make_node(n_steps, *new_carried_inputs, *constant_inputs).outputs
Comment on lines +1058 to +1072
Copy link
Member Author

@ricardoV94 ricardoV94 Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One shouldn't have to hack into scan internals to avoid saving the intermediate results... This is needed because of #178


return [r[1] for r in ret]


# We want to run this after the scan save mem rewrite, as we already applied it here
optdb.register(
"inline_elemwise_scan",
in2out(inline_elemwise_scan),
"fast_compile",
"fast_run",
position=1.62,
)
Loading