Skip to content

Commit

Permalink
fix scoping
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 3, 2024
1 parent dd1538e commit 24d20b2
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/accelerate/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,13 @@ def find_executable_batch_size(function: callable = None, starting_batch_size: i
"""
if function is None:
return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size)
if reduce_batch_size_fn is None:
reduce_batch_size_fn = lambda: batch_size // 2

batch_size = starting_batch_size
if reduce_batch_size_fn is None:
def reduce_batch_size_fn():
nonlocal batch_size
batch_size = batch_size // 2
return batch_size

def decorator(*args, **kwargs):
nonlocal batch_size
Expand Down

0 comments on commit 24d20b2

Please sign in to comment.