Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions miles/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,30 @@ def add_wandb_arguments(parser):
parser.add_argument("--wandb-run-id", type=str, default=None)
return parser

# mlflow
def add_mlflow_arguments(parser):
parser.add_argument("--use-mlflow", action="store_true", default=False)
parser.add_argument(
"--mlflow-tracking-uri",
type=str,
default=None,
help="MLflow tracking server URI. Defaults to MLFLOW_TRACKING_URI env var, or local mlruns/ directory.",
)
parser.add_argument(
"--mlflow-experiment-name",
type=str,
default="miles",
help="MLflow experiment name.",
)
parser.add_argument(
"--mlflow-run-name",
type=str,
default=None,
help="MLflow run name. Defaults to --wandb-group if not set.",
)
parser.add_argument("--mlflow-run-id", type=str, default=None)
return parser

# tensorboard
def add_tensorboard_arguments(parser):
# tb_project_name, tb_experiment_name
Expand Down Expand Up @@ -1397,6 +1421,7 @@ def add_sglang_tp_size():
parser = add_eval_arguments(parser)
parser = add_algo_arguments(parser)
parser = add_wandb_arguments(parser)
parser = add_mlflow_arguments(parser)
parser = add_tensorboard_arguments(parser)
parser = add_router_arguments(parser)
parser = add_debug_arguments(parser)
Expand Down
1 change: 1 addition & 0 deletions miles/utils/tracking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .base import TrackingBackend, TrackingManager, BACKEND_REGISTRY
137 changes: 137 additions & 0 deletions miles/utils/tracking/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
Shared tracking interface for experiment logging backends.

Each backend implements ``init / log / finish``, and :class:`TrackingManager` fans out
calls to every active backend.

To add a new backend:
--------------------
1. Subclass :class:`TrackingBackend`.
2. Register it in :data:`BACKEND_REGISTRY`.
3. Add a corresponding ``--use-<name>`` CLI flag in ``arguments.py``.
"""

from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from typing import Any

logger = logging.getLogger(__name__)



class TrackingBackend(ABC):
# Interface every logging backend must satisfy.

@abstractmethod
def init(self, args, *, primary: bool = True, **kwargs) -> None:
...

@abstractmethod
def log(self, metrics: dict[str, Any], step: int | None = None) -> None:
...

@abstractmethod
def finish(self) -> None:
...


# Thin adapters for backwards compatibility to keep wandb_utils and tensorboard_utils untouched.
class WandbBackend(TrackingBackend):
# Delegates to the existing ``wandb_utils`` helpers.

def init(self, args, *, primary: bool = True, **kwargs) -> None:
from . import wandb_utils

if primary:
wandb_utils.init_wandb_primary(args, **kwargs)
else:
wandb_utils.init_wandb_secondary(args, **kwargs)

def log(self, metrics: dict[str, Any], step: int | None = None) -> None:
import wandb

wandb.log(metrics)

def finish(self) -> None:
import wandb

wandb.finish()


class TensorboardBackend(TrackingBackend):
_adapter = None

def init(self, args, *, primary: bool = True, **kwargs) -> None:
from .tensorboard_utils import _TensorboardAdapter

self._adapter = _TensorboardAdapter(args)

def log(self, metrics: dict[str, Any], step: int | None = None) -> None:
if self._adapter is not None:
# Strip step-key entries (e.g. "train/step", "rollout/step") —
# tensorboard receives step as an explicit argument instead.
data = {k: v for k, v in metrics.items() if not k.endswith("/step")}
self._adapter.log(data=data, step=step)

def finish(self) -> None:
if self._adapter is not None:
self._adapter.finish()


class MlflowBackend(TrackingBackend):

def init(self, args, *, primary: bool = True, **kwargs) -> None:
from . import mlflow_utils

mlflow_utils.init_mlflow(args, primary=primary, **kwargs)

def log(self, metrics: dict[str, Any], step: int | None = None) -> None:
from . import mlflow_utils

mlflow_utils.log_metrics(metrics, step=step)

def finish(self) -> None:
from . import mlflow_utils

mlflow_utils.finish()


# Registry that maps backend name → (class, args-flag attribute)

BACKEND_REGISTRY: dict[str, tuple[type[TrackingBackend], str]] = {
"wandb": (WandbBackend, "use_wandb"),
"tensorboard": (TensorboardBackend, "use_tensorboard"),
"mlflow": (MlflowBackend, "use_mlflow"),
}


class TrackingManager:
#Initialises and logs to every enabled backend; used internally by ``tracking_utils``.

def __init__(self) -> None:
self._backends: list[TrackingBackend] = []

def init(self, args, *, primary: bool = True, **kwargs) -> None:
for name, (cls, flag) in BACKEND_REGISTRY.items():
if getattr(args, flag, False):
logger.info("Initialising tracking backend: %s", name)
backend = cls()
backend.init(args, primary=primary, **kwargs)
self._backends.append(backend)

def log(self, metrics: dict[str, Any], step: int | None = None) -> None:
for backend in self._backends:
backend.log(metrics, step=step)

def finish(self) -> None:
for backend in self._backends:
try:
backend.finish()
except Exception:
logger.exception(
"Error finishing tracking backend %s",
type(backend).__name__,
)
self._backends.clear()
135 changes: 135 additions & 0 deletions miles/utils/tracking/mlflow_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""
MLflow tracking backend for slime.


MLflow docs for future reference:
- Tracking overview : https://mlflow.org/docs/latest/ml/tracking/
- Python API : https://mlflow.org/docs/latest/python_api/mlflow.html
- Remote tracking : https://mlflow.org/docs/latest/tracking/server.html
"""

from __future__ import annotations

import logging
import os
import re
from copy import deepcopy
from typing import Any

logger = logging.getLogger(__name__)


# Helpers/utils
def _sanitize_key(key: str) -> str:
return re.sub(r"[^a-zA-Z0-9_\-./\s]", "_", key)


def _compute_config_for_logging(args) -> dict[str, str]:
# Build a flat param dict from *args*, mirroring ``wandb_utils._compute_config_for_logging``."""
raw = deepcopy(args.__dict__)

whitelist_env_vars = ["SLURM_JOB_ID"]
raw["env_vars"] = {k: v for k, v in os.environ.items() if k in whitelist_env_vars}

return _flatten_dict(raw)


def _flatten_dict(d: dict, parent_key: str = "", sep: str = ".") -> dict[str, str]:
# Recursively flatten nested dicts into ``dotted.key`` → ``str(value)`` pairs.
items: list[tuple[str, str]] = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(_flatten_dict(v, new_key, sep).items())
else:
items.append((new_key, str(v)))
return dict(items)


def init_mlflow(args, *, primary: bool = True, **kwargs) -> None:
if not args.use_mlflow:
args.mlflow_run_id = None
return

import mlflow

tracking_uri = args.mlflow_tracking_uri or os.environ.get("MLFLOW_TRACKING_URI")
if tracking_uri:
mlflow.set_tracking_uri(tracking_uri)
logger.info("MLflow tracking URI: %s", tracking_uri)

experiment_name = args.mlflow_experiment_name
mlflow.set_experiment(experiment_name)

if primary:
_init_mlflow_primary(args, experiment_name)
else:
_init_mlflow_secondary(args)


def _init_mlflow_primary(args, experiment_name: str) -> None:
import mlflow

run_name = args.mlflow_run_name or args.wandb_group

tags = {}
slurm_job_id = os.environ.get("SLURM_JOB_ID")
if slurm_job_id:
tags["slurm_job_id"] = slurm_job_id
tags["rank"] = str(args.rank)

run = mlflow.start_run(run_name=run_name, tags=tags)
mlflow.log_params(_compute_config_for_logging(args))

args.mlflow_run_id = run.info.run_id
logger.info("MLflow run started: %s (experiment=%s, name=%s)", run.info.run_id, experiment_name, run_name)


def _init_mlflow_secondary(args) -> None:
"""Attach to an existing MLflow run created by the primary rank."""
import mlflow

run_id = args.mlflow_run_id or os.environ.get("MLFLOW_RUN_ID")
if run_id is None:
return

mlflow.start_run(run_id=run_id)
logger.info("MLflow secondary attached to run: %s", run_id)


# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------

def log_metrics(metrics: dict[str, Any], step: int | None = None) -> None:
import mlflow

if mlflow.active_run() is None:
return

sanitized: dict[str, float] = {}
for k, v in metrics.items():
if k.endswith("/step"):
continue
try:
sanitized[_sanitize_key(k)] = float(v)
except (TypeError, ValueError):
continue

if sanitized:
mlflow.log_metrics(sanitized, step=int(step) if step is not None else None)


# ---------------------------------------------------------------------------
# Cleanup
# ---------------------------------------------------------------------------

def finish() -> None:
import mlflow

if mlflow.active_run() is None:
return

run_id = mlflow.active_run().info.run_id
mlflow.end_run()
logger.info("MLflow run ended: %s", run_id)
File renamed without changes.
21 changes: 8 additions & 13 deletions miles/utils/tracking_utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
import wandb
from miles.utils.tensorboard_utils import _TensorboardAdapter
from .tracking import TrackingManager

from . import wandb_utils
_manager = TrackingManager()


def init_tracking(args, primary: bool = True, **kwargs):
if primary:
wandb_utils.init_wandb_primary(args, **kwargs)
else:
wandb_utils.init_wandb_secondary(args, **kwargs)
_manager.init(args, primary=primary, **kwargs)


# TODO further refactor, e.g. put TensorBoard init to the "init" part
def log(args, metrics, step_key: str):
if args.use_wandb:
wandb.log(metrics)
step = metrics.get(step_key)
_manager.log(metrics, step=step)

if args.use_tensorboard:
metrics_except_step = {k: v for k, v in metrics.items() if k != step_key}
_TensorboardAdapter(args).log(data=metrics_except_step, step=metrics[step_key])

def finish_tracking():
_manager.finish()
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def get_tag(self):
extras_require={
"fsdp": [
"torch>=2.0",
]
],
"mlflow": [
"mlflow",
],
},
python_requires=">=3.10",
classifiers=[
Expand Down