Skip to content

Memory usage issues with TIMM classifier #609

@aweaver1fandm

Description

@aweaver1fandm

Search before asking

  • I have searched the Pytorch-Wildlife issues and found no similar bug report.

Bug

I found this when trying to use the Deep Faun New England classifier for species classification. As the code is currently written, running the DFNE classifier (and presumambly the Deep Faun classifier as well) causes memory usage to balloon. I did a batch classification of approximately 150 images and memory usage (RAM or GPU RAM, doesn't matter) just keeps growing. I tried running classification on roughly 3000 images and memory usage ballooned to over 300 GB and I didn't even get half way through the set of images.

I believe the issue is in
PytorchWildlife/models/classification/timm_base/base_classifier.py

in the batch_image_classification function specifically these lines of code:

with tqdm(total=len(dataloader)) as pbar:
for batch in dataloader:
imgs, paths = batch
imgs = imgs.to(self.device)
total_logits.append(self.predictor(imgs))
total_paths.append(paths)
pbar.update(1)

I believe the issue is that model is still keeping track of gradient information. When I modified the code to be as follows:
with tqdm(total=len(dataloader)) as pbar:
with torch.no_grad():
for batch in dataloader:
imgs, paths = batch
imgs = imgs.to(self.device)
total_logits.append(self.predictor(imgs))
total_paths.append(paths)
pbar.update(1)

Memory usage was normal and I was able to run the classifier without issue

Environment

No response

Minimal Reproducible Example

No response

Additional

No response

Are you willing to submit a PR?

  • Yes I'd like to help by submitting a PR!

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    Status

    Backlog

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions