diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 51ec6ebcf1..0384f90470 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -50,7 +50,7 @@ rv_size_is_none, shape_from_dims, ) -from pymc.logprob.abstract import MeasurableVariable, _icdf, _logcdf, _logprob +from pymc.logprob.abstract import MeasurableOp, _icdf, _logcdf, _logprob from pymc.logprob.basic import logp from pymc.logprob.rewriting import logprob_rewrites_db from pymc.printing import str_for_dist @@ -228,7 +228,7 @@ def __get__(self, instance, type_): return descr_get(instance, type_) -class SymbolicRandomVariable(OpFromGraph): +class SymbolicRandomVariable(MeasurableOp, OpFromGraph): """Symbolic Random Variable This is a subclasse of `OpFromGraph` which is used to encapsulate the symbolic @@ -624,10 +624,6 @@ def dist( return rv_out -# Let PyMC know that the SymbolicRandomVariable has a logprob. -MeasurableVariable.register(SymbolicRandomVariable) - - @node_rewriter([SymbolicRandomVariable]) def inline_symbolic_random_variable(fgraph, node): """ diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index 41c92e422d..38d06505d0 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -35,6 +35,7 @@ # SOFTWARE. import abc +import warnings from collections.abc import Sequence from functools import singledispatch @@ -46,6 +47,17 @@ from pytensor.tensor.random.op import RandomVariable +def __getattr__(name): + if name == "MeasurableVariable": + warnings.warn( + f"{name} has been deprecated in favor of MeasurableOp. Importing will fail in a future release.", + FutureWarning, + ) + return MeasurableOpMixin + + raise AttributeError(f"module {__name__} has no attribute {name}") + + @singledispatch def _logprob( op: Op, @@ -131,14 +143,21 @@ def _icdf_helper(rv, value, **kwargs): return rv_icdf -class MeasurableVariable(abc.ABC): - """A variable that can be assigned a measure/log-probability""" +class MeasurableOp(abc.ABC): + """An operation whose outputs can be assigned a measure/log-probability""" + +MeasurableOp.register(RandomVariable) -MeasurableVariable.register(RandomVariable) +class MeasurableOpMixin(MeasurableOp): + """MeasurableOp Mixin with a distinctive string representation""" -class MeasurableElemwise(Elemwise): + def __str__(self): + return f"Measurable{super().__str__()}" + + +class MeasurableElemwise(MeasurableOpMixin, Elemwise): """Base class for Measurable Elemwise variables""" valid_scalar_types: tuple[MetaType, ...] = () @@ -150,6 +169,3 @@ def __init__(self, scalar_op, *args, **kwargs): f"Acceptable types are {self.valid_scalar_types}" ) super().__init__(scalar_op, *args, **kwargs) - - -MeasurableVariable.register(MeasurableElemwise) diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index 3cf91479e8..c945baa751 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -56,7 +56,7 @@ from pytensor.tensor.variable import TensorVariable from pymc.logprob.abstract import ( - MeasurableVariable, + MeasurableOp, _icdf_helper, _logcdf_helper, _logprob, @@ -522,7 +522,7 @@ def conditional_logp( while q: node = q.popleft() - if not isinstance(node.op, MeasurableVariable): + if not isinstance(node.op, MeasurableOp): continue q_values = [replacements[q_rv] for q_rv in node.outputs if q_rv in updated_rv_values] diff --git a/pymc/logprob/checks.py b/pymc/logprob/checks.py index f7b483e599..fdee9d689a 100644 --- a/pymc/logprob/checks.py +++ b/pymc/logprob/checks.py @@ -42,18 +42,15 @@ from pytensor.tensor import TensorVariable from pytensor.tensor.shape import SpecifyShape -from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper +from pymc.logprob.abstract import MeasurableOp, MeasurableOpMixin, _logprob, _logprob_helper from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import replace_rvs_by_values -class MeasurableSpecifyShape(SpecifyShape): +class MeasurableSpecifyShape(MeasurableOpMixin, SpecifyShape): """A placeholder used to specify a log-likelihood for a specify-shape sub-graph.""" -MeasurableVariable.register(MeasurableSpecifyShape) - - @_logprob.register(MeasurableSpecifyShape) def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs): (value,) = values @@ -80,7 +77,7 @@ def find_measurable_specify_shapes(fgraph, node) -> list[TensorVariable] | None: if not ( base_rv.owner - and isinstance(base_rv.owner.op, MeasurableVariable) + and isinstance(base_rv.owner.op, MeasurableOp) and base_rv not in rv_map_feature.rv_values ): return None # pragma: no cover @@ -99,13 +96,10 @@ def find_measurable_specify_shapes(fgraph, node) -> list[TensorVariable] | None: ) -class MeasurableCheckAndRaise(CheckAndRaise): +class MeasurableCheckAndRaise(MeasurableOpMixin, CheckAndRaise): """A placeholder used to specify a log-likelihood for an assert sub-graph.""" -MeasurableVariable.register(MeasurableCheckAndRaise) - - @_logprob.register(MeasurableCheckAndRaise) def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs): (value,) = values diff --git a/pymc/logprob/cumsum.py b/pymc/logprob/cumsum.py index 3c1c9d3e72..777cb05da5 100644 --- a/pymc/logprob/cumsum.py +++ b/pymc/logprob/cumsum.py @@ -41,17 +41,14 @@ from pytensor.tensor import TensorVariable from pytensor.tensor.extra_ops import CumOp -from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper +from pymc.logprob.abstract import MeasurableOpMixin, _logprob, _logprob_helper from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db -class MeasurableCumsum(CumOp): +class MeasurableCumsum(MeasurableOpMixin, CumOp): """A placeholder used to specify a log-likelihood for a cumsum sub-graph.""" -MeasurableVariable.register(MeasurableCumsum) - - @_logprob.register(MeasurableCumsum) def logprob_cumsum(op, values, base_rv, **kwargs): """Compute the log-likelihood graph for a `Cumsum`.""" diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 08e102f805..15e4e2a82e 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -67,7 +67,8 @@ from pymc.logprob.abstract import ( MeasurableElemwise, - MeasurableVariable, + MeasurableOp, + MeasurableOpMixin, _logprob, _logprob_helper, ) @@ -217,7 +218,7 @@ def rv_pull_down(x: TensorVariable) -> TensorVariable: return fgraph.outputs[0] -class MixtureRV(Op): +class MixtureRV(MeasurableOpMixin, Op): """A placeholder used to specify a log-likelihood for a mixture sub-graph.""" __props__ = ("indices_end_idx", "out_dtype", "out_broadcastable") @@ -235,9 +236,6 @@ def perform(self, node, inputs, outputs): raise NotImplementedError("This is a stand-in Op.") # pragma: no cover -MeasurableVariable.register(MixtureRV) - - def get_stack_mixture_vars( node: Apply, ) -> tuple[list[TensorVariable] | None, int | None]: @@ -457,13 +455,10 @@ def logprob_switch_mixture(op, values, switch_cond, component_true, component_fa ) -class MeasurableIfElse(IfElse): +class MeasurableIfElse(MeasurableOpMixin, IfElse): """Measurable subclass of IfElse operator.""" -MeasurableVariable.register(MeasurableIfElse) - - @node_rewriter([IfElse]) def useless_ifelse_outputs(fgraph, node): """Remove outputs that are shared across the IfElse branches.""" @@ -512,7 +507,7 @@ def find_measurable_ifelse_mixture(fgraph, node): base_rvs = assume_measured_ir_outputs(valued_rvs, base_rvs) if len(base_rvs) != op.n_outs * 2: return None - if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_rvs): + if not all(var.owner and isinstance(var.owner.op, MeasurableOp) for var in base_rvs): return None return MeasurableIfElse(n_outs=op.n_outs).make_node(if_var, *base_rvs).outputs diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index f9fa8cbe0d..a765756be9 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -48,7 +48,7 @@ from pytensor.tensor.variable import TensorVariable from pymc.logprob.abstract import ( - MeasurableVariable, + MeasurableOpMixin, _logcdf_helper, _logprob, _logprob_helper, @@ -59,20 +59,14 @@ from pymc.pytensorf import constant_fold -class MeasurableMax(Max): +class MeasurableMax(MeasurableOpMixin, Max): """A placeholder used to specify a log-likelihood for a max sub-graph.""" -MeasurableVariable.register(MeasurableMax) - - -class MeasurableMaxDiscrete(Max): +class MeasurableMaxDiscrete(MeasurableOpMixin, Max): """A placeholder used to specify a log-likelihood for sub-graphs of maxima of discrete variables""" -MeasurableVariable.register(MeasurableMaxDiscrete) - - @node_rewriter([Max]) def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None: rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) @@ -162,21 +156,15 @@ def max_logprob_discrete(op, values, base_rv, **kwargs): return logprob -class MeasurableMaxNeg(Max): +class MeasurableMaxNeg(MeasurableOpMixin, Max): """A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph. This shows up in the graph of min, which is (neg(max(neg(x))).""" -MeasurableVariable.register(MeasurableMaxNeg) - - -class MeasurableDiscreteMaxNeg(Max): +class MeasurableDiscreteMaxNeg(MeasurableOpMixin, Max): """A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables""" -MeasurableVariable.register(MeasurableDiscreteMaxNeg) - - @node_rewriter(tracks=[Max]) def find_measurable_max_neg(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None: rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) diff --git a/pymc/logprob/rewriting.py b/pymc/logprob/rewriting.py index a7dca970d5..aa3586c21e 100644 --- a/pymc/logprob/rewriting.py +++ b/pymc/logprob/rewriting.py @@ -81,7 +81,7 @@ ) from pytensor.tensor.variable import TensorVariable -from pymc.logprob.abstract import MeasurableVariable +from pymc.logprob.abstract import MeasurableOp from pymc.logprob.utils import DiracDelta inc_subtensor_ops = (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1) @@ -139,7 +139,7 @@ def apply(self, fgraph): continue # This is where we filter only those nodes we care about: # Nodes that have variables that we want to measure and are not yet measurable - if isinstance(node.op, MeasurableVariable): + if isinstance(node.op, MeasurableOp): continue if not any(out in rv_map_feature.needs_measuring for out in node.outputs): continue @@ -155,7 +155,7 @@ def apply(self, fgraph): node_rewriter, "__name__", "" ) # If we converted to a MeasurableVariable we're done here! - if node not in fgraph.apply_nodes or isinstance(node.op, MeasurableVariable): + if node not in fgraph.apply_nodes or isinstance(node.op, MeasurableOp): # go to next node break @@ -274,7 +274,7 @@ def request_measurable(self, vars: Sequence[Variable]) -> list[Variable]: # Input vars or valued vars can't be measured for derived expressions if not var.owner or var in self.rv_values: continue - if isinstance(var.owner.op, MeasurableVariable): + if isinstance(var.owner.op, MeasurableOp): measurable.append(var) else: self.needs_measuring.add(var) diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index 84b2722b1a..2adba3297c 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -54,7 +54,7 @@ from pytensor.tensor.variable import TensorVariable from pytensor.updates import OrderedUpdates -from pymc.logprob.abstract import MeasurableVariable, _logprob +from pymc.logprob.abstract import MeasurableOp, MeasurableOpMixin, _logprob from pymc.logprob.basic import conditional_logp from pymc.logprob.rewriting import ( PreserveRVMappings, @@ -66,16 +66,13 @@ from pymc.logprob.utils import replace_rvs_by_values -class MeasurableScan(Scan): +class MeasurableScan(MeasurableOpMixin, Scan): """A placeholder used to specify a log-likelihood for a scan sub-graph.""" def __str__(self): return f"Measurable({super().__str__()})" -MeasurableVariable.register(MeasurableScan) - - def convert_outer_out_to_in( input_scan_args: ScanArgs, outer_out_vars: Iterable[TensorVariable], @@ -288,7 +285,7 @@ def get_random_outer_outputs( io_type = oo_info.name[(oo_info.name.index("_", 6) + 1) :] inner_out_type = f"inner_out_{io_type}" io_var = getattr(scan_args, inner_out_type)[oo_info.index] - if io_var.owner and isinstance(io_var.owner.op, MeasurableVariable): + if io_var.owner and isinstance(io_var.owner.op, MeasurableOp): rv_vars.append((n, oo_var, io_var)) return rv_vars diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index c709013cc6..d6d946cdf4 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -52,7 +52,7 @@ local_rv_size_lift, ) -from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper +from pymc.logprob.abstract import MeasurableOp, MeasurableOpMixin, _logprob, _logprob_helper from pymc.logprob.rewriting import ( PreserveRVMappings, assume_measured_ir_outputs, @@ -124,13 +124,10 @@ def naive_bcast_rv_lift(fgraph: FunctionGraph, node): return [bcasted_node.outputs[1]] -class MeasurableMakeVector(MakeVector): +class MeasurableMakeVector(MeasurableOpMixin, MakeVector): """A placeholder used to specify a log-likelihood for a cumsum sub-graph.""" -MeasurableVariable.register(MeasurableMakeVector) - - @_logprob.register(MeasurableMakeVector) def logprob_make_vector(op, values, *base_rvs, **kwargs): """Compute the log-likelihood graph for a `MeasurableMakeVector`.""" @@ -151,13 +148,10 @@ def logprob_make_vector(op, values, *base_rvs, **kwargs): return pt.stack(logps) -class MeasurableJoin(Join): +class MeasurableJoin(MeasurableOpMixin, Join): """A placeholder used to specify a log-likelihood for a join sub-graph.""" -MeasurableVariable.register(MeasurableJoin) - - @_logprob.register(MeasurableJoin) def logprob_join(op, values, axis, *base_rvs, **kwargs): """Compute the log-likelihood graph for a `Join`.""" @@ -222,7 +216,7 @@ def find_measurable_stacks(fgraph, node) -> list[TensorVariable] | None: return None base_vars = assume_measured_ir_outputs(valued_rvs, base_vars) - if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_vars): + if not all(var.owner and isinstance(var.owner.op, MeasurableOp) for var in base_vars): return None if is_join: @@ -234,7 +228,7 @@ def find_measurable_stacks(fgraph, node) -> list[TensorVariable] | None: return [measurable_stack] -class MeasurableDimShuffle(DimShuffle): +class MeasurableDimShuffle(MeasurableOpMixin, DimShuffle): """A placeholder used to specify a log-likelihood for a dimshuffle sub-graph.""" # Need to get the absolute path of `c_func_file`, otherwise it tries to @@ -242,9 +236,6 @@ class MeasurableDimShuffle(DimShuffle): c_func_file = str(DimShuffle.get_path(Path(DimShuffle.c_func_file))) -MeasurableVariable.register(MeasurableDimShuffle) - - @_logprob.register(MeasurableDimShuffle) def logprob_dimshuffle(op: MeasurableDimShuffle, values, base_var, **kwargs): """Compute the log-likelihood graph for a `MeasurableDimShuffle`.""" diff --git a/pymc/logprob/transform_value.py b/pymc/logprob/transform_value.py index 00422e6529..2523a9b6db 100644 --- a/pymc/logprob/transform_value.py +++ b/pymc/logprob/transform_value.py @@ -24,7 +24,7 @@ from pytensor.scan.op import Scan from pytensor.tensor.variable import TensorVariable -from pymc.logprob.abstract import MeasurableVariable, _logprob +from pymc.logprob.abstract import MeasurableOp, _logprob from pymc.logprob.rewriting import PreserveRVMappings, cleanup_ir_rewrites_db from pymc.logprob.transforms import Transform @@ -50,7 +50,7 @@ def infer_shape(self, fgraph, node, input_shapes): transformed_value = TransformedValue() -class TransformedValueRV(Op): +class TransformedValueRV(MeasurableOp, Op): """A no-op that identifies RVs whose values were transformed. This is introduced by the `TransformValuesRewrite` @@ -76,9 +76,6 @@ def infer_shape(self, fgraph, node, input_shapes): return input_shapes -MeasurableVariable.register(TransformedValueRV) - - @_logprob.register(TransformedValueRV) def transformed_value_logprob(op, values, *rv_outs, use_jacobian=True, **kwargs): """Compute the log-probability graph for a `TransformedRV`. diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 505b51cb7e..b2a9717f02 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -108,7 +108,7 @@ from pymc.logprob.abstract import ( MeasurableElemwise, - MeasurableVariable, + MeasurableOp, _icdf, _icdf_helper, _logcdf, @@ -427,7 +427,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> list[Node] """Find measurable transformations from Elemwise operators.""" # Node was already converted - if isinstance(node.op, MeasurableVariable): + if isinstance(node.op, MeasurableOp): return None # pragma: no cover rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index b75e633f73..e5e878f0b5 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -55,7 +55,7 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorVariable -from pymc.logprob.abstract import MeasurableVariable, _logprob +from pymc.logprob.abstract import MeasurableOp, _logprob from pymc.pytensorf import replace_vars_in_graphs from pymc.util import makeiter @@ -147,7 +147,7 @@ def expand(r): return { node for node in walk(makeiter(vars), expand, False) - if node.owner and isinstance(node.owner.op, RandomVariable | MeasurableVariable) + if node.owner and isinstance(node.owner.op, RandomVariable | MeasurableOp) } @@ -179,7 +179,7 @@ def check_potential_measurability( def expand_fn(var): # expand_fn does not go beyond valued_rvs or any MeasurableVariable - if var.owner and not isinstance(var.owner.op, MeasurableVariable) and var not in valued_rvs: + if var.owner and not isinstance(var.owner.op, MeasurableOp) and var not in valued_rvs: return reversed(var.owner.inputs) else: return [] @@ -189,7 +189,7 @@ def expand_fn(var): for ancestor_var in walk(inputs, expand=expand_fn, bfs=False) if ( ancestor_var.owner - and isinstance(ancestor_var.owner.op, MeasurableVariable) + and isinstance(ancestor_var.owner.op, MeasurableOp) and ancestor_var not in valued_rvs ) ): @@ -259,7 +259,7 @@ def local_check_parameter_to_ninf_switch(fgraph, node): ) -class DiracDelta(Op): +class DiracDelta(MeasurableOp, Op): """An `Op` that represents a Dirac-delta distribution.""" __props__ = ("rtol", "atol") @@ -287,9 +287,6 @@ def infer_shape(self, fgraph, node, input_shapes): return input_shapes -MeasurableVariable.register(DiracDelta) - - dirac_delta = DiracDelta() diff --git a/pymc/variational/minibatch_rv.py b/pymc/variational/minibatch_rv.py index 435cac9fbb..864825910d 100644 --- a/pymc/variational/minibatch_rv.py +++ b/pymc/variational/minibatch_rv.py @@ -20,10 +20,10 @@ from pytensor.graph import Apply, Op from pytensor.tensor import NoneConst, TensorVariable, as_tensor_variable -from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper +from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper -class MinibatchRandomVariable(Op): +class MinibatchRandomVariable(MeasurableOp, Op): """RV whose logprob should be rescaled to match total_size""" __props__ = () @@ -95,9 +95,6 @@ def get_scaling(total_size: Sequence[Variable], shape: TensorVariable) -> Tensor return pt.cast(coef, dtype=config.floatX) -MeasurableVariable.register(MinibatchRandomVariable) - - @_logprob.register(MinibatchRandomVariable) def minibatch_rv_logprob(op, values, *inputs, **kwargs): [value] = values diff --git a/tests/logprob/test_abstract.py b/tests/logprob/test_abstract.py index 7a0bc61e78..3976066e60 100644 --- a/tests/logprob/test_abstract.py +++ b/tests/logprob/test_abstract.py @@ -45,7 +45,7 @@ import pymc as pm -from pymc.logprob.abstract import MeasurableElemwise, MeasurableVariable, _logcdf_helper +from pymc.logprob.abstract import MeasurableElemwise, MeasurableOp, _logcdf_helper from pymc.logprob.basic import logcdf @@ -66,7 +66,7 @@ class TestMeasurableElemwise(MeasurableElemwise): measurable_exp_op = TestMeasurableElemwise(scalar_op=exp) measurable_exp = measurable_exp_op(0.0) - assert isinstance(measurable_exp.owner.op, MeasurableVariable) + assert isinstance(measurable_exp.owner.op, MeasurableOp) def test_logcdf_helper(): diff --git a/tests/logprob/test_composite_logprob.py b/tests/logprob/test_composite_logprob.py index e4cdfc7dc3..3653830ef9 100644 --- a/tests/logprob/test_composite_logprob.py +++ b/tests/logprob/test_composite_logprob.py @@ -41,7 +41,7 @@ import scipy.stats as st from pymc import draw, logp -from pymc.logprob.abstract import MeasurableVariable +from pymc.logprob.abstract import MeasurableOp from pymc.logprob.basic import conditional_logp from pymc.logprob.rewriting import construct_ir_fgraph from pymc.testing import assert_no_rvs @@ -138,7 +138,7 @@ def test_unvalued_ir_reversion(nested): # assert len(z_fgraph.preserve_rv_mappings.measurable_conversions) == 1 assert ( - sum(isinstance(node.op, MeasurableVariable) for node in z_fgraph.apply_nodes) == 2 + sum(isinstance(node.op, MeasurableOp) for node in z_fgraph.apply_nodes) == 2 ) # Just the 2 rvs diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index fa0c53831e..1d09e844fd 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -52,7 +52,7 @@ as_index_constant, ) -from pymc.logprob.abstract import MeasurableVariable +from pymc.logprob.abstract import MeasurableOp from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.mixture import MeasurableSwitchMixture, expand_indices from pymc.logprob.rewriting import construct_ir_fgraph @@ -993,16 +993,16 @@ def test_switch_mixture_invalid_bcast(): valid_mix = pt.switch(valid_switch_cond, valid_true_branch, valid_false_branch) fgraph, _, _ = construct_ir_fgraph({valid_mix: valid_mix.type()}) - assert isinstance(fgraph.outputs[0].owner.op, MeasurableVariable) + assert isinstance(fgraph.outputs[0].owner.op, MeasurableOp) assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchMixture) invalid_mix = pt.switch(invalid_switch_cond, valid_true_branch, valid_false_branch) fgraph, _, _ = construct_ir_fgraph({invalid_mix: invalid_mix.type()}) - assert not isinstance(fgraph.outputs[0].owner.op, MeasurableVariable) + assert not isinstance(fgraph.outputs[0].owner.op, MeasurableOp) invalid_mix = pt.switch(valid_switch_cond, valid_true_branch, invalid_false_branch) fgraph, _, _ = construct_ir_fgraph({invalid_mix: invalid_mix.type()}) - assert not isinstance(fgraph.outputs[0].owner.op, MeasurableVariable) + assert not isinstance(fgraph.outputs[0].owner.op, MeasurableOp) def test_ifelse_mixture_one_component(): diff --git a/tests/logprob/test_transform_value.py b/tests/logprob/test_transform_value.py index 2490ab61e7..3832b44143 100644 --- a/tests/logprob/test_transform_value.py +++ b/tests/logprob/test_transform_value.py @@ -30,7 +30,7 @@ from pymc.distributions.transforms import _default_transform, log, logodds from pymc.logprob import conditional_logp -from pymc.logprob.abstract import MeasurableVariable, _logprob +from pymc.logprob.abstract import MeasurableOp, _logprob from pymc.logprob.transform_value import TransformValuesMapping, TransformValuesRewrite from pymc.logprob.transforms import ExpTransform, LogOddsTransform, LogTransform from pymc.testing import assert_no_rvs @@ -42,14 +42,12 @@ def multiout_measurable_op(): # Create a dummy Op that just returns the two inputs mu1, mu2 = pt.scalars("mu1", "mu2") - class TestOpFromGraph(OpFromGraph): + class TestOpFromGraph(MeasurableOp, OpFromGraph): def do_constant_folding(self, fgraph, node): False multiout_op = TestOpFromGraph([mu1, mu2], [mu1 + 0.0, mu2 + 0.0]) - MeasurableVariable.register(TestOpFromGraph) - @_logprob.register(TestOpFromGraph) def logp_multiout(op, values, mu1, mu2): value1, value2 = values diff --git a/tests/logprob/test_utils.py b/tests/logprob/test_utils.py index 3192d0c586..d337e0317e 100644 --- a/tests/logprob/test_utils.py +++ b/tests/logprob/test_utils.py @@ -49,7 +49,7 @@ from pymc import SymbolicRandomVariable, inputvars from pymc.distributions.transforms import Interval -from pymc.logprob.abstract import MeasurableVariable +from pymc.logprob.abstract import MeasurableOp from pymc.logprob.basic import logp from pymc.logprob.utils import ( ParameterValueError, @@ -151,13 +151,7 @@ def test_intermediate_rv(self): res_ancestors = list(ancestors((res,))) assert ( - len( - list( - n - for n in res_ancestors - if n.owner and isinstance(n.owner.op, MeasurableVariable) - ) - ) + len(list(n for n in res_ancestors if n.owner and isinstance(n.owner.op, MeasurableOp))) == 1 )