Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion examples/art-e/all_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,26 @@
models["217"].config.include_qwen3_nothink = True

models["218"] = models["206"].model_copy(deep=True)
models["218"].name = "email-agent-218"
models["218"].name = "email-agent-218-5"
models["218"].base_model = "Qwen/Qwen3-32B"
models["218"].config.group_judge_model = "base_model"
models["218"].config.include_qwen3_nothink = True

# Model 219: like 008 but with custom internal config (low max_grad_norm) and high learning rate
models["219"] = models["008"].model_copy(deep=True)
models["219"].name = "email-agent-219"
models["219"].config.learning_rate = 1e-2
models["219"]._internal_config = art.dev.InternalModelConfig(
trainer_args=art.dev.TrainerArgs(
max_grad_norm=1e-7,
)
)

models["220"] = models["217"].model_copy(deep=True)
models["220"].name = "email-agent-220"
models["220"].base_model = "willcb/Qwen3-14B"

models["221"] = models["008"].model_copy(deep=True)
models["221"].name = "email-agent-221"
models["221"].config.include_qwen3_nothink = True
models["221"].base_model = "willcb/Qwen3-32B"
28 changes: 23 additions & 5 deletions examples/art-e/art_e/group_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ class GroupJudge:
def __init__(
self,
project: str,
judge_model: str | art.Model = "openai/o3",
judge_model: str | art.TrainableModel = "openai/o3",
judge_model_use_base_model: bool = False,
rubric: str = DEFAULT_RUBRIC,
):
self.project = project # store for later use
self.judge_model = judge_model
self.judge_model_use_base_model = judge_model_use_base_model
self.rubric = rubric

