-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Out of memory error when running torchdynamo with model and custom backend #1955
Comments
I have the same issue on T5-3B (OOM because of the @williamwen42 any idea of a (dirty?) workaround if a clean fix takes time to come? Seems to be related to #1950 |
If it is really #1950, I can give you a dirty workaround for it. |
|
Correct me if I'm wrong but I think the dirty fix has been applied here #1950 and I tried testing it and am still running out of memory |
Yeah, then this is a different problem, we will need to investigate |
@TheExGenesis are you trying on Whisper? where does it crash? On last nighties, the issue seems to be elsewhere. @ezyang is it possible that eager mode of dynamo has a higher (even slightly, like from 10.4 Gb to 10.6 Gb of CUDA memory reserved) memory footprint than "real" eager mode (aka without dynamo)? Also, would it be possible that the garbage collector is not called with eager+dynamo as it would for real eager mode? (later or never) |
Dynamo eager can use more memory, but we found in our benchmark suite that typically memory usage improved, because our min cut graph partitioner can make better choices about what to save for backwards. The other known and obvious culprits for memory usage is cuda graphs (but this is turned off by default) and fake tensor falling back to real operations to fallback for meta usage (but this is a very slight amount of extra memory usage, only as much as is necessary to allocate the inputs/outputs for a particular operation.) @eellison, do we have an easy log level to test for the latter? I'm going to bump the priority to make sure we have someone look into this. |
@ezyang we don't atm, I can add. the one off-ops culprit was actually a red herring for other things (cudagraphs), when I landed the change for running ops inductor with fake tensor instead of regular tensors memory compression didn't decrease at all. I think it would be worth adding a debug mode that prints out the additional memory overhead for some of the following when it's significant. I think the remaining sources of memory overhead in order of likeliness:
|
The issue happens at inference time with dynamo+eager mode (no CUDA graph, no Triton involved). Code to make it raise OOM is the following: import torch
import torch._dynamo as torchdynamo
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, WhisperProcessor
torch.cuda.memory._record_memory_history(True)
torchdynamo.config.cache_size_limit = 512
audio_dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large").to("cuda")
def optimize_model(original_model) -> None:
original_model.forward2 = original_model.forward
@torchdynamo.optimize("eager")
def run(*args, **kwargs):
return original_model.forward2(*args, **kwargs)
original_model.forward = run
optimize_model(model.model.decoder)
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
speech_data = audio_dataset[0]["audio"]["array"]
inputs = processor(speech_data, return_tensors="pt", sampling_rate=16_000).input_features.to("cuda")
with torch.no_grad(), torch.autocast(dtype=torch.float16, cache_enabled=True, device_type="cuda"):
predicted_ids = model.generate(inputs, min_length=25, max_length=25, num_beams=5, do_sample=False)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]
assert (
transcription == "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
), transcription
print(transcription)
print("torch.cuda.memory_allocated: %fGB" % (torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024))
print("torch.cuda.memory_reserved: %fGB" % (torch.cuda.memory_reserved(0) / 1024 / 1024 / 1024))
print("torch.cuda.max_memory_reserved: %fGB" % (torch.cuda.max_memory_reserved(0) / 1024 / 1024 / 1024))
1/ it works without dynamo (memory reserved < 12Gb), aka if you comment The error is due to this line in Whisper model: This function is called by the beam decoder. I know it because:
Without dynamo, it prints:
Moreover new Pytorch memory profiler ( I am under the impression that with torch dynamo the garbage collector can't delete these tensors, and then the CUDA memory can't be freed. The tensors of this function will be output by the model (in the cache of the transformer model) and then reused as input to generate the next token. One possible issue is that, for some reason IDK, reference to those tensors are captured by dynamo and they can't be garbage collected anymore. Makes sense to you? Not related, but still sharing, minifier doesn't seem to catch those OOM issues, at least it's the second time it fails for me (and works for simpler case). |
I'm not getting OOM anymore on an 80GB A100, but I am hitting the cache limit and getting no speed improvement (strictly 0.99x of baseline). Cache limit warnings:
|
You can increase the cache limit by modifying Dynamo config like code posted just above: torchdynamo.config.cache_size_limit = 512 Moreover can you share your cuda memory footprint after running the model? |
With "eager", I can't raise the cache_size_limit above 64 without getting OOM With "ofi", even at cache_size_limit=64, I'm getting OOM, also a bunch of these Warnings that I haven't had time to research
|
@gaetansnl You are using the a no-op compiler which doesn't free the inputs to the backward when they are no longer needed. This will incur significant memory overhead. Could you try the default inductor backend ? i.e. |
@TheExGenesis if you are seeing issues different from this one, please open a new issue, thank you. |
Removing high priority because this is using a non standard backend which doesn't free inputs, so memory regression is expected. |
@eellison do you have more details on what needs to be implemented in the backend ? I can't use inductor because I have a custom backend implementation |
The problem is detailed here pytorch/pytorch#83137. To fix it for your backend, you want to return a compiled function that takes in a list of tensors by marking _boxed_call = True, and you also want to make sure the list is cleared and the inputs are freed when they are no longer needed. pytorch/pytorch#83137 is a good example of a PR to follow. CC @SherlockNoMad for custom backend this might be a good thing to document if it's not already. |
I also have ouf of memory with inductor
|
I just reran the code with eager compiler + today nightly... and no more OOM! Inductor compiler raises OOM but on CUDA graph, it's not surprising as CG copy input tensors and this model has a huge encoder output (appear in cache) if duplicated for each seq len of the decoder, it s not surprising it OOM.
@gaetansnl can we close the issue? |
thanks a lot everyone ! |
🐛 Describe the bug
Hello ! I have an out of memory error when I try to run Whisper through torchdynamo.
openai/whisper-medium
num_beams=1
optimize_model(model.model.decoder)
it worksWhen I set
use_cache
tofalse
ingenerate
it segfault instead of OOM.And I don't think the minifier is working for this case.
Pytorch:
1.14.0.dev20221130+cu117
(nightly)Minimal reproduction
Error logs
Minified repro
No response
The text was updated successfully, but these errors were encountered: