Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions examples/post_training/supervised_fine_tune_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.

import asyncio
from typing import Optional

import fire
from llama_stack_client import LlamaStackClient

from llama_stack_client.types.post_training_supervised_fine_tune_params import (
AlgorithmConfigLoraFinetuningConfig,
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigEfficiencyConfig,
TrainingConfigOptimizerConfig,
)


async def run_main(
host: str,
port: int,
job_uuid: str,
model: str,
use_https: bool = False,
checkpoint_dir: Optional[str] = None,
cert_path: Optional[str] = None,
):

# Construct the base URL with the appropriate protocol
protocol = "https" if use_https else "http"
base_url = f"{protocol}://{host}:{port}"

# Configure client with SSL certificate if provided
client_kwargs = {"base_url": base_url}
if use_https and cert_path:
client_kwargs["verify"] = cert_path

client = LlamaStackClient(**client_kwargs)

algorithm_config = AlgorithmConfigLoraFinetuningConfig(
type="LoRA",
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
apply_lora_to_mlp=True,
apply_lora_to_output=False,
rank=8,
alpha=16,
)

data_config = TrainingConfigDataConfig(
dataset_id="alpaca",
validation_dataset_id="alpaca",
batch_size=1,
shuffle=False,
)

optimizer_config = TrainingConfigOptimizerConfig(
optimizer_type="adamw",
lr=3e-4,
weight_decay=0.1,
num_warmup_steps=100,
)

effiency_config = TrainingConfigEfficiencyConfig(
enable_activation_checkpointing=True,
)

training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
efficiency_config=effiency_config,
optimizer_config=optimizer_config,
max_steps_per_epoch=30,
gradient_accumulation_steps=1,
)

training_job = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
model=model,
algorithm_config=algorithm_config,
training_config=training_config,
checkpoint_dir=checkpoint_dir,
# logger_config and hyperparam_search_config haven't been used yet
logger_config={},
hyperparam_search_config={},
)

print(f"finished the training job: {training_job.job_uuid}")


def main(
host: str,
port: int,
job_uuid: str,
model: str,
use_https: bool = False,
checkpoint_dir: Optional[str] = "null",
cert_path: Optional[str] = None,
):
job_uuid = str(job_uuid)
asyncio.run(
run_main(host, port, job_uuid, model, use_https, checkpoint_dir, cert_path)
)


if __name__ == "__main__":
fire.Fire(main)
2 changes: 2 additions & 0 deletions src/llama_stack_client/lib/cli/llama_stack_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .inference import inference
from .memory_banks import memory_banks
from .models import models
from .post_training import post_training
from .providers import providers
from .scoring_functions import scoring_functions
from .shields import shields
Expand Down Expand Up @@ -75,6 +76,7 @@ def cli(ctx, endpoint: str, config: str | None):
cli.add_command(scoring_functions, "scoring_functions")
cli.add_command(eval, "eval")
cli.add_command(inference, "inference")
cli.add_command(post_training, "post_training")


def main():
Expand Down
9 changes: 9 additions & 0 deletions src/llama_stack_client/lib/cli/post_training/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .post_training import post_training

__all__ = ["post_training"]
117 changes: 117 additions & 0 deletions src/llama_stack_client/lib/cli/post_training/post_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Optional

import click

from llama_stack_client.types.post_training_supervised_fine_tune_params import (
AlgorithmConfig,
TrainingConfig,
)
from rich.console import Console

from ..common.utils import handle_client_errors


@click.group()
def post_training():
"""Query details about available post_training endpoints on distribution."""
pass


@click.command("supervised_fine_tune")
@click.option("--job-uuid", required=True, help="Job UUID")
@click.option("--model", required=True, help="Model ID")
@click.option("--algorithm-config", required=True, help="Algorithm Config")
@click.option("--training-config", required=True, help="Training Config")
@click.option(
"--checkpoint-dir", required=False, help="Checkpoint Config", default=None
)
@click.pass_context
@handle_client_errors("post_training supervised_fine_tune")
def supervised_fine_tune(
ctx,
job_uuid: str,
model: str,
algorithm_config: AlgorithmConfig,
training_config: TrainingConfig,
checkpoint_dir: Optional[str],
):
"""Kick off a supervised fine tune job"""
client = ctx.obj["client"]
console = Console()

post_training_job = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
model=model,
algorithm_config=algorithm_config,
training_config=training_config,
checkpoint_dir=checkpoint_dir,
# logger_config and hyperparam_search_config haven't been used yet
logger_config={},
hyperparam_search_config={},
)
console.print(post_training_job.job_uuid)


@click.command("list")
@click.pass_context
@handle_client_errors("post_training get_training_jobs")
def get_training_jobs(ctx):
"""Show the list of available post training jobs"""
client = ctx.obj["client"]
console = Console()

post_training_jobs = client.post_training.job.list()
console.print(
[post_training_job.job_uuid for post_training_job in post_training_jobs]
)


@click.command("status")
@click.option("--job-uuid", required=True, help="Job UUID")
@click.pass_context
@handle_client_errors("post_training get_training_job_status")
def get_training_job_status(ctx, job_uuid: str):
"""Show the status of a specific post training job"""
client = ctx.obj["client"]
console = Console()

job_status_reponse = client.post_training.job.status(job_uuid=job_uuid)
console.print(job_status_reponse)


@click.command("artifacts")
@click.option("--job-uuid", required=True, help="Job UUID")
@click.pass_context
@handle_client_errors("post_training get_training_job_artifacts")
def get_training_job_artifacts(ctx, job_uuid: str):
"""Get the training artifacts of a specific post training job"""
client = ctx.obj["client"]
console = Console()

job_artifacts = client.post_training.job.artifacts(job_uuid=job_uuid)
console.print(job_artifacts)


@click.command("cancel")
@click.option("--job-uuid", required=True, help="Job UUID")
@click.pass_context
@handle_client_errors("post_training cancel_training_job")
def cancel_training_job(ctx, job_uuid: str):
"""Cancel the training job"""
client = ctx.obj["client"]

client.post_training.job.cancel(job_uuid=job_uuid)


# Register subcommands
post_training.add_command(supervised_fine_tune)
post_training.add_command(get_training_jobs)
post_training.add_command(get_training_job_status)
post_training.add_command(get_training_job_artifacts)
post_training.add_command(cancel_training_job)
Loading