-
Couldn't load subscription status.
- Fork 0
SFT Nemo - Qwen #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
nidhihrmth
wants to merge
3
commits into
main
Choose a base branch
from
nh/nemo_qwen2.5
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
SFT Nemo - Qwen #17
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| ## Run instructions | ||
|
|
||
| ### Update run.sh | ||
| - Any changes to hyperparams used during training can be passed via flags to `model/train.py`. More details about flags are in `train.py`. | ||
| - All data processing in this example is in `model/data.py`. Repace this, or add data modules that reflect your dataset. Also remember to update `train.py` to use the correct data module. | ||
|
|
||
| ### Launch run | ||
|
|
||
| ``` | ||
| truss train push config.py | ||
| ``` | ||
|
|
||
| Upon successful submission, the CLI will output helpful information about your job, including the job-id to track your run. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| from truss_train import definitions | ||
| from truss.base import truss_config | ||
|
|
||
| BASE_IMAGE = "nvcr.io/nvidia/nemo:25.07" | ||
|
|
||
| training_runtime = definitions.Runtime( | ||
| start_commands=[ | ||
| "/bin/sh -c 'chmod +x ./run.sh && ./run.sh'" | ||
| ], | ||
| environment_variables={ | ||
| "HF_TOKEN": definitions.SecretReference(name="hf_access_token"), | ||
| # "WANDB_API_KEY": definitions.SecretReference(name="wandb_api_key"), | ||
| }, | ||
| cache_config=definitions.CacheConfig( | ||
| enabled=True, | ||
| ), | ||
| checkpointing_config=definitions.CheckpointingConfig( | ||
| enabled=True, | ||
| ), | ||
| ) | ||
|
|
||
| training_compute = definitions.Compute( | ||
| accelerator=truss_config.AcceleratorSpec( | ||
| accelerator=truss_config.Accelerator.H100, | ||
| count=8, | ||
| ), | ||
| node_count=1, | ||
| ) | ||
|
|
||
| training_job = definitions.TrainingJob( | ||
| image=definitions.Image(base_image=BASE_IMAGE), | ||
| compute=training_compute, | ||
| runtime=training_runtime | ||
| ) | ||
|
|
||
| training_project = definitions.TrainingProject( | ||
| name="Nemo-qwen2.5-nemo 1node", | ||
| job=training_job | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| import json | ||
| import shutil | ||
| from typing import TYPE_CHECKING, Any, Dict, List, Optional | ||
|
|
||
| # import numpy as np | ||
| from datasets import load_dataset | ||
|
|
||
| from nemo.collections.llm.gpt.data.core import get_dataset_root | ||
| from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule | ||
| from nemo.lightning.io.mixin import IOMixin | ||
| from nemo.utils import logging | ||
|
|
||
| from functools import lru_cache | ||
|
|
||
| from nemo.collections.llm.gpt.data.core import create_sft_dataset | ||
|
|
||
| if TYPE_CHECKING: | ||
| from nemo.collections.common.tokenizers import TokenizerSpec | ||
| from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs | ||
|
|
||
|
|
||
| class BespokeDataModule(FineTuningDataModule, IOMixin): | ||
| """A data module for fine-tuning on the Bespoke dataset. | ||
|
|
||
| This class inherits from the `FineTuningDataModule` class and is specifically designed for fine-tuning models on the | ||
| "bespokelabs/Bespoke-Stratos-17k" dataset. It handles data download, preprocessing, splitting, and preparing the data | ||
| in a format suitable for training, validation, and testing. | ||
|
|
||
| Args: | ||
| force_redownload (bool, optional): Whether to force re-download the dataset even if it exists locally. Defaults to False. | ||
| delete_raw (bool, optional): Whether to delete the raw downloaded dataset after preprocessing. Defaults to True. | ||
| See FineTuningDataModule for the other args | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| seq_length: int = 2048, | ||
| tokenizer: Optional["TokenizerSpec"] = None, | ||
| micro_batch_size: int = 4, | ||
| global_batch_size: int = 8, | ||
| rampup_batch_size: Optional[List[int]] = None, | ||
| force_redownload: bool = False, | ||
| delete_raw: bool = True, | ||
| seed: int = 1234, | ||
| memmap_workers: int = 1, | ||
| num_workers: int = 8, | ||
| pin_memory: bool = True, | ||
| persistent_workers: bool = False, | ||
| packed_sequence_specs: Optional["PackedSequenceSpecs"] = None, | ||
| dataset_kwargs: Optional[Dict[str, Any]] = None, | ||
| dataset_root: str = "./bespoke", | ||
| ): | ||
| self.force_redownload = force_redownload | ||
| self.delete_raw = delete_raw | ||
|
|
||
| super().__init__( | ||
| dataset_root=dataset_root, | ||
| seq_length=seq_length, | ||
| tokenizer=tokenizer, | ||
| micro_batch_size=micro_batch_size, | ||
| global_batch_size=global_batch_size, | ||
| rampup_batch_size=rampup_batch_size, | ||
| seed=seed, | ||
| memmap_workers=memmap_workers, | ||
| # num_workers=num_workers, | ||
| pin_memory=pin_memory, | ||
| persistent_workers=persistent_workers, | ||
| packed_sequence_specs=packed_sequence_specs, | ||
| dataset_kwargs=dataset_kwargs, | ||
| ) | ||
|
|
||
| def prepare_data(self) -> None: | ||
| # if train file is specified, no need to do anything | ||
| if not self.train_path.exists() or self.force_redownload: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a bit error prone for partial download failures. it'd just delegate to the huggingface load_dataset logic and rely on it for skipping download. |
||
| dset = self._download_data() | ||
| self._preprocess_and_split_data(dset) | ||
| super().prepare_data() | ||
|
|
||
| def _download_data(self): | ||
| logging.info(f"Downloading {self.__class__.__name__}...") | ||
| return load_dataset( | ||
| "bespokelabs/Bespoke-Stratos-17k", | ||
| cache_dir=str(self.dataset_root), | ||
| download_mode="force_redownload" if self.force_redownload else None, | ||
| ) | ||
|
|
||
| def _preprocess_and_split_data(self, dset, train_ratio: float = 0.80, val_ratio: float = 0.15): | ||
| logging.info(f"Preprocessing {self.__class__.__name__} to jsonl format and splitting...") | ||
| test_ratio = 1 - train_ratio - val_ratio | ||
| save_splits = {} | ||
| dataset = dset.get('train') | ||
| split_dataset = dataset.train_test_split(test_size=val_ratio + test_ratio, seed=self.seed) | ||
| split_dataset2 = split_dataset['test'].train_test_split( | ||
| test_size=test_ratio / (val_ratio + test_ratio), seed=self.seed | ||
| ) | ||
| save_splits['training'] = split_dataset['train'] | ||
| save_splits['validation'] = split_dataset2['train'] | ||
| save_splits['test'] = split_dataset2['test'] | ||
|
|
||
| print("len training: ", len(save_splits['training'])) | ||
| print("len validation: ", len(save_splits['validation'])) | ||
| print("len test: ", len(save_splits['test'])) | ||
|
|
||
| for split_name, dataset in save_splits.items(): | ||
| output_file = self.dataset_root / f"{split_name}.jsonl" | ||
| with output_file.open("w", encoding="utf-8") as f: | ||
| for example in dataset: | ||
|
|
||
| conversations = example["conversations"] | ||
|
|
||
| for conversation in conversations: | ||
| if conversation["from"] == "user": | ||
| conversation["from"] = "User" | ||
| elif conversation["from"] == "assistant": | ||
| conversation["from"] = "Assistant" | ||
| else: | ||
| raise ValueError(f"Unknown role: {conversation['role']}") | ||
|
|
||
| example["mask"] = "User" | ||
| example["type"] = "VALUE_TO_TEXT" | ||
|
|
||
| f.write(json.dumps(example) + "\n") | ||
|
|
||
| logging.info(f"{split_name} split saved to {output_file}") | ||
|
|
||
| if self.delete_raw: | ||
| for p in self.dataset_root.iterdir(): | ||
| if p.is_dir(): | ||
| shutil.rmtree(p) | ||
| elif '.jsonl' not in str(p.name): | ||
| p.unlink() | ||
|
|
||
| @lru_cache | ||
| def _create_dataset(self, path, pack_metadata_path=None, is_test=False, **kwargs): | ||
| # pylint: disable=C0115,C0116 | ||
| return create_sft_dataset( | ||
| path, | ||
| tokenizer=self.tokenizer, | ||
| seq_length=(self.seq_length if is_test or self.packed_sequence_size <= 0 else self.packed_sequence_size), | ||
| memmap_workers=self.memmap_workers, | ||
| seed=self.seed, | ||
| chat=True, | ||
| is_test=is_test, | ||
| pack_metadata_file_path=None, # packing is not supported | ||
| pad_cu_seqlens=False, | ||
| **kwargs, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| import argparse | ||
| import nemo_run as run | ||
| import lightning.pytorch as pl | ||
| from nemo.collections import llm | ||
| import torch | ||
|
|
||
| def parse_args(): | ||
| parser = argparse.ArgumentParser(description="Convert Hugging Face checkpoint to NeMo format") | ||
|
|
||
| # Model ID configuration | ||
| parser.add_argument( | ||
| "--model_id", | ||
| type=str, | ||
| default="Qwen/Qwen2.5-7B-Instruct", | ||
| help="Source path for the model (e.g., model-name or local path)" | ||
| ) | ||
|
|
||
| # Overwrite option | ||
| parser.add_argument( | ||
| "--overwrite", | ||
| action="store_true", | ||
| default=True, | ||
| help="Whether to overwrite existing checkpoint" | ||
| ) | ||
|
|
||
| # Executor type | ||
| parser.add_argument( | ||
| "--executor", | ||
| type=str, | ||
| default="local", | ||
| choices=["local"], | ||
| help="Executor type to use" | ||
| ) | ||
|
|
||
| # Output directory | ||
| parser.add_argument( | ||
| "--output-dir", | ||
| type=str, | ||
| default=None, | ||
| help="Output directory for converted checkpoint" | ||
| ) | ||
|
|
||
| return parser.parse_args() | ||
|
|
||
| def get_model_config(model_name): | ||
| """Get model configuration based on model name""" | ||
| model_configs = { | ||
| "Qwen/Qwen2.5-7B-Instruct": llm.qwen2_7b.model(), | ||
| } | ||
|
|
||
| if model_name not in model_configs: | ||
| raise ValueError(f"Unsupported model: {model_name}") | ||
|
|
||
| return model_configs[model_name] | ||
|
|
||
| def configure_checkpoint_conversion(args): | ||
| """Configure checkpoint conversion with command line arguments""" | ||
| model_config = get_model_config(args.model_id) | ||
|
|
||
| conversion_config = { | ||
| "model": model_config, | ||
| "source": f"hf://{args.model_id}", | ||
| "overwrite": args.overwrite, | ||
| } | ||
|
|
||
| # Add output directory if specified | ||
| if args.output_dir: | ||
| conversion_config["output_dir"] = args.output_dir | ||
|
|
||
| return run.Partial(llm.import_ckpt, **conversion_config) | ||
|
|
||
| def get_executor(executor_type): | ||
| """Get executor based on type""" | ||
| if executor_type == "local": | ||
| return run.LocalExecutor() | ||
| else: | ||
| raise ValueError(f"Unsupported executor type: {executor_type}") | ||
|
|
||
| def main(): | ||
| # Parse command line arguments | ||
| args = parse_args() | ||
|
|
||
| # Print CUDA information | ||
| print(f"CUDA available: {torch.cuda.is_available()}") | ||
| print(f"Number of GPUs: {torch.cuda.device_count()}") | ||
| print(f"Model/Source: {args.model_id}") | ||
| print(f"Overwrite: {args.overwrite}") | ||
| print(f"Executor: {args.executor}") | ||
|
|
||
| # Configure checkpoint conversion | ||
| import_ckpt = configure_checkpoint_conversion(args) | ||
|
|
||
| # Define executor | ||
| executor = get_executor(args.executor) | ||
|
|
||
| # Run experiment | ||
| print("Starting checkpoint conversion...") | ||
| run.run(import_ckpt, executor=executor) | ||
| print("Checkpoint conversion completed!") | ||
|
|
||
| if __name__ == "__main__": | ||
| main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a basic linter to this repo and clean up all the newline ends and formating in a seprate PR? check w Nico what's preferred for our repos.