-
Notifications
You must be signed in to change notification settings - Fork 26.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ZeroShotClassificationPipeline has large memory spikes when using a lot of candidate_labels #24873
Comments
Hi @rsmith49 Thank you for opening this issue 🤗 . I will take a look! |
I have been running in a jupyter notebook, which I think does save the results from calling the pipeline since it is the final statement in the cell - let me try in a regular python process and see if the memory spikes the same. I should note though that the "1500 inference calls" I mentioned are only over 20 documents - since there are 130
Correct, the result from the pipeline does not contain the |
Hi, I am not able to reproduce with the following (slightly modified) script (see at the end), running in python directly iteration: 0
RAM: 4318.6015625 MB
timing: 18.248116 sec.
==============
iteration: 156
RAM: 4319.5 MB
timing: 18.464201 sec It would be great if you can try to see if the issue happens with python script only. However, this is a sequence classification model, and the from transformers import pipeline
tmp_repro_data = ['I purchased this to replace my 7 yr old video baby monitor that had been dropped too many times.'] * 20
ckpt = 'facebook/bart-large-mnli'
# ckpt = 'facebook/bart-base'
p = pipeline(
'zero-shot-classification',
model=ckpt,
device="cuda",
batch_size=20,
)
import pdb; pdb.set_trace()
def _revised_forward(self, inputs):
candidate_label = inputs["candidate_label"]
sequence = inputs["sequence"]
model_inputs = {k: inputs[k] for k in self.tokenizer.model_input_names} #type: ignore
outputs = self.model(**model_inputs, use_cache=False)
model_outputs = {
"candidate_label": candidate_label,
"sequence": sequence,
"is_last": inputs["is_last"],
**outputs,
}
return model_outputs
# With this line it works as expected, without it memory spikes. The only difference between `revised_forward`
# and the transformers repo is that we pass `use_cache=False` as an extra arg to inference with `self.model`
import psutil
import os
process = psutil.Process(os.getpid())
#p._forward = _revised_forward.__get__(p)
import datetime
for i in range(1000):
s = datetime.datetime.now()
o = p(
tmp_repro_data,
multi_label=True,
candidate_labels=list(range(130)),
)
e = datetime.datetime.now()
d = (e-s).total_seconds()
mem = process.memory_info()[0] / float(2 ** 20)
print(i)
print(mem)
print(d)
print("=" * 80) |
Thanks for looking into this! Weirdly, I also did not see memory spikes when using a single text snippet copied 20 times, only when using 20 unique strings (I'm guessing something to do with caching somewhere in either python, torch, or transformers that makes garbage collection more effective). So if you could try using the example list I posted above that may do it. Haven't had a chance to run the script in a pure python process but will let you know when I do! |
That would be nice to know! (I am opening a PR soon anyway :-) ) |
Ran the script as just |
You use the different 20 text sentences in (I am running with the same text repeated 20 times, with latest |
Yes, not sure why repeating the same text doesn't trigger it, but I get the same result as you when using repeated text |
System Info
Who can help?
@Narsil
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Repro Steps Cell
Expected behavior
Letting this script run as is causes memory (CPU memory, not GPU memory) to spike over 10Gi at around 1500 inference calls. This can break a lot of environments, especially anything involving running jobs on resource constrained machines.
After some debugging, we traced this to the
past_key_values
object being returned by the Bart model, which was a tuple of some very large tensors. We suspect that these large tensors are causing garbage collection to not be able to catch up when storing all of these model inference requests in a single list. Passinguse_cache=False
to model inference (and therefore not returning thepast_key_values
object) fixes the memory spikes, making us think this was indeed the issue.The text was updated successfully, but these errors were encountered: