Skip to content

Commit

Permalink
Bump PyMC minimum version requirement
Browse files Browse the repository at this point in the history
  • Loading branch information
HasnainRaz authored and ricardoV94 committed Nov 27, 2023
1 parent 79c4dc1 commit 150fb0f
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 12 deletions.
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ dependencies:
- xhistogram
- statsmodels
- pip:
- pymc>=5.9.0 # CI was failing to resolve
- pymc>=5.10.0 # CI was failing to resolve
- blackjax
- scikit-learn
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ dependencies:
- xhistogram
- statsmodels
- pip:
- pymc>=5.9.0 # CI was failing to resolve
- pymc>=5.10.0 # CI was failing to resolve
- blackjax
- scikit-learn
2 changes: 0 additions & 2 deletions pymc_experimental/model/marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,6 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
for i in range(len(marginalized_rv_domain))
]
else:
# Make sure this rewrite is registered
from pymc.pytensorf import local_remove_check_parameter

def logp_fn(marginalized_rv_const, *non_sequences):
return joint_logp_op(marginalized_rv_const, *non_sequences)
Expand Down
14 changes: 7 additions & 7 deletions pymc_experimental/utils/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
import numpy as np
import pymc as pm
import pytensor.tensor as pt
from pymc.logprob.transforms import RVTransform
from pymc.logprob.transforms import Transform


class ParamCfg(TypedDict):
name: str
transform: Optional[RVTransform]
transform: Optional[Transform]
dims: Optional[Union[str, Tuple[str]]]


Expand All @@ -44,14 +44,14 @@ class FlatInfo(TypedDict):
info: List[VarInfo]


def _arg_to_param_cfg(key, value: Optional[Union[ParamCfg, RVTransform, str, Tuple]] = None):
def _arg_to_param_cfg(key, value: Optional[Union[ParamCfg, Transform, str, Tuple]] = None):
if value is None:
cfg = ParamCfg(name=key, transform=None, dims=None)
elif isinstance(value, Tuple):
cfg = ParamCfg(name=key, transform=None, dims=value)
elif isinstance(value, str):
cfg = ParamCfg(name=value, transform=None, dims=None)
elif isinstance(value, RVTransform):
elif isinstance(value, Transform):
cfg = ParamCfg(name=key, transform=value, dims=None)
else:
cfg = value.copy()
Expand All @@ -62,7 +62,7 @@ def _arg_to_param_cfg(key, value: Optional[Union[ParamCfg, RVTransform, str, Tup


def _parse_args(
var_names: Sequence[str], **kwargs: Union[ParamCfg, RVTransform, str, Tuple]
var_names: Sequence[str], **kwargs: Union[ParamCfg, Transform, str, Tuple]
) -> Dict[str, ParamCfg]:
results = dict()
for var in var_names:
Expand Down Expand Up @@ -133,7 +133,7 @@ def prior_from_idata(
name="trace_prior_",
*,
var_names: Sequence[str] = (),
**kwargs: Union[ParamCfg, RVTransform, str, Tuple]
**kwargs: Union[ParamCfg, Transform, str, Tuple]
) -> Dict[str, pt.TensorVariable]:
"""
Create a prior from posterior using MvNormal approximation.
Expand All @@ -153,7 +153,7 @@ def prior_from_idata(
Inference data with posterior group
var_names: Sequence[str]
names of variables to take as is from the posterior
kwargs: Union[ParamCfg, RVTransform, str, Tuple]
kwargs: Union[ParamCfg, Transform, str, Tuple]
names of variables with additional configuration, see more in Examples
Examples
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pymc>=5.8.2
pymc>=5.10.0
scikit-learn

0 comments on commit 150fb0f

Please sign in to comment.