Skip to content

Commit

Permalink
Add type annotations to distribution.py (pytorch#87577)
Browse files Browse the repository at this point in the history
As title.
Pull Request resolved: pytorch#87577
Approved by: https://github.com/kit1980
  • Loading branch information
EPronovost authored and pytorchmergebot committed Oct 26, 2022
1 parent 16e35bd commit 585d715
Showing 1 changed file with 28 additions and 22 deletions.
50 changes: 28 additions & 22 deletions torch/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import warnings
from torch.distributions import constraints
from torch.distributions.utils import lazy_property
from typing import Dict, Optional, Any
from torch.types import _size
from typing import Dict, Optional, Any, Tuple

__all__ = ['Distribution']

Expand All @@ -16,7 +17,7 @@ class Distribution(object):
_validate_args = __debug__

@staticmethod
def set_default_validate_args(value):
def set_default_validate_args(value: bool) -> None:
"""
Sets whether validation is enabled or disabled.
Expand All @@ -32,7 +33,12 @@ def set_default_validate_args(value):
raise ValueError
Distribution._validate_args = value

def __init__(self, batch_shape=torch.Size(), event_shape=torch.Size(), validate_args=None):
def __init__(
self,
batch_shape: torch.Size = torch.Size(),
event_shape: torch.Size = torch.Size(),
validate_args: Optional[bool] = None,
):
self._batch_shape = batch_shape
self._event_shape = event_shape
if validate_args is not None:
Expand Down Expand Up @@ -62,7 +68,7 @@ def __init__(self, batch_shape=torch.Size(), event_shape=torch.Size(), validate_
)
super(Distribution, self).__init__()

def expand(self, batch_shape, _instance=None):
def expand(self, batch_shape: torch.Size, _instance=None):
"""
Returns a new distribution instance (or populates an existing instance
provided by a derived class) with batch dimensions expanded to
Expand All @@ -84,14 +90,14 @@ def expand(self, batch_shape, _instance=None):
raise NotImplementedError

@property
def batch_shape(self):
def batch_shape(self) -> torch.Size:
"""
Returns the shape over which parameters are batched.
"""
return self._batch_shape

@property
def event_shape(self):
def event_shape(self) -> torch.Size:
"""
Returns the shape of a single sample (without batching).
"""
Expand All @@ -116,58 +122,58 @@ def support(self) -> Optional[Any]:
raise NotImplementedError

@property
def mean(self):
def mean(self) -> torch.Tensor:
"""
Returns the mean of the distribution.
"""
raise NotImplementedError

@property
def mode(self):
def mode(self) -> torch.Tensor:
"""
Returns the mode of the distribution.
"""
raise NotImplementedError(f"{self.__class__} does not implement mode")

@property
def variance(self):
def variance(self) -> torch.Tensor:
"""
Returns the variance of the distribution.
"""
raise NotImplementedError

@property
def stddev(self):
def stddev(self) -> torch.Tensor:
"""
Returns the standard deviation of the distribution.
"""
return self.variance.sqrt()

def sample(self, sample_shape=torch.Size()):
def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
"""
Generates a sample_shape shaped sample or sample_shape shaped batch of
samples if the distribution parameters are batched.
"""
with torch.no_grad():
return self.rsample(sample_shape)

def rsample(self, sample_shape=torch.Size()):
def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
"""
Generates a sample_shape shaped reparameterized sample or sample_shape
shaped batch of reparameterized samples if the distribution parameters
are batched.
"""
raise NotImplementedError

def sample_n(self, n):
def sample_n(self, n: int) -> torch.Tensor:
"""
Generates n samples or n batches of samples if the distribution
parameters are batched.
"""
warnings.warn('sample_n will be deprecated. Use .sample((n,)) instead', UserWarning)
return self.sample(torch.Size((n,)))

def log_prob(self, value):
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Returns the log of the probability density/mass function evaluated at
`value`.
Expand All @@ -177,7 +183,7 @@ def log_prob(self, value):
"""
raise NotImplementedError

def cdf(self, value):
def cdf(self, value: torch.Tensor) -> torch.Tensor:
"""
Returns the cumulative density/mass function evaluated at
`value`.
Expand All @@ -187,7 +193,7 @@ def cdf(self, value):
"""
raise NotImplementedError

def icdf(self, value):
def icdf(self, value: torch.Tensor) -> torch.Tensor:
"""
Returns the inverse cumulative density/mass function evaluated at
`value`.
Expand All @@ -197,7 +203,7 @@ def icdf(self, value):
"""
raise NotImplementedError

def enumerate_support(self, expand=True):
def enumerate_support(self, expand: bool = True) -> torch.Tensor:
"""
Returns tensor containing all values supported by a discrete
distribution. The result will enumerate over dimension 0, so the shape
Expand All @@ -221,7 +227,7 @@ def enumerate_support(self, expand=True):
"""
raise NotImplementedError

def entropy(self):
def entropy(self) -> torch.Tensor:
"""
Returns entropy of distribution, batched over batch_shape.
Expand All @@ -230,7 +236,7 @@ def entropy(self):
"""
raise NotImplementedError

def perplexity(self):
def perplexity(self) -> torch.Tensor:
"""
Returns perplexity of distribution, batched over batch_shape.
Expand All @@ -239,7 +245,7 @@ def perplexity(self):
"""
return torch.exp(self.entropy())

def _extended_shape(self, sample_shape=torch.Size()):
def _extended_shape(self, sample_shape: _size = torch.Size()) -> Tuple[int, ...]:
"""
Returns the size of the sample returned by the distribution, given
a `sample_shape`. Note, that the batch and event shapes of a distribution
Expand All @@ -253,7 +259,7 @@ def _extended_shape(self, sample_shape=torch.Size()):
sample_shape = torch.Size(sample_shape)
return sample_shape + self._batch_shape + self._event_shape

def _validate_sample(self, value):
def _validate_sample(self, value: torch.Tensor) -> None:
"""
Argument validation for distribution methods such as `log_prob`,
`cdf` and `icdf`. The rightmost dimensions of a value to be
Expand Down Expand Up @@ -306,7 +312,7 @@ def _get_checked_instance(self, cls, _instance=None):
format(self.__class__.__name__, cls.__name__))
return self.__new__(type(self)) if _instance is None else _instance

def __repr__(self):
def __repr__(self) -> str:
param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__]
args_string = ', '.join(['{}: {}'.format(p, self.__dict__[p]
if self.__dict__[p].numel() == 1
Expand Down

0 comments on commit 585d715

Please sign in to comment.