Skip to content

Commit ed2161d

Browse files
Add support for power transforms
1 parent 473c1e6 commit ed2161d

File tree

2 files changed

+53
-8
lines changed

2 files changed

+53
-8
lines changed

aeppl/transforms.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from aesara.graph.fg import FunctionGraph
1111
from aesara.graph.op import Op
1212
from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
13-
from aesara.tensor.math import add, exp, log, mul, reciprocal, sub, true_div
13+
from aesara.tensor.math import add, exp, log, mul, pow, reciprocal, sub, true_div
1414
from aesara.tensor.rewriting.basic import (
1515
register_specialize,
1616
register_stabilize,
@@ -422,8 +422,20 @@ def transform(measurable_input, *other_inputs):
422422
def measurable_reciprocal(fgraph, node):
423423
"""Rewrite a `reciprocal` node to a `MeasurableVariable`."""
424424

425-
def transform(measurable_input, *other_inputs):
426-
return ReciprocalTransform(), (measurable_input,)
425+
new_node = at.power(node.inputs[0], at.as_tensor(-1)).owner
426+
return measurable_pow.transform(fgraph, new_node)
427+
428+
429+
@register_measurable_ir
430+
@node_rewriter([pow])
431+
def measurable_pow(fgraph, node):
432+
"""Rewrite a `pow` node to a `MeasurableVariable`."""
433+
434+
def transform(measurable_input, *args):
435+
return PowerTransform(transform_args_fn=lambda *inputs: inputs[-1]), (
436+
measurable_input,
437+
*args,
438+
)
427439

428440
return construct_elemwise_transform(fgraph, node, transform)
429441

@@ -579,17 +591,31 @@ def log_jac_det(self, value, *inputs):
579591
return -at.log(value)
580592

581593

582-
class ReciprocalTransform(RVTransform):
583-
name = "reciprocal"
594+
class PowerTransform(RVTransform):
595+
name = "power"
596+
597+
def __init__(self, transform_args_fn):
598+
self.transform_args_fn = transform_args_fn
584599

585600
def forward(self, value, *inputs):
586-
return at.reciprocal(value)
601+
power = self.transform_args_fn(*inputs)
602+
return at.power(value, power)
587603

588604
def backward(self, value, *inputs):
589-
return at.reciprocal(value)
605+
power = self.transform_args_fn(*inputs)
606+
607+
inv_power = at.reciprocal(power)
608+
return at.switch(
609+
at.eq(at.mod(power, 2), 0),
610+
at.power(value, inv_power),
611+
at.sgn(value) * at.power(at.abs(value), inv_power),
612+
)
590613

591614
def log_jac_det(self, value, *inputs):
592-
return -2 * at.log(value)
615+
from aeppl.logprob import xlogy0
616+
617+
power = self.transform_args_fn(*inputs)
618+
return at.log(at.abs(power)) + xlogy0((power - 1), at.abs(value))
593619

594620

595621
class IntervalTransform(RVTransform):

tests/test_transforms.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,3 +763,22 @@ def test_transform_measurable_sub():
763763

764764
with pytest.raises(RuntimeError, match="The logprob terms"):
765765
joint_logprob(Z_rv, X_rv)
766+
767+
768+
@pytest.mark.parametrize(
769+
"pow_fn, exp_val_fn",
770+
[
771+
(lambda x: x**2, lambda z: sp.stats.chi2(df=1).logpdf(z))
772+
# TODO: Add more cases.
773+
],
774+
)
775+
def test_transform_measurable_pow(pow_fn, exp_val_fn):
776+
X_rv = at.random.normal(0, 1, name="X")
777+
Z_rv = pow_fn(X_rv)
778+
Z_rv.name = "Z"
779+
780+
z_logp, (z_vv,) = conditional_logprob(Z_rv)
781+
z_logp_fn = aesara.function([z_vv], z_logp[Z_rv])
782+
783+
z_val = 0.5
784+
assert np.allclose(z_logp_fn(z_val), exp_val_fn(z_val))

0 commit comments

Comments
 (0)