Skip to content

Commit 0e498cf

Browse files
committed
fix: Formatting/linter issues
1 parent 9016205 commit 0e498cf

File tree

8 files changed

+222
-103
lines changed

8 files changed

+222
-103
lines changed

examples/hn_title_generator/reference_grpo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Tuple
44

55
import numpy as np
6-
import wandb
76
from datasets import Dataset
87
from dotenv import load_dotenv
98
from transformers import PreTrainedTokenizer
@@ -18,6 +17,7 @@
1817
)
1918
from vllm import SamplingParams
2019

20+
import wandb
2121
from art.utils import limit_concurrency
2222

2323
load_dotenv()

src/art/cli.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,21 @@
2929

3030
@app.command()
3131
def migrate(
32-
path: Path = typer.Argument(..., help="Path to model dir, project dir, or trajectories dir"),
33-
dry_run: bool = typer.Option(False, "--dry-run", "-n", help="Show what would be migrated without making changes"),
34-
keep_jsonl: bool = typer.Option(False, "--keep-jsonl", help="Keep original JSONL files after conversion"),
35-
verbose: bool = typer.Option(False, "--verbose", "-v", help="Print progress for each file"),
32+
path: Path = typer.Argument(
33+
..., help="Path to model dir, project dir, or trajectories dir"
34+
),
35+
dry_run: bool = typer.Option(
36+
False,
37+
"--dry-run",
38+
"-n",
39+
help="Show what would be migrated without making changes",
40+
),
41+
keep_jsonl: bool = typer.Option(
42+
False, "--keep-jsonl", help="Keep original JSONL files after conversion"
43+
),
44+
verbose: bool = typer.Option(
45+
False, "--verbose", "-v", help="Print progress for each file"
46+
),
3647
) -> None:
3748
"""
3849
Migrate trajectory files from JSONL to Parquet format.
@@ -88,24 +99,33 @@ def migrate(
8899
model_dir,
89100
delete_originals=not keep_jsonl,
90101
dry_run=dry_run,
91-
progress_callback=lambda f: typer.echo(f" {f}") if verbose else None,
102+
progress_callback=lambda f: typer.echo(f" {f}")
103+
if verbose
104+
else None,
92105
)
93106
result = result + model_result
94107
else:
95-
typer.echo(f"Error: Could not determine path type. Expected a model, project, or trajectories directory.", err=True)
108+
typer.echo(
109+
f"Error: Could not determine path type. Expected a model, project, or trajectories directory.",
110+
err=True,
111+
)
96112
raise typer.Exit(1)
97113

98114
# Print summary
99115
if dry_run:
100116
typer.echo(f"\n[DRY RUN] Would migrate {result.files_migrated} files")
101117
if result.bytes_before > 0:
102-
typer.echo(f" Estimated space savings: {result.space_saved / 1024 / 1024:.1f} MB")
118+
typer.echo(
119+
f" Estimated space savings: {result.space_saved / 1024 / 1024:.1f} MB"
120+
)
103121
else:
104122
typer.echo(f"\nMigrated {result.files_migrated} files")
105123
if result.files_skipped > 0:
106124
typer.echo(f"Skipped {result.files_skipped} files")
107125
if result.bytes_before > 0 and result.bytes_after > 0:
108-
typer.echo(f"Space saved: {result.space_saved / 1024 / 1024:.1f} MB ({result.compression_ratio:.1f}x compression)")
126+
typer.echo(
127+
f"Space saved: {result.space_saved / 1024 / 1024:.1f} MB ({result.compression_ratio:.1f}x compression)"
128+
)
109129

110130
if result.errors:
111131
typer.echo(f"\nErrors ({len(result.errors)}):", err=True)

src/art/local/backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,16 @@
1212
import numpy as np
1313
import polars as pl
1414
import torch
15-
import wandb
1615
import weave
1716
from openai import AsyncOpenAI
1817
from tqdm import auto as tqdm
1918
from transformers import AutoImageProcessor, AutoTokenizer
2019
from transformers.image_processing_utils import BaseImageProcessor
2120
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
2221
from typing_extensions import Self
23-
from wandb.sdk.wandb_run import Run
2422
from weave.trace.weave_client import WeaveClient
2523

24+
import wandb
2625
from art.utils.deployment import (
2726
DeploymentResult,
2827
Provider,
@@ -45,6 +44,7 @@
4544
)
4645
from art.utils.trajectory_logging import write_trajectory_groups_parquet
4746
from mp_actors import close_proxy, move_to_child_process
47+
from wandb.sdk.wandb_run import Run
4848

4949
from .. import dev
5050
from ..backend import Backend

src/art/utils/benchmarking/load_trajectories.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,13 @@ async def load_trajectories(
105105
return pl.DataFrame()
106106

107107
# Collect all parquet files
108-
all_parquet_files: list[tuple[str, str, str, int]] = [] # (path, model, split, step)
108+
all_parquet_files: list[
109+
tuple[str, str, str, int]
110+
] = [] # (path, model, split, step)
109111

110-
for model_dir in tqdm(model_dirs, desc="Scanning models", unit="model", disable=not debug):
112+
for model_dir in tqdm(
113+
model_dirs, desc="Scanning models", unit="model", disable=not debug
114+
):
111115
model_name = model_dir.name
112116
traj_root = Path(get_trajectories_dir(str(model_dir)))
113117

@@ -121,12 +125,14 @@ async def load_trajectories(
121125
for trajectory_path in split_dir.glob("*.parquet"):
122126
try:
123127
step = int(trajectory_path.stem)
124-
all_parquet_files.append((
125-
str(trajectory_path),
126-
model_name,
127-
split_dir.name,
128-
step,
129-
))
128+
all_parquet_files.append(
129+
(
130+
str(trajectory_path),
131+
model_name,
132+
split_dir.name,
133+
step,
134+
)
135+
)
130136
except ValueError:
131137
continue
132138

@@ -220,7 +226,8 @@ async def load_trajectories(
220226
"content": msg[1],
221227
"tool_calls": msg[2],
222228
"tool_call_id": msg[3],
223-
"trainable": msg[4] is not None, # finish_reason present = trainable
229+
"trainable": msg[4]
230+
is not None, # finish_reason present = trainable
224231
}
225232

226233
# Build processed message

src/art/utils/benchmarking/log_constant_metrics_wandb.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""Utilities for logging constant baseline metrics to Weights & Biases."""
22

