-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement download subcommand, optional positional model name argument #234
Conversation
9d595bd
to
7a60392
Compare
Make sure to run linter Setup: pip install -r requirements-lintrunner.txt lintrunner -a --all-files |
7a60392
to
0dc48b0
Compare
Can you add a CI test to exercise the download path? |
4f56a24
to
29516d5
Compare
I'm going to actually defer this because converting some of the larger models takes over an hour. We do need CI coverage, but I might need to experiment with runner size and the choice of model, and I want to land this to unblock others. Tracking via T186104081. |
88859fd
to
3f6eb29
Compare
Sounds like you should download gguf file that's heavily quantized, and/or stories15M! |
GGUF has it's own conversion logic. Stories is also a little bit special because it has a unique format and I'll have to add special logic to handle the download. That being said, it would be nice to have, so I'll probably do that. I want to look more into why it takes upwards of an hour to convert a 7B model on the runner, though. Seems like something is wrong. It shouldn't take that long to shuffle around the weights. Edit: |
8fbc926
to
da23171
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolve Michael's issue and merge in the changes to support Llama3 and merge.
cli.py
Outdated
"--checkpoint-dir", | ||
type=Path, | ||
default=None, | ||
help="Model checkpoint directory.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure what you mean by developer-only option
cli.py
Outdated
parser.add_argument( | ||
"--gguf-path", | ||
type=Path, | ||
default=None, | ||
help="GGUF file path.", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that currently, specifying this with DSO or pte is only a warning; IMO we should hard error because it's easily fixed and a great way to waste a lot of time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, though we should probably take this as a follow up.
if model_dir is None: | ||
model_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is the default something that's not even in models.json?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mikekgfb Do we need this default value anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, but we should put this or a similarly situated chat model into the models.json.
BTW, I really think it's bad to have even the model name default to something (unless we're so excited about llama3 that we make it that.... but that will require users to have obtained a token)
if model in model_aliases: | ||
model = model_aliases[model] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: model = model_aliases.get(model, model)
is shorter FWIW
print(f"Downloading {url}...") | ||
urllib.request.urlretrieve(url, str(local_path.absolute())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be nice to use progressbar or tqdm to show a progress bar since these downloads can be big; can leave for follow-up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking about this for checkpoint conversion, as well. My only concern was an additional dependency, but if that's not a worry, I can go ahead and add it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think makes sense. Let's make sure we lazily import it, maybe I don't want to wait if I am not downloading/converting?
5ce3954
to
39f81c7
Compare
39f81c7
to
1c88063
Compare
1c88063
to
5eab970
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please review and address comments (either why we should do something else, or do as suggested, either works... but we should document why we choose what we choose)
@@ -134,9 +144,12 @@ def from_args(cls, args): # -> TokenizerArgs: | |||
|
|||
if args.tokenizer_path: | |||
tokenizer_path = args.tokenizer_path | |||
elif args.model: # Using a named, well-known model | |||
model_config = resolve_model_config(args.model) | |||
tokenizer_path = Path(args.model_directory) / model_config.name / "tokenizer.model" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well known doesn't mean it's local. how do you know where the tokenizer is?
if model_dir is None: | ||
model_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, but we should put this or a similarly situated chat model into the models.json.
BTW, I really think it's bad to have even the model name default to something (unless we're so excited about llama3 that we make it that.... but that will require users to have obtained a token)
print(f"Downloading {url}...") | ||
urllib.request.urlretrieve(url, str(local_path.absolute())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think makes sense. Let's make sure we lazily import it, maybe I don't want to wait if I am not downloading/converting?
@@ -546,8 +550,6 @@ def callback(x): | |||
|
|||
|
|||
def main(args): | |||
is_chat = args.subcommand == "chat" | |||
|
|||
# If a named model was provided and not downloaded, download it. | |||
if args.model and not is_model_downloaded(args.model, args.model_directory): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should intercept this in a central place, like in cli() because all functions basically need to do the same? So we dupe it in a gazillion places?
#234) * Implement download option * Add support for model aliases * Support model name as a positional parameter * Merge GenerateArgs changes * Run lint * Revert chat subcommand/arg changes * Add mistral-7b-instruct alias, fix lints * Add model config for known models * Move known model config to config/models.json * Make model names case-insensitive * Move known model configuration from build/model.py to config/model_config.py * Fix lints * Fixing issues after rebasing * Update README
#234) * Implement download option * Add support for model aliases * Support model name as a positional parameter * Merge GenerateArgs changes * Run lint * Revert chat subcommand/arg changes * Add mistral-7b-instruct alias, fix lints * Add model config for known models * Move known model config to config/models.json * Make model names case-insensitive * Move known model configuration from build/model.py to config/model_config.py * Fix lints * Fixing issues after rebasing * Update README
#234) * Implement download option * Add support for model aliases * Support model name as a positional parameter * Merge GenerateArgs changes * Run lint * Revert chat subcommand/arg changes * Add mistral-7b-instruct alias, fix lints * Add model config for known models * Move known model config to config/models.json * Make model names case-insensitive * Move known model configuration from build/model.py to config/model_config.py * Fix lints * Fixing issues after rebasing * Update README
#234) * Implement download option * Add support for model aliases * Support model name as a positional parameter * Merge GenerateArgs changes * Run lint * Revert chat subcommand/arg changes * Add mistral-7b-instruct alias, fix lints * Add model config for known models * Move known model config to config/models.json * Make model names case-insensitive * Move known model configuration from build/model.py to config/model_config.py * Fix lints * Fixing issues after rebasing * Update README
#234) * Implement download option * Add support for model aliases * Support model name as a positional parameter * Merge GenerateArgs changes * Run lint * Revert chat subcommand/arg changes * Add mistral-7b-instruct alias, fix lints * Add model config for known models * Move known model config to config/models.json * Make model names case-insensitive * Move known model configuration from build/model.py to config/model_config.py * Fix lints * Fixing issues after rebasing * Update README
#234) * Implement download option * Add support for model aliases * Support model name as a positional parameter * Merge GenerateArgs changes * Run lint * Revert chat subcommand/arg changes * Add mistral-7b-instruct alias, fix lints * Add model config for known models * Move known model config to config/models.json * Make model names case-insensitive * Move known model configuration from build/model.py to config/model_config.py * Fix lints * Fixing issues after rebasing * Update README
Implementing download subcommand to download and convert model from HuggingFace. Add an optional positional argument to other torchchat subcommands to use a downloaded model. The model name can be either a known HF path, such as
meta-llama/Llama-2-7b-chat-hf
, or an alias, such asllama2
. Per-model configuration, including the download channel and model aliases, are under config/models.json.Example usage:
As a follow up, I intend to refactor the CLI model positional arg handling. It might also be nice to intelligently handle multiple file types with the positional arg, such as a gguf.
Test Plan:
CI for model options are covered here:
--gguf-path:
torchchat/.github/workflows/pull.yml
Line 298 in 5203f0c
--dso-path:
torchchat/.ci/scripts/validate.sh
Line 108 in 5203f0c
--pte-path:
torchchat/.ci/scripts/validate.sh
Line 172 in 5203f0c
Since there are many ways to load a model, I'm relying on CI to exercise many of the paths.