From 585d71513de98f02659835b08785de845bc6d348 Mon Sep 17 00:00:00 2001 From: Ethan Pronovost Date: Wed, 26 Oct 2022 18:50:48 +0000 Subject: [PATCH] Add type annotations to distribution.py (#87577) As title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87577 Approved by: https://github.com/kit1980 --- torch/distributions/distribution.py | 50 ++++++++++++++++------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py index 66bd158bd87b6b..4159f34d7748a4 100644 --- a/torch/distributions/distribution.py +++ b/torch/distributions/distribution.py @@ -2,7 +2,8 @@ import warnings from torch.distributions import constraints from torch.distributions.utils import lazy_property -from typing import Dict, Optional, Any +from torch.types import _size +from typing import Dict, Optional, Any, Tuple __all__ = ['Distribution'] @@ -16,7 +17,7 @@ class Distribution(object): _validate_args = __debug__ @staticmethod - def set_default_validate_args(value): + def set_default_validate_args(value: bool) -> None: """ Sets whether validation is enabled or disabled. @@ -32,7 +33,12 @@ def set_default_validate_args(value): raise ValueError Distribution._validate_args = value - def __init__(self, batch_shape=torch.Size(), event_shape=torch.Size(), validate_args=None): + def __init__( + self, + batch_shape: torch.Size = torch.Size(), + event_shape: torch.Size = torch.Size(), + validate_args: Optional[bool] = None, + ): self._batch_shape = batch_shape self._event_shape = event_shape if validate_args is not None: @@ -62,7 +68,7 @@ def __init__(self, batch_shape=torch.Size(), event_shape=torch.Size(), validate_ ) super(Distribution, self).__init__() - def expand(self, batch_shape, _instance=None): + def expand(self, batch_shape: torch.Size, _instance=None): """ Returns a new distribution instance (or populates an existing instance provided by a derived class) with batch dimensions expanded to @@ -84,14 +90,14 @@ def expand(self, batch_shape, _instance=None): raise NotImplementedError @property - def batch_shape(self): + def batch_shape(self) -> torch.Size: """ Returns the shape over which parameters are batched. """ return self._batch_shape @property - def event_shape(self): + def event_shape(self) -> torch.Size: """ Returns the shape of a single sample (without batching). """ @@ -116,34 +122,34 @@ def support(self) -> Optional[Any]: raise NotImplementedError @property - def mean(self): + def mean(self) -> torch.Tensor: """ Returns the mean of the distribution. """ raise NotImplementedError @property - def mode(self): + def mode(self) -> torch.Tensor: """ Returns the mode of the distribution. """ raise NotImplementedError(f"{self.__class__} does not implement mode") @property - def variance(self): + def variance(self) -> torch.Tensor: """ Returns the variance of the distribution. """ raise NotImplementedError @property - def stddev(self): + def stddev(self) -> torch.Tensor: """ Returns the standard deviation of the distribution. """ return self.variance.sqrt() - def sample(self, sample_shape=torch.Size()): + def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: """ Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. @@ -151,7 +157,7 @@ def sample(self, sample_shape=torch.Size()): with torch.no_grad(): return self.rsample(sample_shape) - def rsample(self, sample_shape=torch.Size()): + def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: """ Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters @@ -159,7 +165,7 @@ def rsample(self, sample_shape=torch.Size()): """ raise NotImplementedError - def sample_n(self, n): + def sample_n(self, n: int) -> torch.Tensor: """ Generates n samples or n batches of samples if the distribution parameters are batched. @@ -167,7 +173,7 @@ def sample_n(self, n): warnings.warn('sample_n will be deprecated. Use .sample((n,)) instead', UserWarning) return self.sample(torch.Size((n,))) - def log_prob(self, value): + def log_prob(self, value: torch.Tensor) -> torch.Tensor: """ Returns the log of the probability density/mass function evaluated at `value`. @@ -177,7 +183,7 @@ def log_prob(self, value): """ raise NotImplementedError - def cdf(self, value): + def cdf(self, value: torch.Tensor) -> torch.Tensor: """ Returns the cumulative density/mass function evaluated at `value`. @@ -187,7 +193,7 @@ def cdf(self, value): """ raise NotImplementedError - def icdf(self, value): + def icdf(self, value: torch.Tensor) -> torch.Tensor: """ Returns the inverse cumulative density/mass function evaluated at `value`. @@ -197,7 +203,7 @@ def icdf(self, value): """ raise NotImplementedError - def enumerate_support(self, expand=True): + def enumerate_support(self, expand: bool = True) -> torch.Tensor: """ Returns tensor containing all values supported by a discrete distribution. The result will enumerate over dimension 0, so the shape @@ -221,7 +227,7 @@ def enumerate_support(self, expand=True): """ raise NotImplementedError - def entropy(self): + def entropy(self) -> torch.Tensor: """ Returns entropy of distribution, batched over batch_shape. @@ -230,7 +236,7 @@ def entropy(self): """ raise NotImplementedError - def perplexity(self): + def perplexity(self) -> torch.Tensor: """ Returns perplexity of distribution, batched over batch_shape. @@ -239,7 +245,7 @@ def perplexity(self): """ return torch.exp(self.entropy()) - def _extended_shape(self, sample_shape=torch.Size()): + def _extended_shape(self, sample_shape: _size = torch.Size()) -> Tuple[int, ...]: """ Returns the size of the sample returned by the distribution, given a `sample_shape`. Note, that the batch and event shapes of a distribution @@ -253,7 +259,7 @@ def _extended_shape(self, sample_shape=torch.Size()): sample_shape = torch.Size(sample_shape) return sample_shape + self._batch_shape + self._event_shape - def _validate_sample(self, value): + def _validate_sample(self, value: torch.Tensor) -> None: """ Argument validation for distribution methods such as `log_prob`, `cdf` and `icdf`. The rightmost dimensions of a value to be @@ -306,7 +312,7 @@ def _get_checked_instance(self, cls, _instance=None): format(self.__class__.__name__, cls.__name__)) return self.__new__(type(self)) if _instance is None else _instance - def __repr__(self): + def __repr__(self) -> str: param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] args_string = ', '.join(['{}: {}'.format(p, self.__dict__[p] if self.__dict__[p].numel() == 1