Skip to content

Commit

Permalink
Merge pull request #63 from RWKV/ds3-fix
Browse files Browse the repository at this point in the history
Ds3 fix
  • Loading branch information
PicoCreator authored Jan 20, 2024
2 parents 9426405 + e0aad53 commit 3186764
Show file tree
Hide file tree
Showing 2 changed files with 1,726 additions and 7 deletions.
19 changes: 12 additions & 7 deletions RWKV-v5/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,13 +1139,18 @@ def checkpointed_step(idx, targets, mask, last_shift_states,
# https://lightning.ai/docs/pytorch/2.0.4/common/lightning_module.html#backward
learning_loss = segment_train_loss / gradient_accumulation_steps

# Undocumented multiple backward pass support
# https://github.com/Lightning-AI/lightning/blob/678f642808c54e4c490caee4df5d357301c976bb/tests/trainer/optimization/test_manual_optimization.py#L251
self.manual_backward(learning_loss, optimizer, retain_graph=True)

# Accumulate without gradient, as we already did the backward pass
# This does mean, that a single backward pass is "wasted" at the end
training_loss = training_loss + segment_train_loss.clone().detach().requires_grad_(False)
# Perform the backward pass accordingly, for valid segments (besides the last segment)
if i == start_learning_segment + backward_segment_count - 1:
# This is the last backward pass, we let the default pytorch lightning handle the backward pass
# and return the segment loss as part of the total loss
training_loss = training_loss + segment_train_loss
else:
# Undocumented multiple backward pass support
# https://github.com/Lightning-AI/lightning/blob/678f642808c54e4c490caee4df5d357301c976bb/tests/trainer/optimization/test_manual_optimization.py#L251
self.manual_backward(learning_loss, optimizer, retain_graph=True)

# Accumulate without gradient, as we already did the backward pass
training_loss = training_loss + segment_train_loss.clone().detach().requires_grad_(False)
else:
# Even if its not the segments we use for backward pass, we still need to accumulate the loss
training_loss = training_loss + segment_train_loss.clone().detach().requires_grad_(False)
Expand Down
Loading

0 comments on commit 3186764

Please sign in to comment.