Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions examples/qwen2.5_nemo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
## Run instructions

### Update run.sh
- Any changes to hyperparams used during training can be passed via flags to `model/train.py`. More details about flags are in `train.py`.
- All data processing in this example is in `model/data.py`. Repace this, or add data modules that reflect your dataset. Also remember to update `train.py` to use the correct data module.

### Launch run

```
truss train push config.py
```

Upon successful submission, the CLI will output helpful information about your job, including the job-id to track your run.
39 changes: 39 additions & 0 deletions examples/qwen2.5_nemo/training/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from truss_train import definitions
from truss.base import truss_config

BASE_IMAGE = "nvcr.io/nvidia/nemo:25.07"

training_runtime = definitions.Runtime(
start_commands=[
"/bin/sh -c 'chmod +x ./run.sh && ./run.sh'"
],
environment_variables={
"HF_TOKEN": definitions.SecretReference(name="hf_access_token"),
# "WANDB_API_KEY": definitions.SecretReference(name="wandb_api_key"),
},
cache_config=definitions.CacheConfig(
enabled=True,
),
checkpointing_config=definitions.CheckpointingConfig(
enabled=True,
),
)

training_compute = definitions.Compute(
accelerator=truss_config.AcceleratorSpec(
accelerator=truss_config.Accelerator.H100,
count=8,
),
node_count=1,
)

training_job = definitions.TrainingJob(
image=definitions.Image(base_image=BASE_IMAGE),
compute=training_compute,
runtime=training_runtime
)

