From dd1538e5ee676f9f561d52b15e77bc5f9e68b427 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 3 Sep 2024 14:18:27 -0400 Subject: [PATCH] add support for custom function for reducing the batch size --- src/accelerate/utils/memory.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/accelerate/utils/memory.py b/src/accelerate/utils/memory.py index baa5377f6a5..26117a86075 100644 --- a/src/accelerate/utils/memory.py +++ b/src/accelerate/utils/memory.py @@ -103,7 +103,7 @@ def should_reduce_batch_size(exception: Exception) -> bool: return False -def find_executable_batch_size(function: callable = None, starting_batch_size: int = 128): +def find_executable_batch_size(function: callable = None, starting_batch_size: int = 128, reduce_batch_size_fn: callable = None): """ A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or CUDNN, the batch size is cut in half and passed to `function` @@ -132,6 +132,8 @@ def find_executable_batch_size(function: callable = None, starting_batch_size: i """ if function is None: return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size) + if reduce_batch_size_fn is None: + reduce_batch_size_fn = lambda: batch_size // 2 batch_size = starting_batch_size @@ -154,7 +156,7 @@ def decorator(*args, **kwargs): except Exception as e: if should_reduce_batch_size(e): clear_device_cache(garbage_collection=True) - batch_size //= 2 + batch_size = reduce_batch_size_fn() else: raise