Skip to content

Commit

Permalink
Do not instantiate all transforms in find_measurable_transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Aug 5, 2024
1 parent d4714d2 commit a3e2261
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,21 +463,6 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> list[Node]
transform_inputs: tuple[TensorVariable, ...] = (measurable_input,)
transform: Transform

transform_dict = {
Exp: ExpTransform(),
Log: LogTransform(),
Abs: AbsTransform(),
Sinh: SinhTransform(),
Cosh: CoshTransform(),
Tanh: TanhTransform(),
ArcSinh: ArcsinhTransform(),
ArcCosh: ArccoshTransform(),
ArcTanh: ArctanhTransform(),
Erf: ErfTransform(),
Erfc: ErfcTransform(),
Erfcx: ErfcxTransform(),
}
transform = transform_dict.get(type(scalar_op), None)
if isinstance(scalar_op, Pow):
# We only allow for the base to be measurable
if measurable_input_idx != 0:
Expand All @@ -495,11 +480,27 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> list[Node]
transform = LocTransform(
transform_args_fn=lambda *inputs: inputs[-1],
)
elif transform is None:
elif isinstance(scalar_op, Mul):
transform_inputs = (measurable_input, pt.mul(*other_inputs))
transform = ScaleTransform(
transform_args_fn=lambda *inputs: inputs[-1],
)
else:
transform = {
Exp: ExpTransform,
Log: LogTransform,
Abs: AbsTransform,
Sinh: SinhTransform,
Cosh: CoshTransform,
Tanh: TanhTransform,
ArcSinh: ArcsinhTransform,
ArcCosh: ArccoshTransform,
ArcTanh: ArctanhTransform,
Erf: ErfTransform,
Erfc: ErfcTransform,
Erfcx: ErfcxTransform,
}[type(scalar_op)]()

transform_op = MeasurableTransform(
scalar_op=scalar_op,
transform=transform,
Expand Down

0 comments on commit a3e2261

Please sign in to comment.