3-
import wandb
4-
53
import art
4+
import wandb
65

76

87
async def log_constant_metrics_wandb(

src/art/utils/trajectory_logging.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -82,38 +82,50 @@ def write_trajectory_groups_parquet(
8282
msg = {
8383
"finish_reason": msg.finish_reason,
8484
"index": msg.index,
85-
"message": msg.message.to_dict() if hasattr(msg.message, "to_dict") else msg.message,
85+
"message": msg.message.to_dict()
86+
if hasattr(msg.message, "to_dict")
87+
else msg.message,
8688
}
8789
messages.append(_flatten_message(msg))
8890

89-
rows.append({
90-
"group_index": group_index,
91-
"reward": trajectory.reward,
92-
"metrics": json.dumps(trajectory.metrics) if trajectory.metrics else None,
93-
"metadata": json.dumps(trajectory.metadata) if trajectory.metadata else None,
94-
"tools": json.dumps(trajectory.tools) if trajectory.tools else None,
95-
"logs": trajectory.logs if trajectory.logs else None,
96-
"messages": messages,
97-
})
91+
rows.append(
92+
{
93+
"group_index": group_index,
94+
"reward": trajectory.reward,
95+
"metrics": json.dumps(trajectory.metrics)
96+
if trajectory.metrics
97+
else None,
98+
"metadata": json.dumps(trajectory.metadata)
99+
if trajectory.metadata
100+
else None,
101+
"tools": json.dumps(trajectory.tools) if trajectory.tools else None,
102+
"logs": trajectory.logs if trajectory.logs else None,
103+
"messages": messages,
104+
}
105+
)
98106

99107
# Define schema
100-
message_type = pa.struct([
101-
("role", pa.string()),
102-
("content", pa.string()),
103-
("tool_calls", pa.string()),
104-
("tool_call_id", pa.string()),
105-
("trainable", pa.bool_()),
106-
])
107-
108-
schema = pa.schema([
109-
("group_index", pa.int64()),
110-
("reward", pa.float64()),
111-
("metrics", pa.string()),
112-
("metadata", pa.string()),
113-
("tools", pa.string()),
114-
("logs", pa.list_(pa.string())),
115-
("messages", pa.list_(message_type)),
116-
])
108+
message_type = pa.struct(
109+
[
110+
("role", pa.string()),
111+
("content", pa.string()),
112+
("tool_calls", pa.string()),
113+
("tool_call_id", pa.string()),
114+
("trainable", pa.bool_()),
115+
]
116+
)
117+
118+
schema = pa.schema(
119+
[
120+
("group_index", pa.int64()),
121+
("reward", pa.float64()),
122+
("metrics", pa.string()),
123+
("metadata", pa.string()),
124+
("tools", pa.string()),
125+
("logs", pa.list_(pa.string())),
126+
("messages", pa.list_(message_type)),
127+
]
128+
)
117129

118130
if not rows:
119131
table = pa.table({name: [] for name in schema.names}, schema=schema)
@@ -168,7 +180,9 @@ def read_trajectory_groups_parquet(path: str | Path) -> list[TrajectoryGroup]:
168180
messages_and_choices=messages_and_choices,
169181
reward=row_dict["reward"],
170182
metrics=json.loads(row_dict["metrics"]) if row_dict.get("metrics") else {},
171-
metadata=json.loads(row_dict["metadata"]) if row_dict.get("metadata") else {},
183+
metadata=json.loads(row_dict["metadata"])
184+
if row_dict.get("metadata")
185+
else {},
172186
logs=row_dict.get("logs") or [],
173187
)
174188

src/art/utils/trajectory_migration.py

Lines changed: 67 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -235,53 +235,77 @@ def migrate_jsonl_to_parquet(
235235
if "finish_reason" in msg:
236236
# Choice format - extract inner message, mark as trainable
237237
inner = msg.get("message", {})
238-
messages.append({
239-
"role": inner.get("role"),
240-
"content": inner.get("content"),
241-
"tool_calls": json.dumps(inner.get("tool_calls")) if inner.get("tool_calls") else None,
242-
"tool_call_id": None,
243-
"trainable": True,
244-
})
238+
messages.append(
239+
{
240+
"role": inner.get("role"),
241+
"content": inner.get("content"),
242+
"tool_calls": json.dumps(inner.get("tool_calls"))
243+
if inner.get("tool_calls")
244+
else None,
245+
"tool_call_id": None,
246+
"trainable": True,
247+
}
248+
)
245249
else:
246250
# Regular message
247-
messages.append({
248-
"role": msg.get("role"),
249-
"content": msg.get("content"),
250-
"tool_calls": json.dumps(msg.get("tool_calls")) if msg.get("tool_calls") else None,
251-
"tool_call_id": msg.get("tool_call_id"),
252-
"trainable": False,
253-
})
254-
255-
rows.append({
256-
"group_index": group_index,
257-
"reward": traj.get("reward"),
258-
"metrics": json.dumps(traj.get("metrics")) if traj.get("metrics") else None,
259-
"metadata": json.dumps(traj.get("metadata")) if traj.get("metadata") else None,
260-
"tools": json.dumps(traj.get("tools")) if traj.get("tools") else None,
261-
"logs": traj.get("logs"),
262-
"additional_histories": json.dumps(traj.get("additional_histories")) if traj.get("additional_histories") else None,
263-
"messages": messages,
264-
})
251+
messages.append(
252+
{
253+
"role": msg.get("role"),
254+
"content": msg.get("content"),
255+
"tool_calls": json.dumps(msg.get("tool_calls"))
256+
if msg.get("tool_calls")
257+
else None,
258+
"tool_call_id": msg.get("tool_call_id"),
259+
"trainable": False,
260+
}
261+
)
262+
263+
rows.append(
264+
{
265+
"group_index": group_index,
266+
"reward": traj.get("reward"),
267+
"metrics": json.dumps(traj.get("metrics"))
268+
if traj.get("metrics")
269+
else None,
270+
"metadata": json.dumps(traj.get("metadata"))
271+
if traj.get("metadata")
272+
else None,
273+
"tools": json.dumps(traj.get("tools"))
274+
if traj.get("tools")
275+
else None,
276+
"logs": traj.get("logs"),
277+
"additional_histories": json.dumps(
278+
traj.get("additional_histories")
279+
)
280+
if traj.get("additional_histories")
281+
else None,
282+
"messages": messages,
283+
}
284+
)
265285

266286
# Define the message struct schema
267-
message_type = pa.struct([
268-
("role", pa.string()),
269-
("content", pa.string()),
270-
("tool_calls", pa.string()),
271-
("tool_call_id", pa.string()),
272-
("trainable", pa.bool_()),
273-
])
274-
275-
schema = pa.schema([
276-
("group_index", pa.int64()),
277-
("reward", pa.float64()),
278-
("metrics", pa.string()),
279-
("metadata", pa.string()),
280-
("tools", pa.string()),
281-
("logs", pa.list_(pa.string())),
282-
("additional_histories", pa.string()),
283-
("messages", pa.list_(message_type)),
284-
])
287+
message_type = pa.struct(
288+
[
289+
("role", pa.string()),
290+
("content", pa.string()),
291+
("tool_calls", pa.string()),
292+
("tool_call_id", pa.string()),
293+
("trainable", pa.bool_()),
294+
]
295+
)
296+
297+
schema = pa.schema(
298+
[
299+
("group_index", pa.int64()),
300+
("reward", pa.float64()),
301+
("metrics", pa.string()),
302+
("metadata", pa.string()),
303+
("tools", pa.string()),
304+
("logs", pa.list_(pa.string())),
305+
("additional_histories", pa.string()),
306+
("messages", pa.list_(message_type)),
307+
]
308+
)
285309

286310
# Handle empty case
287311
if not rows:

0 commit comments

Comments
 (0)