training_project = definitions.TrainingProject(
name="Nemo-qwen2.5-nemo 1node",
job=training_job
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a basic linter to this repo and clean up all the newline ends and formating in a seprate PR? check w Nico what's preferred for our repos.

147 changes: 147 additions & 0 deletions examples/qwen2.5_nemo/training/model/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import json
import shutil
from typing import TYPE_CHECKING, Any, Dict, List, Optional

# import numpy as np
from datasets import load_dataset

from nemo.collections.llm.gpt.data.core import get_dataset_root
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule
from nemo.lightning.io.mixin import IOMixin
from nemo.utils import logging

from functools import lru_cache

from nemo.collections.llm.gpt.data.core import create_sft_dataset

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec
from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs


class BespokeDataModule(FineTuningDataModule, IOMixin):
"""A data module for fine-tuning on the Bespoke dataset.

This class inherits from the `FineTuningDataModule` class and is specifically designed for fine-tuning models on the
"bespokelabs/Bespoke-Stratos-17k" dataset. It handles data download, preprocessing, splitting, and preparing the data
in a format suitable for training, validation, and testing.

Args:
force_redownload (bool, optional): Whether to force re-download the dataset even if it exists locally. Defaults to False.
delete_raw (bool, optional): Whether to delete the raw downloaded dataset after preprocessing. Defaults to True.
See FineTuningDataModule for the other args
"""

def __init__(
self,
seq_length: int = 2048,
tokenizer: Optional["TokenizerSpec"] = None,
micro_batch_size: int = 4,
global_batch_size: int = 8,
rampup_batch_size: Optional[List[int]] = None,
force_redownload: bool = False,
delete_raw: bool = True,
seed: int = 1234,
memmap_workers: int = 1,
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = False,
packed_sequence_specs: Optional["PackedSequenceSpecs"] = None,
dataset_kwargs: Optional[Dict[str, Any]] = None,
dataset_root: str = "./bespoke",
):
self.force_redownload = force_redownload
self.delete_raw = delete_raw

super().__init__(
dataset_root=dataset_root,
seq_length=seq_length,
tokenizer=tokenizer,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
rampup_batch_size=rampup_batch_size,
seed=seed,
memmap_workers=memmap_workers,
# num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
packed_sequence_specs=packed_sequence_specs,
dataset_kwargs=dataset_kwargs,
)

def prepare_data(self) -> None:
# if train file is specified, no need to do anything
if not self.train_path.exists() or self.force_redownload:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit error prone for partial download failures. it'd just delegate to the huggingface load_dataset logic and rely on it for skipping download.

dset = self._download_data()
self._preprocess_and_split_data(dset)
super().prepare_data()

def _download_data(self):
logging.info(f"Downloading {self.__class__.__name__}...")
return load_dataset(
"bespokelabs/Bespoke-Stratos-17k",
cache_dir=str(self.dataset_root),
download_mode="force_redownload" if self.force_redownload else None,
)

def _preprocess_and_split_data(self, dset, train_ratio: float = 0.80, val_ratio: float = 0.15):
logging.info(f"Preprocessing {self.__class__.__name__} to jsonl format and splitting...")
test_ratio = 1 - train_ratio - val_ratio
save_splits = {}
dataset = dset.get('train')
split_dataset = dataset.train_test_split(test_size=val_ratio + test_ratio, seed=self.seed)
split_dataset2 = split_dataset['test'].train_test_split(
test_size=test_ratio / (val_ratio + test_ratio), seed=self.seed
)
save_splits['training'] = split_dataset['train']
save_splits['validation'] = split_dataset2['train']
save_splits['test'] = split_dataset2['test']

print("len training: ", len(save_splits['training']))
print("len validation: ", len(save_splits['validation']))
print("len test: ", len(save_splits['test']))

for split_name, dataset in save_splits.items():
output_file = self.dataset_root / f"{split_name}.jsonl"
with output_file.open("w", encoding="utf-8") as f:
for example in dataset:

conversations = example["conversations"]

for conversation in conversations:
if conversation["from"] == "user":
conversation["from"] = "User"
elif conversation["from"] == "assistant":
conversation["from"] = "Assistant"
else:
raise ValueError(f"Unknown role: {conversation['role']}")

example["mask"] = "User"
example["type"] = "VALUE_TO_TEXT"

f.write(json.dumps(example) + "\n")

logging.info(f"{split_name} split saved to {output_file}")

if self.delete_raw:
for p in self.dataset_root.iterdir():
if p.is_dir():
shutil.rmtree(p)
elif '.jsonl' not in str(p.name):
p.unlink()

@lru_cache
def _create_dataset(self, path, pack_metadata_path=None, is_test=False, **kwargs):
# pylint: disable=C0115,C0116
return create_sft_dataset(
path,
tokenizer=self.tokenizer,
seq_length=(self.seq_length if is_test or self.packed_sequence_size <= 0 else self.packed_sequence_size),
memmap_workers=self.memmap_workers,
seed=self.seed,
chat=True,
is_test=is_test,
pack_metadata_file_path=None, # packing is not supported
pad_cu_seqlens=False,
**kwargs,
)
102 changes: 102 additions & 0 deletions examples/qwen2.5_nemo/training/model/import_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import argparse
import nemo_run as run
import lightning.pytorch as pl
from nemo.collections import llm
import torch

def parse_args():
parser = argparse.ArgumentParser(description="Convert Hugging Face checkpoint to NeMo format")

# Model ID configuration
parser.add_argument(
"--model_id",
type=str,
default="Qwen/Qwen2.5-7B-Instruct",
help="Source path for the model (e.g., model-name or local path)"
)

# Overwrite option
parser.add_argument(
"--overwrite",
action="store_true",
default=True,
help="Whether to overwrite existing checkpoint"
)

# Executor type
parser.add_argument(
"--executor",
type=str,
default="local",
choices=["local"],
help="Executor type to use"
)

# Output directory
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Output directory for converted checkpoint"
)

return parser.parse_args()

def get_model_config(model_name):
"""Get model configuration based on model name"""
model_configs = {
"Qwen/Qwen2.5-7B-Instruct": llm.qwen2_7b.model(),
}

if model_name not in model_configs:
raise ValueError(f"Unsupported model: {model_name}")

return model_configs[model_name]

def configure_checkpoint_conversion(args):
"""Configure checkpoint conversion with command line arguments"""
model_config = get_model_config(args.model_id)

conversion_config = {
"model": model_config,
"source": f"hf://{args.model_id}",
"overwrite": args.overwrite,
}

# Add output directory if specified
if args.output_dir:
conversion_config["output_dir"] = args.output_dir

return run.Partial(llm.import_ckpt, **conversion_config)

def get_executor(executor_type):
"""Get executor based on type"""
if executor_type == "local":
return run.LocalExecutor()
else:
raise ValueError(f"Unsupported executor type: {executor_type}")

def main():
# Parse command line arguments
args = parse_args()

# Print CUDA information
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
print(f"Model/Source: {args.model_id}")
print(f"Overwrite: {args.overwrite}")
print(f"Executor: {args.executor}")

# Configure checkpoint conversion
import_ckpt = configure_checkpoint_conversion(args)

# Define executor
executor = get_executor(args.executor)

# Run experiment
print("Starting checkpoint conversion...")
run.run(import_ckpt, executor=executor)
print("Checkpoint conversion completed!")

if __name__ == "__main__":
main()
Loading