From 24d20b244b4c6293b70add05b0752744e30a194d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 3 Sep 2024 14:26:46 -0400 Subject: [PATCH] fix scoping --- src/accelerate/utils/memory.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/accelerate/utils/memory.py b/src/accelerate/utils/memory.py index 26117a86075..b01a97390f6 100644 --- a/src/accelerate/utils/memory.py +++ b/src/accelerate/utils/memory.py @@ -132,10 +132,13 @@ def find_executable_batch_size(function: callable = None, starting_batch_size: i """ if function is None: return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size) - if reduce_batch_size_fn is None: - reduce_batch_size_fn = lambda: batch_size // 2 batch_size = starting_batch_size + if reduce_batch_size_fn is None: + def reduce_batch_size_fn(): + nonlocal batch_size + batch_size = batch_size // 2 + return batch_size def decorator(*args, **kwargs): nonlocal batch_size