Skip to content

Commit

Permalink
Implement download subcommand, optional positional model name argument (
Browse files Browse the repository at this point in the history
#234)

* Implement download option

* Add support for model aliases

* Support model name as a positional parameter

* Merge GenerateArgs changes

* Run lint

* Revert chat subcommand/arg changes

* Add mistral-7b-instruct alias, fix lints

* Add model config for known models

* Move known model config to config/models.json

* Make model names case-insensitive

* Move known model configuration from build/model.py to config/model_config.py

* Fix lints

* Fixing issues after rebasing

* Update README
  • Loading branch information
GregoryComer authored and malfet committed Jul 17, 2024
1 parent 4c97582 commit 3e49b4d
Show file tree
Hide file tree
Showing 15 changed files with 327 additions and 75 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,11 @@ jobs:
cat ./output_eager2
echo "Tests complete."
- name: Test download
run: |
python torchchat.py generate stories15M
test-tinystories-eager:
strategy:
matrix:
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ __pycache__/

# C extensions
*.so

.model-artifacts/
22 changes: 13 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,10 @@ python torchchat.py --help
```

### Dowenload a Model and Tokenizer
### Generating Text

```
#download a model
python torchchat.py download llama2
#generate text using the model
python torchchat.py generate stories15M
```
That’s all there is to it!
Read on to learn how to use the full power of torchchat.
Expand All @@ -48,7 +44,15 @@ Read on to learn how to use the full power of torchchat.
For the full details on all commands and parameters run `python torchchat.py --help`

### Download
TODO: Fill this out
For supported models, torchchat can download model weights. Most models use HuggingFace as the distribution channel, so you will need to create a HuggingFace
account and install `huggingface-cli`.

To install `huggingface-cli`, run `pip install huggingface-cli`. After installing, create a user access token [as documented here](https://huggingface.co/docs/hub/en/security-tokens). Run `huggingface-cli login`, which will prompt for the newly created token. Once this is done, torchchat will be able to download model artifacts from
HuggingFace.

```
python torchchat.py download llama2
```

### Chat
Designed for interactive and conversational use.
Expand All @@ -69,7 +73,7 @@ For more information run `python torchchat.py generate --help`

**Examples**
```
#Generate for Mac with some parameters
python torchchat.py generate llama2 --device=cpu --dtype=fp16
```

### Export
Expand All @@ -80,7 +84,7 @@ For more information run `python torchchat.py export --help`
**Examples**

```
#Export Example
python torchchat.py export stories15M --output-pte-path=stories15m.pte
```

### Browser
Expand Down
31 changes: 25 additions & 6 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
import torch._dynamo.config
import torch._inductor.config

from config.model_config import resolve_model_config
from quantize import name_to_dtype, quantize_model

from sentencepiece import SentencePieceProcessor
Expand Down Expand Up @@ -42,7 +42,7 @@ class BuilderArgs:
def __post_init__(self):
if not (
(self.checkpoint_path and self.checkpoint_path.is_file())
or (self.checkpoint_dir and self.checkpoint_path.is_dir())
or (self.checkpoint_dir and self.checkpoint_dir.is_dir())
or (self.gguf_path and self.gguf_path.is_file())
or (self.dso_path and Path(self.dso_path).is_file())
or (self.pte_path and Path(self.pte_path).is_file())
Expand Down Expand Up @@ -73,7 +73,17 @@ def from_args(cls, args): # -> BuilderArgs:
# Handle disabled checkpoint_dir option
checkpoint_dir = None
if hasattr(args, "checkpoint_dir"):
checkpoint_dir = args.checkpoint_dir
checkpoint_dir = args.checkpoint_dir

checkpoint_path = args.checkpoint_path
if args.model: # Using a named, well-known model
model_config = resolve_model_config(args.model)

checkpoint_path = (
Path(args.model_directory)
/ model_config.name
/ model_config.checkpoint_file
)

is_chat_model = False
if args.is_chat_model:
Expand All @@ -94,8 +104,8 @@ def from_args(cls, args): # -> BuilderArgs:
is_chat_model = True

return cls(
checkpoint_path=args.checkpoint_path,
checkpoint_dir=checkpoint_dir,
checkpoint_path=checkpoint_path,
params_path=args.params_path,
params_table=args.params_table,
gguf_path=args.gguf_path,
Expand Down Expand Up @@ -134,9 +144,12 @@ def from_args(cls, args): # -> TokenizerArgs:

if args.tokenizer_path:
tokenizer_path = args.tokenizer_path
elif args.model: # Using a named, well-known model
model_config = resolve_model_config(args.model)
tokenizer_path = Path(args.model_directory) / model_config.name / "tokenizer.model"
elif args.checkpoint_path:
tokenizer_path = args.checkpoint_path.parent / "tokenizer.model"
elif args.checkpoint_dir:
elif hasattr(args, "checkpoint_dir") and args.checkpoint_dir:
tokenizer_path = args.checkpoint_dir / "tokenizer.model"
else:
raise RuntimeError("cannot find tokenizer model")
Expand Down Expand Up @@ -356,4 +369,10 @@ def validate_args(model: Transformer, tokenizer_args: TokenizerArgs):
is_tiktoken = tokenizer_args.is_tiktoken
if use_tiktoken != is_tiktoken:
raise RuntimeError(f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)}")


def resolve_model_name(model: str) -> str:
# If the provided model name is an alias, retrieve the full path.
if model in model_aliases:
return model_aliases[model]
else:
return model
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import re
import sys
from pathlib import Path
Expand All @@ -22,19 +23,20 @@
@torch.inference_mode()
def convert_hf_checkpoint(
*,
checkpoint_dir: Optional[Path] = None,
model_dir: Optional[Path] = None,
model_name: Optional[str] = None,
remove_bin_files: bool = False,
) -> None:
if checkpoint_dir is None:
checkpoint_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf")
if model_dir is None:
model_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf")
if model_name is None:
model_name = checkpoint_dir.name
model_name = model_dir.name

config = ModelArgs.from_name(model_name)
print(f"Model config {config.__dict__}")

# Load the json file containing weight mapping
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
model_map_json = model_dir / "pytorch_model.bin.index.json"

assert model_map_json.is_file()

Expand All @@ -56,7 +58,7 @@ def convert_hf_checkpoint(
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()}

def permute(w, n_heads):
dim = config.dim
Expand Down Expand Up @@ -97,8 +99,13 @@ def permute(w, n_heads):
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
torch.save(final_result, checkpoint_dir / "model.pth")
print(f"Saving checkpoint to {model_dir / 'model.pth'}. This may take a while.")
torch.save(final_result, model_dir / "model.pth")
print("Done.")

if remove_bin_files:
for file in bin_files:
os.remove(file)


if __name__ == "__main__":
Expand All @@ -114,6 +121,6 @@ def permute(w, n_heads):

args = parser.parse_args()
convert_hf_checkpoint(
checkpoint_dir=args.checkpoint_dir,
model_dir=args.checkpoint_dir,
model_name=args.model_name,
)
30 changes: 30 additions & 0 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
def check_args(args, name: str) -> None:
pass

def add_arguments_for_download(parser):
# Only download specific options should be here
_add_arguments_common(parser)


def add_arguments_for_generate(parser):
# Only generate specific options should be here
_add_arguments_common(parser)
Expand All @@ -39,6 +44,19 @@ def add_arguments_for_browser(parser):
)

def _add_arguments_common(parser):
# Model specification. TODO Simplify this.
# A model can be specified using a positional model name or HuggingFace
# path. Alternatively, the model can be specified via --gguf-path or via
# an explicit --checkpoint-dir, --checkpoint-path, or --tokenizer-path.

parser.add_argument(
"model",
type=str,
nargs="?",
default=None,
help="Model name for well-known models.",
)

# TODO: Refactor this so that only common options are here
# and subcommand-specific options are inside individual
# add_arguments_for_generate, add_arguments_for_export etc.
Expand Down Expand Up @@ -168,6 +186,18 @@ def _add_arguments_common(parser):
default=None,
help="maximum length sequence to evaluate",
)
parser.add_argument(
"--hf-token",
type=str,
default=None,
help="A HuggingFace API token to use when downloading model artifacts",
)
parser.add_argument(
"--model-directory",
type=Path,
default=".model-artifacts",
help="The directory to store downloaded model artifacts",
)


def arg_init(args):
Expand Down
Empty file added config/__init__.py
Empty file.
28 changes: 28 additions & 0 deletions config/data/models.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"meta-llama/Llama-2-7b-chat-hf": {
"aliases": ["llama2", "llama2-7b"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "meta-llama/Llama-2-7b-chat-hf"
},
"mistralai/Mistral-7B-Instruct-v0.2": {
"aliases": ["mistral-7b-instruct"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "mistralai/Mistral-7B-Instruct-v0.2"
},
"stories15M": {
"distribution_channel": "DirectDownload",
"distribution_path": [
"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt",
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.model"
],
"checkpoint_file": "stories15M.pt"
},
"stories110M": {
"distribution_channel": "DirectDownload",
"distribution_path": [
"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt",
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.model"
],
"checkpoint_file": "stories110M.pt"
}
}
86 changes: 86 additions & 0 deletions config/model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Dict, Sequence, Union

"""
Known Model Configs:
For models that are known to work with torchchat, we provide a config under
config/data/models.json to support automatically downloading the model and
converting to the expected format for use with torchchat.
There are two supported distribution channels:
1) HuggingFaceSnapshot: Download a model from HuggingFace.
2) DirectDownload: Download a list of model artifacts from URLs. No conversion
is done.
"""


# Specifies the distribution channel to download model artifacts from. Enum
# variants are specified as strings to simplify JSON (de)serialization.
class ModelDistributionChannel(str, Enum):
# Download a full model snapshot from HuggingFace, such as
# meta-llama/Llama-2-7b-chat-hf and convert to torchchat format.
HuggingFaceSnapshot = "HuggingFaceSnapshot"

# Download one or more files over HTTP(S).
DirectDownload = "DirectDownload"


@dataclass
class ModelConfig:
name: str = field(default="")
aliases: Sequence[str] = field(default_factory=list)
distribution_path: Union[str, Sequence[str]] = field(default="")
distribution_channel: ModelDistributionChannel = field(
default=ModelDistributionChannel.HuggingFaceSnapshot
)
checkpoint_file: str = field(default="model.pth")


# Keys are stored in lowercase.
model_aliases: Dict[str, str] = None
model_configs: Dict[str, ModelConfig] = None


def resolve_model_config(model: str) -> ModelConfig:
global model_aliases
global model_configs

model = model.lower()

# Lazy load model config from JSON.
if not model_configs:
model_aliases = {}
model_configs = {}

with open(
Path(__file__).parent.parent / "config" / "data" / "models.json", "r"
) as f:
model_config_dict = json.load(f)

for key, value in model_config_dict.items():
config = ModelConfig(**value)
config.name = key

key = key.lower()
model_configs[key] = config

for alias in config.aliases:
model_aliases[alias.lower()] = key

if model in model_aliases:
model = model_aliases[model]

if model not in model_configs:
raise ValueError(f"Unknown model '{model}'.")

return model_configs[model]
Loading

0 comments on commit 3e49b4d

Please sign in to comment.