@weave.op()
Expand Down Expand Up @@ -129,13 +131,18 @@ async def judge(
completion_params = {}
if isinstance(self.judge_model, art.Model):
completion_params = self.judge_model.litellm_completion_params()
if self.judge_model_use_base_model:
# When using base_model, we still need the model's inference configuration
# (api_key, base_url) but we override the model name to use the base model
completion_params["model"] = (
f"hosted_vllm/{self.judge_model.base_model}"
)
else:
completion_params["model"] = self.judge_model

print("model is", self.judge_model)
print("judge completion_params", completion_params)
response = await acompletion(
# **completion_params,
model=self.judge_model,
**completion_params,
messages=messages,
response_format=GroupJudgeResponse,
caching=True,
Expand Down Expand Up @@ -218,9 +225,20 @@ async def main():
for m, t in zip(models, rollouts):
print(f" {m.name:10s}: {t.reward:.3f}")

art_model = art.TrainableModel[ProjectPolicyConfig](
name="willcb/Qwen3-32B",
base_model="willcb/Qwen3-32B",
# inference_model_name="hosted_vllm/willcb/Qwen3-32B",
project="email_agent",
config=ProjectPolicyConfig(),
)
art_model.inference_api_key = "default"
art_model.inference_base_url = "http://localhost:8000/v1"

judge = GroupJudge(
project="email_agent",
judge_model="openrouter/qwen/qwen3-32b",
judge_model=art_model,
judge_model_use_base_model=True,
# judge_model="openrouter/qwen/qwen3-14b",
# judge_model="openai/o3",
)
Expand Down
2 changes: 1 addition & 1 deletion examples/art-e/art_e/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ async def rollout(
)

if model.config.include_qwen3_nothink:
system_prompt += "\n/nothink"
system_prompt += "\n/no_think"

async def search_emails(keywords: list[str]) -> list[dict]:
"""
Expand Down
27 changes: 12 additions & 15 deletions examples/art-e/art_e/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,6 @@
async def train(model: art.TrainableModel[ProjectPolicyConfig]):
generate_database()

# Training config is now directly on the model config

if model.config.group_judge_model is not None:
judge_model = model.config.group_judge_model

if model.config.group_judge_model == "self":
judge_model = model
elif judge_model == "base_model":
judge_model = model.base_model

group_judge = GroupJudge(
project=model.project,
judge_model=judge_model,
)

with LocalBackend() as backend:
print(f"Pulling from S3 bucket: `{os.environ['BACKUP_BUCKET']}`")
await backend._experimental_pull_from_s3(
Expand Down Expand Up @@ -69,6 +54,18 @@ async def train(model: art.TrainableModel[ProjectPolicyConfig]):
initial_step=await model.get_step(),
)

if model.config.group_judge_model is not None:
judge_model = model.config.group_judge_model

if model.config.group_judge_model in ["self", "base_model"]:
judge_model = model

group_judge = GroupJudge(
project=model.project,
judge_model=judge_model,
judge_model_use_base_model=judge_model == "base_model",
)

for batch in train_iterator:
if batch.step % model.config.eval_steps == 0:
print(f"\n--- Evaluating at Iteration {batch.step} ---")
Expand Down
30 changes: 12 additions & 18 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,11 @@ version = "0.3.13"
description = "The OpenPipe Agent Reinforcement Training (ART) library"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"matplotlib>=3.10.1",
"openai>=1.65.5",
"seaborn>=0.13.2",
"setuptools>=78.1.0",
"wandb>=0.19.8",
"weave>=0.51.51",
"typer>=0.15.2",
"tblib>=3.0.0",
"litellm>=1.63.0",
"polars>=1.26.0",
"awscli>=1.38.1",
"panza",
"semver>=3.0.4",
"setproctitle>=1.3.6",
"accelerate==1.7.0"
]
dependencies = ["openai>=1.65.5", "typer>=0.15.2", "litellm>=1.63.0"]

[project.optional-dependencies]
plotting = ["matplotlib>=3.10.1", "seaborn>=0.13.2", "polars>=1.26.0"]

backend = [
"peft>=0.14.0",
"hf-xet>=1.1.0",
Expand All @@ -34,6 +20,14 @@ backend = [
"trl>=0.19.0",
"torch>=2.7.0",
"torchao>=0.9.0",
"accelerate==1.7.0",
"awscli>=1.38.1",
"setproctitle>=1.3.6",
"weave>=0.51.51",
"tblib>=3.0.0",
"semver>=3.0.4",
"setuptools>=78.1.0",
"wandb>=0.19.8",
]

[project.scripts]
Expand Down Expand Up @@ -62,7 +56,7 @@ dev-dependencies = [
"openpipe>=4.49.0",
"hatch>=1.14.1",
"ruff>=0.12.1",
"skypilot[cudo,do,fluidstack,gcp,lambda,paperspace,runpod]==0.8.0",
"skypilot[cudo,do,fluidstack,gcp,lambda,paperspace,runpod]==0.9.3",
]

[tool.uv.sources]
Expand Down
11 changes: 9 additions & 2 deletions src/art/local/pack.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import matplotlib.pyplot as plt
import os
import random
import seaborn as sns
import torch
from typing_extensions import TypedDict, Unpack

Expand Down Expand Up @@ -167,6 +165,15 @@ def packed_tensors_to_dir(tensors: PackedTensors, dir: str) -> DiskPackedTensors


def plot_packed_tensors(packed_tensors: PackedTensors) -> None:
try:
import matplotlib.pyplot as plt
import seaborn as sns
except ImportError:
raise ImportError(
"Plotting dependencies are not installed. Please install them with: "
"pip install openpipe-art[plotting]"
)

plt.figure(figsize=(15, 24))

for tensor, label, title, subplot_idx in (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import polars as pl
try:
import polars as pl
except ImportError:
raise ImportError(
"Plotting dependencies are not installed. Please install them with: "
"pip install openpipe-art[plotting]"
)

from ..types import BenchmarkModelKey
from ..filter_model_split import filter_rename_model_split

Expand Down Expand Up @@ -46,9 +53,15 @@ def percentage_comparison_bar_chart(
Jupyter notebooks.
"""

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
try:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
except ImportError:
raise ImportError(
"Plotting dependencies are not installed. Please install them with: "
"pip install openpipe-art[plotting]"
)

# create new copy of df
df = df.clone()
Expand Down
12 changes: 9 additions & 3 deletions src/art/utils/benchmarking/charts/training_progress_chart.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns
try:
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns
except ImportError:
raise ImportError(
"Plotting dependencies are not installed. Please install them with: "
"pip install openpipe-art[plotting]"
)

from art.utils.benchmarking.types import BenchmarkModelKey
from art.utils.benchmarking.filter_model_split import filter_rename_model_split
Expand Down
8 changes: 7 additions & 1 deletion src/art/utils/benchmarking/filter_model_split.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import polars as pl
try:
import polars as pl
except ImportError:
raise ImportError(
"Plotting dependencies are not installed. Please install them with: "
"pip install openpipe-art[plotting]"
)

from art.utils.benchmarking.types import BenchmarkModelKey

Expand Down
11 changes: 7 additions & 4 deletions src/art/utils/benchmarking/load_trajectories.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import polars as pl
try:
import polars as pl
except ImportError:
raise ImportError(
"Plotting dependencies are not installed. Please install them with: "
"pip install openpipe-art[plotting]"
)
import yaml
from pathlib import Path
from panza import SQLiteCache
import os

from art.model import Model as ArtModel
Expand All @@ -15,10 +20,8 @@

cache_path = Path(get_repo_root_path()) / "data" / "cache.db"
cache_path.parent.mkdir(parents=True, exist_ok=True)
cache = SQLiteCache(str(cache_path))


@cache.cache()
async def load_trajectories(
project_name: str,
models: list[str] | None = None,
Expand Down
9 changes: 8 additions & 1 deletion src/art/utils/old_benchmarking/generate_line_graphs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import os
from datetime import datetime
import matplotlib.pyplot as plt
from typing import Literal

try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError(
"Plotting dependencies are not installed. Please install them with: "
"pip install openpipe-art[plotting]"
)

from .load_benchmarked_models import load_benchmarked_models
from .types import BenchmarkedModelKey
from ..output_dirs import get_default_art_path
Expand Down
Loading