Skip to content

Commit

Permalink
fix: REBAR sometimes crashes when no gradient is available
Browse files Browse the repository at this point in the history
  • Loading branch information
HEmile committed May 3, 2022
1 parent 62d6499 commit 2a3f1bc
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions storch/method/relax.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

from functools import reduce
from operator import mul
from typing import Optional, Callable, Tuple
Expand Down Expand Up @@ -344,6 +346,9 @@ def update_parameters(
for tensor in tensors:
# TODO: We have to select the probs of the distribution here as that's what it flows to. Is this always correct?
d_param = tensor.grad['probs']
if not d_param.requires_grad:
warnings.warn("Gradient of input tensor does not require grad which is needed to train the REBAR variance parameter.".format(tensor))
continue
variance = (d_param ** 2).sum(d_param.event_dim_indices)
var_loss = storch.reduce_plates(variance)

Expand Down

0 comments on commit 2a3f1bc

Please sign in to comment.