Skip to content

Commit

Permalink
Fix RWKV backward on GPU (huggingface#23774)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger authored and sheonhan committed Jun 1, 2023
1 parent 814dbf2 commit ff596b3
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions src/transformers/models/rwkv/modeling_rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def forward(ctx, time_decay, time_first, key, value, state=None, return_state=Fa

@staticmethod
# g stands for grad
def backward(ctx, g_output):
def backward(ctx, g_output, g_state=None):
input_dtype = ctx.input_dtype

time_decay, time_first, key, value, output = ctx.saved_tensors
Expand Down Expand Up @@ -188,17 +188,14 @@ def backward(ctx, g_output):
g_key,
g_value,
)
g_time_decay = torch.sum(g_time_decay, dim=0)
g_time_first = torch.sum(g_time_first, dim=0)

return (
None,
None,
None,
g_time_decay.to(input_dtype),
g_time_first.to(input_dtype),
g_key.to(input_dtype),
g_value.to(input_dtype),
None,
None,
)


Expand Down

0 comments on commit ff596b3

Please sign in to comment.