Skip to content

Remove deprecated Distribution kwargs #7488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 0 additions & 44 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,6 @@ class DistributionMeta(ABCMeta):
"""

def __new__(cls, name, bases, clsdict):
# Forcefully deprecate old v3 `Distribution`s
if "random" in clsdict:

def _random(*args, **kwargs):
warnings.warn(
"The old `Distribution.random` interface is deprecated.",
FutureWarning,
stacklevel=2,
)
return clsdict["random"](*args, **kwargs)

clsdict["random"] = _random

rv_op = clsdict.setdefault("rv_op", None)
rv_type = clsdict.setdefault("rv_type", None)

Expand Down Expand Up @@ -206,13 +193,6 @@ def support_point(op, rv, *dist_params):
return new_cls


def _make_nice_attr_error(oldcode: str, newcode: str):
def fn(*args, **kwargs):
raise AttributeError(f"The `{oldcode}` method was removed. Instead use `{newcode}`.`")

return fn


class _class_or_instancemethod(classmethod):
"""Allow a method to be called both as a classmethod and an instancemethod,
giving priority to the instancemethod.
Expand Down Expand Up @@ -510,14 +490,6 @@ def __new__(
"for a standalone distribution."
)

if "testval" in kwargs:
initval = kwargs.pop("testval")
warnings.warn(
"The `testval` argument is deprecated; use `initval`.",
FutureWarning,
stacklevel=2,
)

if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")

Expand Down Expand Up @@ -551,10 +523,6 @@ def __new__(
rv_out._repr_latex_ = types.MethodType(
functools.partial(str_for_dist, formatting="latex"), rv_out
)

rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
return rv_out

@classmethod
Expand Down Expand Up @@ -582,15 +550,6 @@ def dist(
rv : TensorVariable
The created random variable tensor.
"""
if "testval" in kwargs:
kwargs.pop("testval")
warnings.warn(
"The `.dist(testval=...)` argument is deprecated and has no effect. "
"Initial values for sampling/optimization can be specified with `initval` in a modelcontext. "
"For using PyTensor's test value features, you must assign the `.tag.test_value` yourself.",
FutureWarning,
stacklevel=2,
)
if "initval" in kwargs:
raise TypeError(
"Unexpected keyword argument `initval`. "
Expand All @@ -617,9 +576,6 @@ def dist(
create_size = find_size(shape=shape, size=size, ndim_supp=ndim_supp)
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)

rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
_add_future_warning_tag(rv_out)
return rv_out

Expand Down
25 changes: 0 additions & 25 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,31 +84,6 @@ def test_issue_4499(self):
npt.assert_almost_equal(m.compile_logp()({"x": np.ones(10)}), 0 * 10)


@pytest.mark.parametrize(
"method,newcode",
[
("logp", r"pm.logp\(rv, x\)"),
("logcdf", r"pm.logcdf\(rv, x\)"),
("random", r"pm.draw\(rv\)"),
],
)
def test_logp_gives_migration_instructions(method, newcode):
rv = pm.Normal.dist()
f = getattr(rv, method)
with pytest.raises(AttributeError, match=rf"use `{newcode}`"):
f()

# A dim-induced resize of the rv created by the `.dist()` API,
# happening in Distribution.__new__ would make us loose the monkeypatches.
# So this triggers it to test if the monkeypatch still works.
with pm.Model(coords={"year": [2019, 2021, 2022]}):
rv = pm.Normal("n", dims="year")
f = getattr(rv, method)
with pytest.raises(AttributeError, match=rf"use `{newcode}`"):
f()
pass


def test_all_distributions_have_support_points():
import pymc.distributions as dist_module

Expand Down
4 changes: 2 additions & 2 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,8 +660,8 @@ def test_initial_point():

b_initval = np.array(0.3, dtype=pytensor.config.floatX)

with pytest.warns(FutureWarning), model:
b = pm.Uniform("b", testval=b_initval)
with model:
b = pm.Uniform("b", initval=b_initval)

b_initval_trans = model.rvs_to_transforms[b].forward(b_initval, *b.owner.inputs).eval()

Expand Down
29 changes: 0 additions & 29 deletions tests/test_initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import cloudpickle
import numpy as np
import numpy.testing as npt
import pytensor
import pytensor.tensor as pt
import pytest
Expand All @@ -34,34 +33,6 @@ def transform_back(rv, transformed, model) -> np.ndarray:
return model.rvs_to_transforms[rv].backward(transformed, *rv.owner.inputs).eval()


class TestInitvalAssignment:
def test_dist_warnings_and_errors(self):
with pytest.warns(FutureWarning, match="argument is deprecated and has no effect"):
rv = pm.Exponential.dist(lam=1, testval=0.5)
assert not hasattr(rv.tag, "test_value")

with pytest.raises(TypeError, match="Unexpected keyword argument `initval`."):
pm.Normal.dist(1, 2, initval=None)
pass

def test_new_warnings(self):
with pm.Model() as pmodel:
with pytest.warns(FutureWarning, match="`testval` argument is deprecated"):
rv = pm.Uniform("u", 0, 1, testval=0.75)
initial_point = pmodel.initial_point(random_seed=0)
npt.assert_allclose(
initial_point["u_interval__"], transform_fwd(rv, 0.75, model=pmodel)
)
assert not hasattr(rv.tag, "test_value")
pass

def test_valid_string_strategy(self):
with pm.Model() as pmodel:
pm.Uniform("x", 0, 1, size=2, initval="unknown")
with pytest.raises(ValueError, match="Invalid string strategy: unknown"):
pmodel.initial_point(random_seed=0)


class TestInitvalEvaluation:
def test_make_initial_point_fns_per_chain_checks_kwargs(self):
with pm.Model() as pmodel:
Expand Down
Loading