-
Notifications
You must be signed in to change notification settings - Fork 284
Description
Search before asking
- I have searched the Pytorch-Wildlife issues and found no similar bug report.
Bug
I found this when trying to use the Deep Faun New England classifier for species classification. As the code is currently written, running the DFNE classifier (and presumambly the Deep Faun classifier as well) causes memory usage to balloon. I did a batch classification of approximately 150 images and memory usage (RAM or GPU RAM, doesn't matter) just keeps growing. I tried running classification on roughly 3000 images and memory usage ballooned to over 300 GB and I didn't even get half way through the set of images.
I believe the issue is in
PytorchWildlife/models/classification/timm_base/base_classifier.py
in the batch_image_classification function specifically these lines of code:
with tqdm(total=len(dataloader)) as pbar:
for batch in dataloader:
imgs, paths = batch
imgs = imgs.to(self.device)
total_logits.append(self.predictor(imgs))
total_paths.append(paths)
pbar.update(1)
I believe the issue is that model is still keeping track of gradient information. When I modified the code to be as follows:
with tqdm(total=len(dataloader)) as pbar:
with torch.no_grad():
for batch in dataloader:
imgs, paths = batch
imgs = imgs.to(self.device)
total_logits.append(self.predictor(imgs))
total_paths.append(paths)
pbar.update(1)
Memory usage was normal and I was able to run the classifier without issue
Environment
No response
Minimal Reproducible Example
No response
Additional
No response
Are you willing to submit a PR?
- Yes I'd like to help by submitting a PR!
Metadata
Metadata
Assignees
Labels
Type
Projects
Status