From ff596b34c6f14c2dd060b29065b0ae6bba5a370a Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 26 May 2023 08:33:17 -0400 Subject: [PATCH] Fix RWKV backward on GPU (#23774) --- src/transformers/models/rwkv/modeling_rwkv.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 7a78ec082f45ae..cd577e9c7431af 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -159,7 +159,7 @@ def forward(ctx, time_decay, time_first, key, value, state=None, return_state=Fa @staticmethod # g stands for grad - def backward(ctx, g_output): + def backward(ctx, g_output, g_state=None): input_dtype = ctx.input_dtype time_decay, time_first, key, value, output = ctx.saved_tensors @@ -188,17 +188,14 @@ def backward(ctx, g_output): g_key, g_value, ) - g_time_decay = torch.sum(g_time_decay, dim=0) - g_time_first = torch.sum(g_time_first, dim=0) return ( - None, - None, - None, g_time_decay.to(input_dtype), g_time_first.to(input_dtype), g_key.to(input_dtype), g_value.to(input_dtype), + None, + None, )