-
Notifications
You must be signed in to change notification settings - Fork 221
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement download subcommand, optional positional model name argument (
#234) * Implement download option * Add support for model aliases * Support model name as a positional parameter * Merge GenerateArgs changes * Run lint * Revert chat subcommand/arg changes * Add mistral-7b-instruct alias, fix lints * Add model config for known models * Move known model config to config/models.json * Make model names case-insensitive * Move known model configuration from build/model.py to config/model_config.py * Fix lints * Fixing issues after rebasing * Update README
- Loading branch information
1 parent
4c97582
commit 3e49b4d
Showing
15 changed files
with
327 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,5 @@ __pycache__/ | |
|
||
# C extensions | ||
*.so | ||
|
||
.model-artifacts/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
{ | ||
"meta-llama/Llama-2-7b-chat-hf": { | ||
"aliases": ["llama2", "llama2-7b"], | ||
"distribution_channel": "HuggingFaceSnapshot", | ||
"distribution_path": "meta-llama/Llama-2-7b-chat-hf" | ||
}, | ||
"mistralai/Mistral-7B-Instruct-v0.2": { | ||
"aliases": ["mistral-7b-instruct"], | ||
"distribution_channel": "HuggingFaceSnapshot", | ||
"distribution_path": "mistralai/Mistral-7B-Instruct-v0.2" | ||
}, | ||
"stories15M": { | ||
"distribution_channel": "DirectDownload", | ||
"distribution_path": [ | ||
"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt", | ||
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.model" | ||
], | ||
"checkpoint_file": "stories15M.pt" | ||
}, | ||
"stories110M": { | ||
"distribution_channel": "DirectDownload", | ||
"distribution_path": [ | ||
"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt", | ||
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.model" | ||
], | ||
"checkpoint_file": "stories110M.pt" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import json | ||
from dataclasses import dataclass, field | ||
from enum import Enum | ||
from pathlib import Path | ||
from typing import Dict, Sequence, Union | ||
|
||
""" | ||
Known Model Configs: | ||
For models that are known to work with torchchat, we provide a config under | ||
config/data/models.json to support automatically downloading the model and | ||
converting to the expected format for use with torchchat. | ||
There are two supported distribution channels: | ||
1) HuggingFaceSnapshot: Download a model from HuggingFace. | ||
2) DirectDownload: Download a list of model artifacts from URLs. No conversion | ||
is done. | ||
""" | ||
|
||
|
||
# Specifies the distribution channel to download model artifacts from. Enum | ||
# variants are specified as strings to simplify JSON (de)serialization. | ||
class ModelDistributionChannel(str, Enum): | ||
# Download a full model snapshot from HuggingFace, such as | ||
# meta-llama/Llama-2-7b-chat-hf and convert to torchchat format. | ||
HuggingFaceSnapshot = "HuggingFaceSnapshot" | ||
|
||
# Download one or more files over HTTP(S). | ||
DirectDownload = "DirectDownload" | ||
|
||
|
||
@dataclass | ||
class ModelConfig: | ||
name: str = field(default="") | ||
aliases: Sequence[str] = field(default_factory=list) | ||
distribution_path: Union[str, Sequence[str]] = field(default="") | ||
distribution_channel: ModelDistributionChannel = field( | ||
default=ModelDistributionChannel.HuggingFaceSnapshot | ||
) | ||
checkpoint_file: str = field(default="model.pth") | ||
|
||
|
||
# Keys are stored in lowercase. | ||
model_aliases: Dict[str, str] = None | ||
model_configs: Dict[str, ModelConfig] = None | ||
|
||
|
||
def resolve_model_config(model: str) -> ModelConfig: | ||
global model_aliases | ||
global model_configs | ||
|
||
model = model.lower() | ||
|
||
# Lazy load model config from JSON. | ||
if not model_configs: | ||
model_aliases = {} | ||
model_configs = {} | ||
|
||
with open( | ||
Path(__file__).parent.parent / "config" / "data" / "models.json", "r" | ||
) as f: | ||
model_config_dict = json.load(f) | ||
|
||
for key, value in model_config_dict.items(): | ||
config = ModelConfig(**value) | ||
config.name = key | ||
|
||
key = key.lower() | ||
model_configs[key] = config | ||
|
||
for alias in config.aliases: | ||
model_aliases[alias.lower()] = key | ||
|
||
if model in model_aliases: | ||
model = model_aliases[model] | ||
|
||
if model not in model_configs: | ||
raise ValueError(f"Unknown model '{model}'.") | ||
|
||
return model_configs[model] |
Oops, something went wrong.