Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RWKV - loss.backward() failed #23653

Closed
2 of 4 tasks
LetianLee opened this issue May 22, 2023 · 9 comments
Closed
2 of 4 tasks

RWKV - loss.backward() failed #23653

LetianLee opened this issue May 22, 2023 · 9 comments

Comments

@LetianLee
Copy link

LetianLee commented May 22, 2023

System Info

  • transformers version: 4.29.2
  • Platform: Linux-5.15.107+-x86_64-with-glibc2.31
  • Python version: 3.10.11
  • Huggingface_hub version: 0.14.1
  • Safetensors version: not installed
  • PyTorch version (GPU?): 2.0.1+cu118 (True)
  • Tensorflow version (GPU?): 2.12.0 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.6.9 (gpu)
  • Jax version: 0.4.8
  • JaxLib version: 0.4.7
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

@ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Below is the code from the official example: https://huggingface.co/docs/transformers/main/en/model_doc/rwkv#transformers.RwkvForCausalLM
import torch
from transformers import AutoTokenizer, RwkvForCausalLM

tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-169m-pile")
model = RwkvForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile")

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
  1. I only added this line loss.backward() to run but it failed:
import torch
from transformers import AutoTokenizer, RwkvForCausalLM

tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-169m-pile")
model = RwkvForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile")

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
loss.backward()
  1. Error messages:
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-7-ffc1d58be8b3>](https://localhost:8080/#) in <cell line: 10>()
      8 outputs = model(**inputs, labels=inputs["input_ids"])
      9 loss = outputs.loss
---> 10 loss.backward()

1 frames
[/usr/local/lib/python3.10/dist-packages/torch/_tensor.py](https://localhost:8080/#) in backward(self, gradient, retain_graph, create_graph, inputs)
    485                 inputs=inputs,
    486             )
--> 487         torch.autograd.backward(
    488             self, gradient, retain_graph, create_graph, inputs=inputs
    489         )

[/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py](https://localhost:8080/#) in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    198     # some Python versions print out the first line of a multi-line function
    199     # calls in the traceback and some print out the last line
--> 200     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    201         tensors, grad_tensors_, retain_graph, create_graph, inputs,
    202         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 768]], which is output 0 of AsStridedBackward0, is at version 12; expected version 11 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Expected behavior

loss.backward() should work out.

@XueFuzhao
Copy link

+1 on this issue. Any update?

@Blealtan
Copy link

Checkout this reply. I guess it's the same issue though not looked into it.

@XueFuzhao
Copy link

Checkout this reply. I guess it's the same issue though not looked into it.

Thanks for the help. I still have the same bug as before.

  File "modeling_rwkv.py", line 783, in forward
    hidden_states, state, attentions = block(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "modeling_rwkv.py", line 510, in forward
    attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "modeling_rwkv.py", line 436, in forward
    rwkv, layer_state = rwkv_linear_attention(
  File "modeling_rwkv.py", line 377, in rwkv_linear_attention
    return rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=state, return_state=return_state)
  File "modeling_rwkv.py", line 361, in rwkv_linear_attention_cpu
    den_state = e1 * den_state + e2

Any idea about this?

@ghost
Copy link

ghost commented May 25, 2023

I'm experiencing loss.backward() failure when using custom cuda kernel. In other words, whenever the setup branches towards the else path below:

 if rwkv_cuda_kernel is None or no_cuda or one_token:
      return rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=state, return_state=return_state)
 else:
      return RwkvLinearAttention.apply(time_decay, time_first, key, value, state, return_state)

loss.backward() throws out an error "TypeError: backward() takes 2 positional arguments but 3 were given".
When rwkv_linear_attention_cpu is called instead, things work out fine.

Any ideas on what might contribute to this?

@ArthurZucker
Copy link
Collaborator

Pinging both @sgugger and @younesbelkada as they ported the model

@sgugger
Copy link
Collaborator

sgugger commented May 25, 2023

I can confirm the backward fails both on CPU (first error) and on GPU (last error). Diving into this.

@sgugger
Copy link
Collaborator

sgugger commented May 25, 2023

On CPU a simple workaround is to set model.train() (which you would need to do for real training anyway 😅 ), the bug comes from gradients of the state. I'll try to dig more, but it doesn't sounds super urgent.

For GPU the fix should be in a PR later today/tomorrow morning.

@sgugger
Copy link
Collaborator

sgugger commented May 26, 2023

GPU fix was merged in #23774

@LetianLee
Copy link
Author

Thanks! I have verified that it is working, and the fine-tuning process is also functioning properly after this issue has been fixed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants