Skip to content

Nouamane/lighteval #356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d4e9daf
InitScalingMethod
NouamaneTazi Apr 14, 2025
6e7f0fa
InitScalingMethod
NouamaneTazi Apr 14, 2025
24d07e5
eval
NouamaneTazi Apr 16, 2025
438257a
try adding lightevalrunner to trainer
NouamaneTazi Apr 16, 2025
4f8a350
amend
NouamaneTazi Apr 16, 2025
c9c479d
amend
NouamaneTazi Apr 16, 2025
190a6b9
amend
NouamaneTazi Apr 17, 2025
004a89c
amend
NouamaneTazi Apr 17, 2025
b4cbb55
amend
NouamaneTazi Apr 17, 2025
d39872b
amend
NouamaneTazi Apr 17, 2025
feb818a
.
NouamaneTazi Apr 17, 2025
025f314
amend
NouamaneTazi Apr 17, 2025
abe75af
amend
NouamaneTazi Apr 17, 2025
bd50c66
.
NouamaneTazi Apr 17, 2025
2227432
qos to low
eliebak Apr 17, 2025
b62cacd
add nanotron_path
eliebak Apr 17, 2025
802fad6
some fix: logs, and config
eliebak Apr 17, 2025
895354a
cp instead of sync
eliebak Apr 17, 2025
55a5d3e
eval_interval
NouamaneTazi Apr 17, 2025
298492e
serialize sanity checks
NouamaneTazi Apr 17, 2025
4219ec8
add output dir and s3_save path in the config
eliebak Apr 17, 2025
f1780ec
add output dir and s3_save path in the config
eliebak Apr 17, 2025
016760e
fix s3 only if define
eliebak Apr 17, 2025
85138ca
fixes
NouamaneTazi Apr 17, 2025
0390de2
Merge branch 'nouamane/lighteval' of https://github.com/huggingface/n…
NouamaneTazi Apr 17, 2025
fefb560
add requeue
eliebak Apr 17, 2025
4558036
add wandb with lighteval and fix eval interval
eliebak Apr 18, 2025
17b5284
Merge branch 'nouamane/lighteval' of github.com:huggingface/nanotron …
eliebak Apr 18, 2025
b5ea942
fix this little space :(
eliebak Apr 20, 2025
561ca6b
folder_path should always have s3 when using s3 (fix consumed tokens …
NouamaneTazi Apr 23, 2025
dc6edaa
Merge branch 'dev' of https://github.com/huggingface/nanotron into no…
NouamaneTazi Apr 24, 2025
7724cf1
config qwen
NouamaneTazi Apr 24, 2025
46949b6
.
NouamaneTazi Apr 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions examples/config_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"410m": (24, 1024, 16, 16, 4096), # ~410M params
# Small to medium models
"1b": (16, 2048, 16, 16, 5632), # ~1B params
"3b": (28, 2048, 16, 2, 11008), # ~3B params
"3b": (36, 2048, 16, 4, 11008), # ~3B params
# Standard sizes
"7b": (32, 4096, 32, 32, 11008), # ~7B params
"13b": (40, 5120, 40, 40, 13824), # ~13B params
Expand All @@ -47,7 +47,7 @@ def get_args():
parser.add_argument(
"--model",
choices=MODEL_SIZES.keys(),
default="custom",
default="3b",
help="Model size to generate config for (e.g., 7b, 13b)",
)
parser.add_argument(
Expand Down Expand Up @@ -76,6 +76,10 @@ def get_args():
tokens_group.add_argument("--mbs", type=int, default=3, help="Micro batch size")
tokens_group.add_argument("--acc", type=int, default=1, help="Batch accumulation per replica")

# checkpoints
checkpoints_group = parser.add_argument_group("checkpoints")
checkpoints_group.add_argument("--ckpt-save", type=int, default=10, help="Checkpoint save interval")

args = parser.parse_args()
return args

Expand Down Expand Up @@ -108,7 +112,7 @@ def get_model_config(model_size: str) -> Qwen2Config:
is_qwen2_config=True,
pad_token_id=None,
_attn_implementation="flash_attention_2",
# sliding_window_size=20,
_use_doc_masking=True,
)


Expand Down Expand Up @@ -154,7 +158,7 @@ def calculate_parameters(model_config: Qwen2Config) -> str:

def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config:
learning_rate = LRSchedulerArgs(
learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5
learning_rate=3e-4, lr_warmup_steps=2000, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=0
)
parallelism = ParallelismArgs(
dp=args.dp,
Expand All @@ -175,7 +179,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config
)
optimizer = OptimizerArgs(
zero_stage=args.zero,
weight_decay=0.01,
weight_decay=0.1,
clip_grad=1.0,
accumulate_grad_in_fp32=True,
learning_rate_scheduler=learning_rate,
Expand All @@ -192,7 +196,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config

return Config(
general=GeneralArgs(project="debug", run=args.run, seed=seed, ignore_sanity_checks=args.no_sanity),
checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10),
checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=args.ckpt_save),
parallelism=parallelism,
model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config),
# tokenizer=TokenizerArgs("HuggingFaceTB/cosmo2-tokenizer"),
Expand All @@ -219,7 +223,11 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config
world_size = args.dp * args.tp * args.pp * args.cp
if world_size <= 8:
print(
f"CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={world_size} run_train.py --config-file {args.out}"
f"ENABLE_TIMERS=1 DEBUG_CPU=1 STATS_SAMPLING_INTERVAL_IN_SEC=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={world_size} run_train.py --config-file {args.out}"
)
print("You can also use environment variables for more debugging:")
print(" - ENABLE_TIMERS=1: Enable detailed timing information")
print(" - DEBUG_CPU=1: Log CPU and memory usage statistics")
print(" - STATS_SAMPLING_INTERVAL_IN_SEC=1: Set sampling interval for metrics collection")
else:
print("Checkout slurm_launcher.py to launch a multi-node job")
24 changes: 12 additions & 12 deletions examples/config_qwen.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
checkpoints:
checkpoint_interval: 10
checkpoint_interval: 100000
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
load_lr_scheduler: true
Expand Down Expand Up @@ -30,9 +30,9 @@ data_stages:
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: false
ignore_sanity_checks: true
project: debug
run: qwen_20250423_201000_16423158
run: qwen_20250424_120835_16423158
seed: 42
step: null
lighteval: null
Expand All @@ -50,24 +50,24 @@ model:
make_vocab_size_divisible_by: 1
model_config:
_attn_implementation: flash_attention_2
_fused_rms_norm: false
_fused_rotary_emb: false
_use_doc_masking: false
_use_qkv_packed: false
_fused_rms_norm: true
_fused_rotary_emb: true
_use_doc_masking: true
_use_qkv_packed: true
attention_bias: false
bos_token_id: 1
eos_token_id: 2
flex_attention_mask: null
hidden_act: silu
hidden_size: 256
hidden_size: 2048
initializer_range: 0.02
intermediate_size: 768
intermediate_size: 11008
is_qwen2_config: true
max_position_embeddings: 4096
moe_config: null
no_rope_layer: null
num_attention_heads: 4
num_hidden_layers: 12
num_attention_heads: 16
num_hidden_layers: 36
num_key_value_heads: 4
pad_token_id: null
pretraining_tp: 1
Expand Down Expand Up @@ -108,7 +108,7 @@ parallelism:
pp: 1
pp_engine: 1f1b
recompute_layer: false
tp: 1
tp: 2
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
tp_recompute_allgather: true
Expand Down
15 changes: 13 additions & 2 deletions src/nanotron/config/lighteval_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,13 @@ class LightEvalConfig:
logging: Optional[LightEvalLoggingArgs] = None
wandb: Optional[LightEvalWandbLoggerConfig] = None
slurm: Optional[LightEvalSlurm] = None
s3_save_path: Optional[str] = None # should not be dependent of the run_name
output_dir: Optional[str] = None # we should sanity check that it's the same as the one in the eval_config_override
s3_save_path: Optional[str] = None # should not be dependent of the run_name
upload_to_wandb: Optional[bool] = False
wandb_project: Optional[str] = None
wandb_entity: Optional[str] = None
output_dir: Optional[
str
] = None # we should sanity check that it's the same as the one in the eval_config_override
nanotron_path: Optional[str] = "./"
eval_config_override: str = None
eval_config_override: Path = None # Previously hardcoded in run_slurm_one_job
Expand All @@ -127,6 +132,12 @@ def __post_init__(self):
if self.slurm is None:
self.slurm = LightEvalSlurm()
self.local_checkpoint_dir = str(Path(self.local_checkpoint_dir).expanduser())
if self.upload_to_wandb:
assert (
self.s3_save_path is not None
), " We should have a s3_save_path if we want to upload to wandb" # todo: add the option to read from local folder i guess
assert self.wandb_project is not None, "wandb_project must be specified if upload_to_wandb is True"
assert self.wandb_entity is not None, "wandb_entity must be specified if upload_to_wandb is True"
if self.eval_interval_file is not None and Path(self.eval_interval_file).exists():
logger.warning(
f"Eval interval file {self.eval_interval_file} exists. `eval_interval` will be replaced by the value in the file upon the next evaluation. You should probably delete this file if that's not what you want."
Expand Down
7 changes: 4 additions & 3 deletions src/nanotron/data/tokenized_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,12 +369,13 @@ def __init__(
)
from datatrove.utils.dataset import url_to_fs

fs_folder, folder_path = url_to_fs(folder_path)
fs_folder, stripped_folder_path = url_to_fs(folder_path)
matched_files = (
fs_folder.find(folder_path, detail=False, maxdepth=1 if not recursive else None)
fs_folder.find(stripped_folder_path, detail=False, maxdepth=1 if not recursive else None)
if not filename_pattern
else fs_folder.glob(
os.path.join(folder_path, filename_pattern), maxdepth=1 if not recursive else None
os.path.join(stripped_folder_path, filename_pattern),
maxdepth=1 if not recursive else None,
)
)
matched_files = sorted(matched_files)
Expand Down
40 changes: 31 additions & 9 deletions src/nanotron/eval/one_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,18 @@ def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]:
logger.warning(
f"Lighteval Runner got {len(uploaded_files)} files. Using {checkpoint_path} as checkpoint path."
)

slurm_job_id, slurm_log = run_slurm_one_job(
config=self.config,
lighteval_config=self.lighteval_config,
model_checkpoint_path=checkpoint_path,
current_step=self.config.general.step,
)
if self.config.general.step % self.lighteval_config.eval_interval == 0:
slurm_job_id, slurm_log = run_slurm_one_job(
config=self.config,
lighteval_config=self.lighteval_config,
model_checkpoint_path=checkpoint_path,
current_step=self.config.general.step,
)
else:
logger.warning(
f"Skipping evaluation at step {self.config.general.step} because it's not a multiple of {self.lighteval_config.eval_interval}"
)
return None, None

return slurm_job_id, slurm_log

Expand Down Expand Up @@ -130,7 +135,8 @@ def run_slurm_one_job(
#SBATCH --exclusive
#SBATCH --qos={slurm_config.qos}
#SBATCH --time={slurm_config.time}
#SBATCH --output={eval_logs_path}/%j-{timestamp}.out"""
#SBATCH --output={eval_logs_path}/%j-{timestamp}.out
#SBATCH --requeue"""

if slurm_config.reservation:
slurm_script += f"\n#SBATCH --reservation={slurm_config.reservation}"
Expand Down Expand Up @@ -250,7 +256,23 @@ def run_slurm_one_job(
--cache-dir {slurm_config.hf_cache}"""
if lighteval_config.output_dir is not None and lighteval_config.s3_save_path is not None:
slurm_script += f"""
s5cmd cp --if-size-differ "{lighteval_config.output_dir}*" {lighteval_config.s3_save_path}
s5cmd cp --if-size-differ "{lighteval_config.output_dir}*" {lighteval_config.s3_save_path}/
"""
if lighteval_config.upload_to_wandb:
gbs_tok = (
config.parallelism.dp
* config.tokens.micro_batch_size
* config.tokens.sequence_length
* config.tokens.batch_accumulation_per_replica
)
slurm_script += f"""
python {nanotron_path}/src/nanotron/eval/upload_to_wandb.py \\
--wandb_project {lighteval_config.wandb_project} \\
--wandb_entity {lighteval_config.wandb_entity} \\
--model_name {general_run_name} \\
--results_path {lighteval_config.s3_save_path}/results/results/{general_run_name}/{current_step}/ \\
--train_step {current_step} \\
--consumed_tokens {current_step*gbs_tok}
"""
slurm_script += """
echo "Cleaning up downloaded checkpoints..."
Expand Down
87 changes: 87 additions & 0 deletions src/nanotron/eval/upload_to_wandb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import json
import s3fs
import wandb
import re
import argparse
from wandb.sdk.lib.runid import generate_id


def push_to_wandb(wandb_project, wandb_entity, model_name, results_path, train_step, consumed_tokens):
s3 = s3fs.S3FileSystem(anon=False)
all_metrics = {
# basic X axis replacements for all metrics
"consumed_tokens": consumed_tokens,
"train_step": train_step,
}

for result_file in sorted(s3.ls(results_path)):
if not result_file.endswith(".json"):
continue

with s3.open(result_file, "r") as f:
results = json.loads(f.read())["results"]

for benchmark, metrics in results.items():
if benchmark == "all":
continue

# extract dataset and config name
match = re.search(r"\|(.*?)(?::(.*?))?\|", benchmark)
if match:
dataset, subtask = match.groups()

for metric_name, metric_value in metrics.items():
if "_stderr" in metric_name:
continue
# wandb-friendly metric name
wandb_metric = f"{dataset}/{subtask}/{metric_name}" if subtask else f"{dataset}/{metric_name}"
all_metrics[wandb_metric] = metric_value

run_id = f"{model_name}-{generate_id()}"

# try to find the run in wandb and resume it
api = wandb.Api()
runs = api.runs(f"{wandb_entity}/{wandb_project}")
for run in runs:
if run.name == model_name:
run_id = run.id
break

wandb.init(
project=wandb_project,
entity=wandb_entity,
name=model_name,
id=run_id,
config={
"model_name": model_name,
},
resume="allow",
)

# log all metrics for this checkpoint
wandb.log(all_metrics)

wandb.finish()

if __name__ == "__main__":
# Setup argument parser
parser = argparse.ArgumentParser(description="Upload evaluation results to Weights & Biases.")
parser.add_argument("--wandb_project", type=str, required=True, help="WandB project name.")
parser.add_argument("--wandb_entity", type=str, required=True, help="WandB entity name.")
parser.add_argument("--model_name", type=str, required=True, help="Name of the model.")
parser.add_argument("--results_path", type=str, required=True, help="S3 path to the results directory.")
parser.add_argument("--train_step", type=int, required=True, help="Training step corresponding to the checkpoint.")
parser.add_argument("--consumed_tokens", type=int, required=True, help="Total consumed tokens up to this checkpoint.")

# Parse arguments
args = parser.parse_args()

# Call the main function with parsed arguments
push_to_wandb(
wandb_project=args.wandb_project,
wandb_entity=args.wandb_entity,
model_name=args.model_name,
results_path=args.results_path,
train_step=args.train_step,
consumed_tokens=args.consumed_tokens
)
Loading