diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index ca7d037dd..c88d952c8 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -160,6 +160,13 @@ def normalize_config(cfg): if isinstance(cfg.pretraining_dataset, dict): cfg.pretraining_dataset = [cfg.pretraining_dataset] + if ( + cfg.gradient_checkpointing + and cfg.unfrozen_parameters is None + and cfg.gradient_checkpointing_kwargs is None + ): + cfg.gradient_checkpointing_kwargs = {"use_reentrant": True} + log_gpu_memory_usage(LOG, "baseline", cfg.device)