From 802f9667a2ba914a80db9d7f0d0073add263fdf1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 Jan 2024 19:48:22 -0500 Subject: [PATCH] improve vram use w gradient checkpointing (#1167) [skip ci] --- src/axolotl/utils/config.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index d6f144912..e75696918 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -159,6 +159,13 @@ def normalize_config(cfg): if isinstance(cfg.pretraining_dataset, dict): cfg.pretraining_dataset = [cfg.pretraining_dataset] + if ( + cfg.gradient_checkpointing + and cfg.unfrozen_parameters is None + and cfg.gradient_checkpointing_kwargs is None + ): + cfg.gradient_checkpointing_kwargs = {"use_reentrant": True} + log_gpu_memory_usage(LOG, "baseline", cfg.device)