Skip to content

Commit

Permalink
add support for custom function for reducing the batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 3, 2024
1 parent b5235f2 commit dd1538e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/accelerate/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def should_reduce_batch_size(exception: Exception) -> bool:
return False


def find_executable_batch_size(function: callable = None, starting_batch_size: int = 128):
def find_executable_batch_size(function: callable = None, starting_batch_size: int = 128, reduce_batch_size_fn: callable = None):
"""
A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
CUDNN, the batch size is cut in half and passed to `function`
Expand Down Expand Up @@ -132,6 +132,8 @@ def find_executable_batch_size(function: callable = None, starting_batch_size: i
"""
if function is None:
return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size)
if reduce_batch_size_fn is None:
reduce_batch_size_fn = lambda: batch_size // 2

batch_size = starting_batch_size

Expand All @@ -154,7 +156,7 @@ def decorator(*args, **kwargs):
except Exception as e:
if should_reduce_batch_size(e):
clear_device_cache(garbage_collection=True)
batch_size //= 2
batch_size = reduce_batch_size_fn()
else:
raise

Expand Down

0 comments on commit dd1538e

Please sign in to comment.