diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index a0edd15ae..bfb989408 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -361,6 +361,11 @@ jobs: cat ./output_eager2 echo "Tests complete." + - name: Test download + run: | + + python torchchat.py generate stories15M + test-tinystories-eager: strategy: matrix: diff --git a/.gitignore b/.gitignore index 8018116a6..9743b9e3a 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ __pycache__/ # C extensions *.so + +.model-artifacts/ diff --git a/README.md b/README.md index 915fa8d60..d1f44947e 100644 --- a/README.md +++ b/README.md @@ -32,14 +32,10 @@ python torchchat.py --help ``` -### Dowenload a Model and Tokenizer +### Generating Text ``` -#download a model -python torchchat.py download llama2 - -#generate text using the model - +python torchchat.py generate stories15M ``` That’s all there is to it! Read on to learn how to use the full power of torchchat. @@ -48,7 +44,15 @@ Read on to learn how to use the full power of torchchat. For the full details on all commands and parameters run `python torchchat.py --help` ### Download -TODO: Fill this out +For supported models, torchchat can download model weights. Most models use HuggingFace as the distribution channel, so you will need to create a HuggingFace +account and install `huggingface-cli`. + +To install `huggingface-cli`, run `pip install huggingface-cli`. After installing, create a user access token [as documented here](https://huggingface.co/docs/hub/en/security-tokens). Run `huggingface-cli login`, which will prompt for the newly created token. Once this is done, torchchat will be able to download model artifacts from +HuggingFace. + +``` +python torchchat.py download llama2 +``` ### Chat Designed for interactive and conversational use. @@ -69,7 +73,7 @@ For more information run `python torchchat.py generate --help` **Examples** ``` -#Generate for Mac with some parameters +python torchchat.py generate llama2 --device=cpu --dtype=fp16 ``` ### Export @@ -80,7 +84,7 @@ For more information run `python torchchat.py export --help` **Examples** ``` -#Export Example +python torchchat.py export stories15M --output-pte-path=stories15m.pte ``` ### Browser diff --git a/build/builder.py b/build/builder.py index 103e2c9e9..f13294b98 100644 --- a/build/builder.py +++ b/build/builder.py @@ -14,7 +14,7 @@ import torch import torch._dynamo.config import torch._inductor.config - +from config.model_config import resolve_model_config from quantize import name_to_dtype, quantize_model from sentencepiece import SentencePieceProcessor @@ -42,7 +42,7 @@ class BuilderArgs: def __post_init__(self): if not ( (self.checkpoint_path and self.checkpoint_path.is_file()) - or (self.checkpoint_dir and self.checkpoint_path.is_dir()) + or (self.checkpoint_dir and self.checkpoint_dir.is_dir()) or (self.gguf_path and self.gguf_path.is_file()) or (self.dso_path and Path(self.dso_path).is_file()) or (self.pte_path and Path(self.pte_path).is_file()) @@ -73,7 +73,17 @@ def from_args(cls, args): # -> BuilderArgs: # Handle disabled checkpoint_dir option checkpoint_dir = None if hasattr(args, "checkpoint_dir"): - checkpoint_dir = args.checkpoint_dir + checkpoint_dir = args.checkpoint_dir + + checkpoint_path = args.checkpoint_path + if args.model: # Using a named, well-known model + model_config = resolve_model_config(args.model) + + checkpoint_path = ( + Path(args.model_directory) + / model_config.name + / model_config.checkpoint_file + ) is_chat_model = False if args.is_chat_model: @@ -94,8 +104,8 @@ def from_args(cls, args): # -> BuilderArgs: is_chat_model = True return cls( - checkpoint_path=args.checkpoint_path, checkpoint_dir=checkpoint_dir, + checkpoint_path=checkpoint_path, params_path=args.params_path, params_table=args.params_table, gguf_path=args.gguf_path, @@ -134,9 +144,12 @@ def from_args(cls, args): # -> TokenizerArgs: if args.tokenizer_path: tokenizer_path = args.tokenizer_path + elif args.model: # Using a named, well-known model + model_config = resolve_model_config(args.model) + tokenizer_path = Path(args.model_directory) / model_config.name / "tokenizer.model" elif args.checkpoint_path: tokenizer_path = args.checkpoint_path.parent / "tokenizer.model" - elif args.checkpoint_dir: + elif hasattr(args, "checkpoint_dir") and args.checkpoint_dir: tokenizer_path = args.checkpoint_dir / "tokenizer.model" else: raise RuntimeError("cannot find tokenizer model") @@ -356,4 +369,10 @@ def validate_args(model: Transformer, tokenizer_args: TokenizerArgs): is_tiktoken = tokenizer_args.is_tiktoken if use_tiktoken != is_tiktoken: raise RuntimeError(f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)}") - + +def resolve_model_name(model: str) -> str: + # If the provided model name is an alias, retrieve the full path. + if model in model_aliases: + return model_aliases[model] + else: + return model diff --git a/scripts/convert_hf_checkpoint.py b/build/convert_hf_checkpoint.py similarity index 85% rename from scripts/convert_hf_checkpoint.py rename to build/convert_hf_checkpoint.py index b5d6d7ba2..b3a221887 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/build/convert_hf_checkpoint.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import json +import os import re import sys from pathlib import Path @@ -22,19 +23,20 @@ @torch.inference_mode() def convert_hf_checkpoint( *, - checkpoint_dir: Optional[Path] = None, + model_dir: Optional[Path] = None, model_name: Optional[str] = None, + remove_bin_files: bool = False, ) -> None: - if checkpoint_dir is None: - checkpoint_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf") + if model_dir is None: + model_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf") if model_name is None: - model_name = checkpoint_dir.name + model_name = model_dir.name config = ModelArgs.from_name(model_name) print(f"Model config {config.__dict__}") # Load the json file containing weight mapping - model_map_json = checkpoint_dir / "pytorch_model.bin.index.json" + model_map_json = model_dir / "pytorch_model.bin.index.json" assert model_map_json.is_file() @@ -56,7 +58,7 @@ def convert_hf_checkpoint( "model.norm.weight": "norm.weight", "lm_head.weight": "output.weight", } - bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} + bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()} def permute(w, n_heads): dim = config.dim @@ -97,8 +99,13 @@ def permute(w, n_heads): del final_result[key] del final_result[key.replace("wq", "wk")] del final_result[key.replace("wq", "wv")] - print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") - torch.save(final_result, checkpoint_dir / "model.pth") + print(f"Saving checkpoint to {model_dir / 'model.pth'}. This may take a while.") + torch.save(final_result, model_dir / "model.pth") + print("Done.") + + if remove_bin_files: + for file in bin_files: + os.remove(file) if __name__ == "__main__": @@ -114,6 +121,6 @@ def permute(w, n_heads): args = parser.parse_args() convert_hf_checkpoint( - checkpoint_dir=args.checkpoint_dir, + model_dir=args.checkpoint_dir, model_name=args.model_name, ) diff --git a/cli.py b/cli.py index 5d11a8c04..c38e7696c 100644 --- a/cli.py +++ b/cli.py @@ -14,6 +14,11 @@ def check_args(args, name: str) -> None: pass +def add_arguments_for_download(parser): + # Only download specific options should be here + _add_arguments_common(parser) + + def add_arguments_for_generate(parser): # Only generate specific options should be here _add_arguments_common(parser) @@ -39,6 +44,19 @@ def add_arguments_for_browser(parser): ) def _add_arguments_common(parser): + # Model specification. TODO Simplify this. + # A model can be specified using a positional model name or HuggingFace + # path. Alternatively, the model can be specified via --gguf-path or via + # an explicit --checkpoint-dir, --checkpoint-path, or --tokenizer-path. + + parser.add_argument( + "model", + type=str, + nargs="?", + default=None, + help="Model name for well-known models.", + ) + # TODO: Refactor this so that only common options are here # and subcommand-specific options are inside individual # add_arguments_for_generate, add_arguments_for_export etc. @@ -168,6 +186,18 @@ def _add_arguments_common(parser): default=None, help="maximum length sequence to evaluate", ) + parser.add_argument( + "--hf-token", + type=str, + default=None, + help="A HuggingFace API token to use when downloading model artifacts", + ) + parser.add_argument( + "--model-directory", + type=Path, + default=".model-artifacts", + help="The directory to store downloaded model artifacts", + ) def arg_init(args): diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/config/data/models.json b/config/data/models.json new file mode 100644 index 000000000..a86af4384 --- /dev/null +++ b/config/data/models.json @@ -0,0 +1,28 @@ +{ + "meta-llama/Llama-2-7b-chat-hf": { + "aliases": ["llama2", "llama2-7b"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "meta-llama/Llama-2-7b-chat-hf" + }, + "mistralai/Mistral-7B-Instruct-v0.2": { + "aliases": ["mistral-7b-instruct"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "mistralai/Mistral-7B-Instruct-v0.2" + }, + "stories15M": { + "distribution_channel": "DirectDownload", + "distribution_path": [ + "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt", + "https://github.com/karpathy/llama2.c/raw/master/tokenizer.model" + ], + "checkpoint_file": "stories15M.pt" + }, + "stories110M": { + "distribution_channel": "DirectDownload", + "distribution_path": [ + "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt", + "https://github.com/karpathy/llama2.c/raw/master/tokenizer.model" + ], + "checkpoint_file": "stories110M.pt" + } +} diff --git a/config/model_config.py b/config/model_config.py new file mode 100644 index 000000000..b2e8c23d9 --- /dev/null +++ b/config/model_config.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import json +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Dict, Sequence, Union + +""" +Known Model Configs: + +For models that are known to work with torchchat, we provide a config under +config/data/models.json to support automatically downloading the model and +converting to the expected format for use with torchchat. + +There are two supported distribution channels: + +1) HuggingFaceSnapshot: Download a model from HuggingFace. +2) DirectDownload: Download a list of model artifacts from URLs. No conversion + is done. +""" + + +# Specifies the distribution channel to download model artifacts from. Enum +# variants are specified as strings to simplify JSON (de)serialization. +class ModelDistributionChannel(str, Enum): + # Download a full model snapshot from HuggingFace, such as + # meta-llama/Llama-2-7b-chat-hf and convert to torchchat format. + HuggingFaceSnapshot = "HuggingFaceSnapshot" + + # Download one or more files over HTTP(S). + DirectDownload = "DirectDownload" + + +@dataclass +class ModelConfig: + name: str = field(default="") + aliases: Sequence[str] = field(default_factory=list) + distribution_path: Union[str, Sequence[str]] = field(default="") + distribution_channel: ModelDistributionChannel = field( + default=ModelDistributionChannel.HuggingFaceSnapshot + ) + checkpoint_file: str = field(default="model.pth") + + +# Keys are stored in lowercase. +model_aliases: Dict[str, str] = None +model_configs: Dict[str, ModelConfig] = None + + +def resolve_model_config(model: str) -> ModelConfig: + global model_aliases + global model_configs + + model = model.lower() + + # Lazy load model config from JSON. + if not model_configs: + model_aliases = {} + model_configs = {} + + with open( + Path(__file__).parent.parent / "config" / "data" / "models.json", "r" + ) as f: + model_config_dict = json.load(f) + + for key, value in model_config_dict.items(): + config = ModelConfig(**value) + config.name = key + + key = key.lower() + model_configs[key] = config + + for alias in config.aliases: + model_aliases[alias.lower()] = key + + if model in model_aliases: + model = model_aliases[model] + + if model not in model_configs: + raise ValueError(f"Unknown model '{model}'.") + + return model_configs[model] diff --git a/download.py b/download.py new file mode 100644 index 000000000..031b31f56 --- /dev/null +++ b/download.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import os +import urllib.request +from pathlib import Path +from typing import Optional, Sequence + +from build.convert_hf_checkpoint import convert_hf_checkpoint +from config.model_config import ( + ModelDistributionChannel, + resolve_model_config, +) + +from requests.exceptions import HTTPError + + +def _download_and_convert_hf_snapshot( + model: str, models_dir: Path, hf_token: Optional[str] +): + model_dir = models_dir / model + os.makedirs(model_dir, exist_ok=True) + + from huggingface_hub import snapshot_download + + # Download and store the HF model artifacts. + print(f"Downloading {model} from HuggingFace...") + try: + snapshot_download( + model, + local_dir=model_dir, + local_dir_use_symlinks=False, + token=hf_token, + ignore_patterns="*safetensors*", + ) + except HTTPError as e: + if e.response.status_code == 401: + raise RuntimeError( + "Access denied. Run huggingface-cli login to authenticate." + ) + os.rmdir(model_dir) + else: + raise e + + # Convert the model to the torchchat format. + print(f"Converting {model} to torchchat format...") + convert_hf_checkpoint(model_dir=model_dir, model_name=model, remove_bin_files=True) + + +def _download_direct( + model: str, + urls: Sequence[str], + models_dir: Path, +): + model_dir = models_dir / model + os.makedirs(model_dir, exist_ok=True) + + for url in urls: + filename = url.split("/")[-1] + local_path = model_dir / filename + print(f"Downloading {url}...") + urllib.request.urlretrieve(url, str(local_path.absolute())) + + +def download_and_convert( + model: str, models_dir: Path, hf_token: Optional[str] = None +) -> None: + model_config = resolve_model_config(model) + + if ( + model_config.distribution_channel + == ModelDistributionChannel.HuggingFaceSnapshot + ): + _download_and_convert_hf_snapshot(model_config.name, models_dir, hf_token) + elif model_config.distribution_channel == ModelDistributionChannel.DirectDownload: + _download_direct(model_config.name, model_config.distribution_path, models_dir) + else: + raise RuntimeError( + f"Unknown distribution channel {model_config.distribution_channel}." + ) + + +def is_model_downloaded(model: str, models_dir: Path) -> bool: + model_config = resolve_model_config(model) + + model_dir = models_dir / model_config.name + return os.path.isdir(model_dir) + + +def main(args): + download_and_convert(args.model, args.model_directory, args.hf_token) diff --git a/eval.py b/eval.py index 24bac880d..f1bbd6e8b 100644 --- a/eval.py +++ b/eval.py @@ -21,6 +21,7 @@ from build.model import Transformer from cli import add_arguments_for_eval, arg_init +from download import download_and_convert, is_model_downloaded from generate import encode_tokens, model_forward from quantize import set_precision @@ -222,6 +223,10 @@ def main(args) -> None: """ + # If a named model was provided and not downloaded, download it. + if args.model and not is_model_downloaded(args.model, args.model_directory): + download_and_convert(args.model, args.model_directory, args.hf_token) + builder_args = BuilderArgs.from_args(args) tokenizer_args = TokenizerArgs.from_args(args) quantize = args.quantize diff --git a/export.py b/export.py index edcfa61e6..63b990126 100644 --- a/export.py +++ b/export.py @@ -16,6 +16,7 @@ BuilderArgs, ) from cli import add_arguments_for_export, arg_init, check_args +from download import download_and_convert, is_model_downloaded from export_aoti import export_model as export_model_aoti from quantize import set_precision @@ -41,6 +42,10 @@ def device_sync(device): def main(args): + # If a named model was provided and not downloaded, download it. + if args.model and not is_model_downloaded(args.model, args.model_directory): + download_and_convert(args.model, args.model_directory, args.hf_token) + builder_args = BuilderArgs.from_args(args) quantize = args.quantize diff --git a/generate.py b/generate.py index 1610e51ce..29dbd954c 100644 --- a/generate.py +++ b/generate.py @@ -27,6 +27,7 @@ ) from build.model import Transformer from cli import add_arguments_for_generate, arg_init, check_args +from download import download_and_convert, is_model_downloaded from quantize import set_precision logger = logging.getLogger(__name__) @@ -545,6 +546,10 @@ def callback(x): def main(args): + # If a named model was provided and not downloaded, download it. + if args.model and not is_model_downloaded(args.model, args.model_directory): + download_and_convert(args.model, args.model_directory, args.hf_token) + builder_args = BuilderArgs.from_args(args) speculative_builder_args = BuilderArgs.from_speculative_args(args) tokenizer_args = TokenizerArgs.from_args(args) diff --git a/scripts/download.py b/scripts/download.py deleted file mode 100644 index 387c41243..000000000 --- a/scripts/download.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -import os -from typing import Optional - -from requests.exceptions import HTTPError - - -def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None: - from huggingface_hub import snapshot_download - - os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) - try: - snapshot_download( - repo_id, - local_dir=f"checkpoints/{repo_id}", - local_dir_use_symlinks=False, - token=hf_token, - ignore_patterns="*safetensors*", - ) - except HTTPError as e: - if e.response.status_code == 401: - print( - "You need to pass a valid `--hf_token=...` to download private checkpoints." - ) - else: - raise e - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="Download data from HuggingFace Hub.") - parser.add_argument( - "--repo-id", - type=str, - default="checkpoints/meta-llama/llama-2-7b-chat-hf", - help="Repository ID to download from.", - ) - parser.add_argument( - "--hf-token", type=str, default=None, help="HuggingFace API token." - ) - - args = parser.parse_args() - hf_download(args.repo_id, args.hf_token) diff --git a/torchchat.py b/torchchat.py index a3e1055bc..b653f9e27 100644 --- a/torchchat.py +++ b/torchchat.py @@ -10,6 +10,7 @@ import sys from cli import ( + add_arguments_for_download, add_arguments_for_eval, add_arguments_for_export, add_arguments_for_generate, @@ -25,9 +26,12 @@ parser = argparse.ArgumentParser(description="Top-level command") subparsers = parser.add_subparsers( dest="subcommand", - help="Use `generate`, `eval`, `export` or `browser` followed by subcommand specific options.", + help="Use `download`, `generate`, `eval`, `export` or `browser` followed by subcommand specific options.", ) + parser_download = subparsers.add_parser("download") + add_arguments_for_download(parser_download) + parser_generate = subparsers.add_parser("generate") add_arguments_for_generate(parser_generate) @@ -46,7 +50,12 @@ format="%(message)s", level=logging.DEBUG if args.verbose else logging.INFO ) - if args.subcommand == "generate": + if args.subcommand == "download": + check_args(args, "download") + from download import main as download_main + + download_main(args) + elif args.subcommand == "generate": check_args(args, "generate") from generate import main as generate_main @@ -85,4 +94,6 @@ command = ["flask", "--app", "chat_in_browser:create_app(" + formatted_args + ")", "run", "--port", f"{port}"] subprocess.run(command) else: - raise RuntimeError("Must specify valid subcommands: generate, export, eval") + raise RuntimeError( + "Must specify a valid subcommand: download, generate, export, or eval." + )