Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] sweeps #2171

Merged
merged 8 commits into from
Feb 2, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[feature] sweeps
  • Loading branch information
winglian committed Jan 31, 2025
commit c9014171f9f4e58b5ea4d66b045a3a01d602a2cd
126 changes: 97 additions & 29 deletions src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
"""Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name

import logging
import subprocess # nosec B404
import tempfile
from copy import deepcopy
from itertools import product
from pathlib import Path
from typing import Optional

import click
import yaml

import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
Expand All @@ -20,6 +26,33 @@
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig


def generate_sweep_configs(base_config, sweeps_config):
"""
Recursively generates all possible configurations by applying sweeps to the base config.

Args:
base_config (dict): The original configuration dictionary
sweeps_config (dict): Dictionary where keys are paths to parameters and values are lists of values to sweep

Returns:
list: List of all possible configuration dictionaries
"""
# Get all parameter combinations
param_names = list(sweeps_config.keys())
param_values = list(sweeps_config.values())
all_combinations = list(product(*param_values))

# Generate a new config for each combination
result_configs = []
for combination in all_combinations:
new_config = deepcopy(base_config)
for param_name, param_value in zip(param_names, combination):
new_config = new_config[param_name] = param_value
result_configs.append(new_config)

return result_configs


@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli():
Expand Down Expand Up @@ -60,17 +93,23 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
help="Use accelerate launch for multi-GPU training",
)
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@click.option(
"--sweep",
type=click.Path(exists=True, path_type=str),
help="YAML config for sweeping hyperparameters",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs) -> None:
def train(config: str, accelerate: bool, cloud: Optional[str] = None, sweep: Optional[str] = None, **kwargs) -> None:
"""
Train or fine-tune a model.

Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
cloud: Path to a cloud accelerator configuration file
sweep: Path to YAML config for sweeping hyperparameters.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
Expand All @@ -80,35 +119,64 @@ def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs)

if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False

if accelerate:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
else:
accelerate_args = []
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port", None)
accelerate_args.append("--main_process_port")
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num-processes")
accelerate_args.append(str(num_processes))

base_cmd = ["accelerate", "launch"]
base_cmd.extend(accelerate_args)
base_cmd.extend(["-m", "axolotl.cli.train"])
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
if sweep:
# load the sweep configuration yaml file
with open(sweep, "r", encoding="utf-8") as fin:
sweep_config: dict[str, list] = yaml.safe_load(fin)
with open(config, "r", encoding="utf-8") as fin:
base_config: dict[str, list] = yaml.safe_load(fin)

# generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)

def iter_configs():
for perm in permutations:
# open temp directory for temporary configurations
with tempfile.TemporaryDirectory() as temp_dir:
with open(
Path(temp_dir) / "config.yaml", "w", encoding="utf-8"
) as fout:
yaml.dump(perm, fout)
yield str(Path(temp_dir) / "config.yaml")
else:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
else:
from axolotl.cli.train import do_cli

do_cli(config=config, **kwargs)
def iter_configs():
yield config

for cfg_file in iter_configs():
# handle errors from subprocess so we can continue rest of sweeps
try:
if accelerate:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
else:
accelerate_args = []
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port", None)
accelerate_args.append("--main_process_port")
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num-processes")
accelerate_args.append(str(num_processes))

base_cmd = ["accelerate", "launch"]
base_cmd.extend(accelerate_args)
base_cmd.extend(["-m", "axolotl.cli.train"])
if cfg_file:
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
else:
from axolotl.cli.train import do_cli

do_cli(config=cfg_file, **kwargs)
except subprocess.CalledProcessError as exc:
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:
raise exc


@cli.command()
Expand Down