|
10 | 10 | from aesara.graph.fg import FunctionGraph
|
11 | 11 | from aesara.graph.op import Op
|
12 | 12 | from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
|
13 |
| -from aesara.tensor.math import add, exp, log, mul, reciprocal, sub, true_div |
| 13 | +from aesara.tensor.math import add, exp, log, mul, pow, reciprocal, sub, true_div |
14 | 14 | from aesara.tensor.rewriting.basic import (
|
15 | 15 | register_specialize,
|
16 | 16 | register_stabilize,
|
@@ -422,8 +422,20 @@ def transform(measurable_input, *other_inputs):
|
422 | 422 | def measurable_reciprocal(fgraph, node):
|
423 | 423 | """Rewrite a `reciprocal` node to a `MeasurableVariable`."""
|
424 | 424 |
|
425 |
| - def transform(measurable_input, *other_inputs): |
426 |
| - return ReciprocalTransform(), (measurable_input,) |
| 425 | + new_node = at.power(node.inputs[0], at.as_tensor(-1)).owner |
| 426 | + return measurable_pow.transform(fgraph, new_node) |
| 427 | + |
| 428 | + |
| 429 | +@register_measurable_ir |
| 430 | +@node_rewriter([pow]) |
| 431 | +def measurable_pow(fgraph, node): |
| 432 | + """Rewrite a `pow` node to a `MeasurableVariable`.""" |
| 433 | + |
| 434 | + def transform(measurable_input, *args): |
| 435 | + return PowerTransform(transform_args_fn=lambda *inputs: inputs[-1]), ( |
| 436 | + measurable_input, |
| 437 | + *args, |
| 438 | + ) |
427 | 439 |
|
428 | 440 | return construct_elemwise_transform(fgraph, node, transform)
|
429 | 441 |
|
@@ -579,17 +591,31 @@ def log_jac_det(self, value, *inputs):
|
579 | 591 | return -at.log(value)
|
580 | 592 |
|
581 | 593 |
|
582 |
| -class ReciprocalTransform(RVTransform): |
583 |
| - name = "reciprocal" |
| 594 | +class PowerTransform(RVTransform): |
| 595 | + name = "power" |
| 596 | + |
| 597 | + def __init__(self, transform_args_fn): |
| 598 | + self.transform_args_fn = transform_args_fn |
584 | 599 |
|
585 | 600 | def forward(self, value, *inputs):
|
586 |
| - return at.reciprocal(value) |
| 601 | + power = self.transform_args_fn(*inputs) |
| 602 | + return at.power(value, power) |
587 | 603 |
|
588 | 604 | def backward(self, value, *inputs):
|
589 |
| - return at.reciprocal(value) |
| 605 | + power = self.transform_args_fn(*inputs) |
| 606 | + |
| 607 | + inv_power = at.reciprocal(power) |
| 608 | + return at.switch( |
| 609 | + at.eq(at.mod(power, 2), 0), |
| 610 | + at.power(value, inv_power), |
| 611 | + at.sgn(value) * at.power(at.abs(value), inv_power), |
| 612 | + ) |
590 | 613 |
|
591 | 614 | def log_jac_det(self, value, *inputs):
|
592 |
| - return -2 * at.log(value) |
| 615 | + from aeppl.logprob import xlogy0 |
| 616 | + |
| 617 | + power = self.transform_args_fn(*inputs) |
| 618 | + return at.log(at.abs(power)) + xlogy0((power - 1), at.abs(value)) |
593 | 619 |
|
594 | 620 |
|
595 | 621 | class IntervalTransform(RVTransform):
|
|
0 commit comments