Skip to content

Commit

Permalink
Remove deprecated code (#2800)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Oct 22, 2024
1 parent 02f050a commit c6e3956
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 94 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
- Changed naming and input order arguments in `KLDivergence` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800))


### Removed

- Changed minimum supported Pytorch version to 2.0 ([#2671](https://github.com/Lightning-AI/torchmetrics/pull/2671))


- Removed `num_outputs` in `R2Score` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800))


### Fixed

- Changing `_modules` dict type in Pytorch 2.5 preventing to fail collections metrics ([#2793](https://github.com/Lightning-AI/torchmetrics/pull/2793))
Expand Down
18 changes: 0 additions & 18 deletions src/torchmetrics/functional/regression/kl_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.compute import _safe_xlogy
from torchmetrics.utilities.prints import rank_zero_warn


def _kld_update(p: Tensor, q: Tensor, log_prob: bool) -> Tuple[Tensor, int]:
Expand Down Expand Up @@ -92,14 +91,6 @@ def kl_divergence(
over data and :math:`Q` is often a prior or approximation of :math:`P`. It should be noted that the KL divergence
is a non-symmetrical metric i.e. :math:`D_{KL}(P||Q) \neq D_{KL}(Q||P)`.
.. warning::
The input order and naming in metric ``kl_divergence`` is set to be deprecated in v1.4 and changed in v1.5.
Input argument ``p`` will be renamed to ``target`` and will be moved to be the second argument of the metric.
Input argument ``q`` will be renamed to ``preds`` and will be moved to the first argument of the metric.
Thus, ``kl_divergence(p, q)`` will equal ``kl_divergence(target=q, preds=p)`` in the future to be consistent
with the rest of ``torchmetrics``. From v1.4 the two new arguments will be added as keyword arguments and
from v1.5 the two old arguments will be removed.
Args:
p: data distribution with shape ``[N, d]``
q: prior or approximate distribution with shape ``[N, d]``
Expand All @@ -120,14 +111,5 @@ def kl_divergence(
tensor(0.0853)
"""
rank_zero_warn(
"The input order and naming in metric `kl_divergence` is set to be deprecated in v1.4 and changed in v1.5."
"Input argument `p` will be renamed to `target` and will be moved to be the second argument of the metric."
"Input argument `q` will be renamed to `preds` and will be moved to the first argument of the metric."
"Thus, `kl_divergence(p, q)` will equal `kl_divergence(target=q, preds=p)` in the future to be consistent with"
" the rest of torchmetrics. From v1.4 the two new arguments will be added as keyword arguments and from v1.5"
" the two old arguments will be removed.",
DeprecationWarning,
)
measures, total = _kld_update(p, q, log_prob)
return _kld_compute(measures, total, reduction)
18 changes: 0 additions & 18 deletions src/torchmetrics/regression/kl_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
from torchmetrics.utilities.prints import rank_zero_warn

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["KLDivergence.plot"]
Expand All @@ -47,14 +46,6 @@ class KLDivergence(Metric):
- ``kl_divergence`` (:class:`~torch.Tensor`): A tensor with the KL divergence
.. warning::
The input order and naming in metric ``KLDivergence`` is set to be deprecated in v1.4 and changed in v1.5.
Input argument ``p`` will be renamed to ``target`` and will be moved to be the second argument of the metric.
Input argument ``q`` will be renamed to ``preds`` and will be moved to the first argument of the metric.
Thus, ``KLDivergence(p, q)`` will equal ``KLDivergence(target=q, preds=p)`` in the future to be consistent
with the rest of ``torchmetrics``. From v1.4 the two new arguments will be added as keyword arguments and
from v1.5 the two old arguments will be removed.
Args:
log_prob: bool indicating if input is log-probabilities or probabilities. If given as probabilities,
will normalize to make sure the distributes sum to 1.
Expand Down Expand Up @@ -102,15 +93,6 @@ def __init__(
reduction: Literal["mean", "sum", "none", None] = "mean",
**kwargs: Any,
) -> None:
rank_zero_warn(
"The input order and naming in metric `KLDivergence` is set to be deprecated in v1.4 and changed in v1.5."
"Input argument `p` will be renamed to `target` and will be moved to be the second argument of the metric."
"Input argument `q` will be renamed to `preds` and will be moved to the first argument of the metric."
"Thus, `KLDivergence(p, q)` will equal `KLDivergence(target=q, preds=p)` in the future to be consistent"
" with the rest of torchmetrics. From v1.4 the two new arguments will be added as keyword arguments and"
" from v1.5 the two old arguments will be removed.",
DeprecationWarning,
)
super().__init__(**kwargs)
if not isinstance(log_prob, bool):
raise TypeError(f"Expected argument `log_prob` to be bool but got {log_prob}")
Expand Down
11 changes: 0 additions & 11 deletions src/torchmetrics/regression/r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from torchmetrics.functional.regression.r2 import _r2_score_compute, _r2_score_update
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

Expand Down Expand Up @@ -108,21 +107,11 @@ class R2Score(Metric):

def __init__(
self,
num_outputs: Optional[int] = None,
adjusted: int = 0,
multioutput: str = "uniform_average",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)

if num_outputs is not None:
rank_zero_warn(
"Argument `num_outputs` in `R2Score` has been deprecated because it is no longer necessary and will be"
"removed in v1.6.0 of TorchMetrics. The number of outputs is now automatically inferred from the shape"
"of the input tensors.",
DeprecationWarning,
)

if adjusted < 0 or not isinstance(adjusted, int):
raise ValueError("`adjusted` parameter should be an integer larger or equal to 0.")
self.adjusted = adjusted
Expand Down
36 changes: 12 additions & 24 deletions tests/unittests/regression/test_r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,28 +60,28 @@ def _multi_target_ref_wrapper(preds, target, adjusted, multioutput):
@pytest.mark.parametrize("adjusted", [0, 5, 10])
@pytest.mark.parametrize("multioutput", ["raw_values", "uniform_average", "variance_weighted"])
@pytest.mark.parametrize(
"preds, target, ref_metric, num_outputs",
"preds, target, ref_metric",
[
(_single_target_inputs.preds, _single_target_inputs.target, _single_target_ref_wrapper, 1),
(_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_ref_wrapper, NUM_TARGETS),
(_single_target_inputs.preds, _single_target_inputs.target, _single_target_ref_wrapper),
(_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_ref_wrapper),
],
)
class TestR2Score(MetricTester):
"""Test class for `R2Score` metric."""

@pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False])
def test_r2(self, adjusted, multioutput, preds, target, ref_metric, num_outputs, ddp):
def test_r2(self, adjusted, multioutput, preds, target, ref_metric, ddp):
"""Test class implementation of metric."""
self.run_class_metric_test(
ddp,
preds,
target,
R2Score,
partial(ref_metric, adjusted=adjusted, multioutput=multioutput),
metric_args={"adjusted": adjusted, "multioutput": multioutput, "num_outputs": num_outputs},
metric_args={"adjusted": adjusted, "multioutput": multioutput},
)

def test_r2_functional(self, adjusted, multioutput, preds, target, ref_metric, num_outputs):
def test_r2_functional(self, adjusted, multioutput, preds, target, ref_metric):
"""Test functional implementation of metric."""
self.run_functional_metric_test(
preds,
Expand All @@ -91,35 +91,23 @@ def test_r2_functional(self, adjusted, multioutput, preds, target, ref_metric, n
metric_args={"adjusted": adjusted, "multioutput": multioutput},
)

def test_r2_differentiability(self, adjusted, multioutput, preds, target, ref_metric, num_outputs):
def test_r2_differentiability(self, adjusted, multioutput, preds, target, ref_metric):
"""Test the differentiability of the metric, according to its `is_differentiable` attribute."""
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=partial(R2Score, num_outputs=num_outputs),
metric_functional=r2_score,
metric_args={"adjusted": adjusted, "multioutput": multioutput},
preds, target, R2Score, r2_score, {"adjusted": adjusted, "multioutput": multioutput}
)

def test_r2_half_cpu(self, adjusted, multioutput, preds, target, ref_metric, num_outputs):
def test_r2_half_cpu(self, adjusted, multioutput, preds, target, ref_metric):
"""Test dtype support of the metric on CPU."""
self.run_precision_test_cpu(
preds,
target,
partial(R2Score, num_outputs=num_outputs),
r2_score,
{"adjusted": adjusted, "multioutput": multioutput},
preds, target, R2Score, r2_score, {"adjusted": adjusted, "multioutput": multioutput}
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_r2_half_gpu(self, adjusted, multioutput, preds, target, ref_metric, num_outputs):
def test_r2_half_gpu(self, adjusted, multioutput, preds, target, ref_metric):
"""Test dtype support of the metric on GPU."""
self.run_precision_test_gpu(
preds,
target,
partial(R2Score, num_outputs=num_outputs),
r2_score,
{"adjusted": adjusted, "multioutput": multioutput},
preds, target, R2Score, r2_score, {"adjusted": adjusted, "multioutput": multioutput}
)


Expand Down
22 changes: 0 additions & 22 deletions tests/unittests/test_deprecated.py

This file was deleted.

0 comments on commit c6e3956

Please sign in to comment.