From a3e2261ac8525b17e6806a9e39b70c82de9b4cea Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 1 Aug 2024 17:34:54 +0200 Subject: [PATCH] Do not instantiate all transforms in find_measurable_transforms --- pymc/logprob/transforms.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 3824af860f..505b51cb7e 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -463,21 +463,6 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> list[Node] transform_inputs: tuple[TensorVariable, ...] = (measurable_input,) transform: Transform - transform_dict = { - Exp: ExpTransform(), - Log: LogTransform(), - Abs: AbsTransform(), - Sinh: SinhTransform(), - Cosh: CoshTransform(), - Tanh: TanhTransform(), - ArcSinh: ArcsinhTransform(), - ArcCosh: ArccoshTransform(), - ArcTanh: ArctanhTransform(), - Erf: ErfTransform(), - Erfc: ErfcTransform(), - Erfcx: ErfcxTransform(), - } - transform = transform_dict.get(type(scalar_op), None) if isinstance(scalar_op, Pow): # We only allow for the base to be measurable if measurable_input_idx != 0: @@ -495,11 +480,27 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> list[Node] transform = LocTransform( transform_args_fn=lambda *inputs: inputs[-1], ) - elif transform is None: + elif isinstance(scalar_op, Mul): transform_inputs = (measurable_input, pt.mul(*other_inputs)) transform = ScaleTransform( transform_args_fn=lambda *inputs: inputs[-1], ) + else: + transform = { + Exp: ExpTransform, + Log: LogTransform, + Abs: AbsTransform, + Sinh: SinhTransform, + Cosh: CoshTransform, + Tanh: TanhTransform, + ArcSinh: ArcsinhTransform, + ArcCosh: ArccoshTransform, + ArcTanh: ArctanhTransform, + Erf: ErfTransform, + Erfc: ErfcTransform, + Erfcx: ErfcxTransform, + }[type(scalar_op)]() + transform_op = MeasurableTransform( scalar_op=scalar_op, transform=transform,