Skip to content

Commit

Permalink
Made grad clipping optional
Browse files Browse the repository at this point in the history
  • Loading branch information
Kushagra Pandey authored and kpandey008 committed Oct 3, 2023
1 parent dc36436 commit aa457e6
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions main/models/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,16 @@ def training_step(self, batch, batch_idx):
t = t_ * (self.sde.T - self.train_eps) + self.train_eps
assert t.shape[0] == x_0.shape[0]

# Compute loss
# Compute loss and backward
loss = self.criterion(x_0, t, self.score_fn)

# Clip gradients and Optimize
optim.zero_grad()
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(
self.score_fn.parameters(), self.config.training.optimizer.grad_clip
)

# Clip gradients (if enabled.)
if self.config.training.optimizer.grad_clip != 0:
torch.nn.utils.clip_grad_norm_(
self.score_fn.parameters(), self.config.training.optimizer.grad_clip
)
optim.step()

# Scheduler step
Expand Down

0 comments on commit aa457e6

Please sign in to comment.