Skip to content

Commit

Permalink
Address review comment
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Oct 30, 2022
1 parent 2da3c75 commit 152b913
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions pyro/ops/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.fft import irfft, rfft

_ROOT_TWO_INVERSE = 1.0 / math.sqrt(2.0)
CHOLESKY_JITTER = 1.0 # in units of finfo.eps
CHOLESKY_RELATIVE_JITTER = 1.0 # in units of finfo.eps


def as_complex(x):
Expand Down Expand Up @@ -396,15 +396,15 @@ def inverse_haar_transform(x):

def safe_cholesky(x):
if x.size(-1) == 1:
if CHOLESKY_JITTER:
if CHOLESKY_RELATIVE_JITTER:
x = x.clamp(min=torch.finfo(x.dtype).tiny)
return x.sqrt()

if CHOLESKY_JITTER:
if CHOLESKY_RELATIVE_JITTER:
# Add adaptive jitter.
x = x.clone()
x_max = x.data.reshape(*x.shape[:-2], -1).abs().max(-1, True).values
jitter = CHOLESKY_JITTER * torch.finfo(x.dtype).eps * x_max
jitter = CHOLESKY_RELATIVE_JITTER * torch.finfo(x.dtype).eps * x_max
x.data.diagonal(dim1=-1, dim2=-2).add_(jitter)

return torch.linalg.cholesky(x)
Expand Down

0 comments on commit 152b913

Please sign in to comment.