Skip to content

feat(cli): allow users to download models from Kaggle #2002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions docs/source/tune_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,25 @@ to download files using the CLI.
Download a model
----------------

The ``tune download <path>`` command downloads any model from the Hugging Face Hub.
The ``tune download <path>`` command downloads any model from the Hugging Face or Kaggle Model Hub.

.. list-table::
:widths: 30 60

* - \--output-dir
- Directory in which to save the model.
- Directory in which to save the model. Note: this is option not yet supported when `--source` is set to `kaggle`.
* - \--output-dir-use-symlinks
- To be used with `output-dir`. If set to 'auto', the cache directory will be used and the file will be either duplicated or symlinked to the local directory depending on its size. It set to `True`, a symlink will be created, no matter the file size. If set to `False`, the file will either be duplicated from cache (if already exists) or downloaded from the Hub and not cached.
* - \--hf-token
- Hugging Face API token. Needed for gated models like Llama.
* - \--ignore-patterns
- If provided, files matching any of the patterns are not downloaded. Defaults to ignoring safetensors files to avoid downloading duplicate weights.
* - \--source {huggingface,kaggle}
- If provided, downloads model weights from the provided <path> on the designated source hub.
* - \--kaggle-username
- Kaggle username for authentication. Needed for private models or gated models like Llama2.
* - \--kaggle-api-key
- Kaggle API key. Needed for private models or gated models like Llama2. You can find your API key at https://kaggle.com/settings.

.. code-block:: bash

Expand All @@ -62,6 +68,13 @@ The ``tune download <path>`` command downloads any model from the Hugging Face H
./model/model-00001-of-00002.bin
...

.. code-block:: bash

$ tune download metaresearch/llama-3.2/pytorch/1b --source kaggle
Successfully downloaded model repo and wrote to the following locations:
/tmp/llama-3.2/pytorch/1b/tokenizer.model
/tmp/llama-3.2/pytorch/1b/params.json
/tmp/llama-3.2/pytorch/1b/consolidated.00.pth

**Download a gated model**

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ dependencies = [
"huggingface_hub",
"safetensors",

# Kaggle Integrations
"kagglehub",

# Tokenization
"sentencepiece",
"tiktoken",
Expand Down
250 changes: 250 additions & 0 deletions tests/torchtune/_cli/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import runpy
import sys
from unittest import mock

import pytest
from tests.common import TUNE_PATH
Expand Down Expand Up @@ -106,3 +108,251 @@ def test_gated_repo_error_with_token(self, capsys, monkeypatch, snapshot_downloa
"Please ensure you have access to the repository and have provided the proper Hugging Face API token"
not in out_err.err
)

# only valid --source parameters supported (expect prompt for supported values)
def test_source_parameter(self, capsys, monkeypatch):
model = "metaresearch/llama-3.2/pytorch/1b"
testargs = f"tune download {model} --source invalid".split()
monkeypatch.setattr(sys, "argv", testargs)

with pytest.raises(SystemExit, match="2"):
runpy.run_path(TUNE_PATH, run_name="__main__")

output = capsys.readouterr()
assert "argument --source: invalid choice: 'invalid'" in output.err

def test_download_from_kaggle(self, capsys, monkeypatch, mocker, tmpdir):
model = "metaresearch/llama-3.2/pytorch/1b"
testargs = f"tune download {model} --source kaggle --kaggle-username kaggle_user --kaggle-api-key kaggle_api_key".split()
monkeypatch.setattr(sys, "argv", testargs)
# mock out kagglehub.model_download to get around key storage
mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir)

runpy.run_path(TUNE_PATH, run_name="__main__")

output = capsys.readouterr().out
assert "Successfully downloaded model repo" in output

def test_download_from_kaggle_warn_when_output_dir_provided(
self, capsys, monkeypatch, mocker, tmpdir
):
model = "metaresearch/llama-3.2/pytorch/1b"
testargs = f"tune download {model} --source kaggle --output-dir /requested/model/path".split()
monkeypatch.setattr(sys, "argv", testargs)
# mock out kagglehub.model_download to get around key storage
mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir)

with pytest.warns(
UserWarning,
match="--output-dir flag is not supported for Kaggle model downloads",
):
runpy.run_path(TUNE_PATH, run_name="__main__")

output = capsys.readouterr().out
assert "Successfully downloaded model repo" in output

def test_download_from_kaggle_warn_when_ignore_patterns_provided(
self, capsys, monkeypatch, mocker, tmpdir
):
model = "metaresearch/llama-3.2/pytorch/1b"
testargs = f'tune download {model} --source kaggle --ignore-patterns "*.glob-pattern"'.split()
monkeypatch.setattr(sys, "argv", testargs)
# mock out kagglehub.model_download to get around key storage
mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir)

with pytest.warns(
UserWarning,
match="--ignore-patterns flag is not supported for Kaggle model downloads",
):
runpy.run_path(TUNE_PATH, run_name="__main__")

output = capsys.readouterr().out
assert "Successfully downloaded model repo" in output

# tests when --kaggle-username and --kaggle-api-key are provided as CLI args
def test_download_from_kaggle_when_credentials_provided(
self, capsys, monkeypatch, mocker, tmpdir
):
expected_username = "username"
expected_api_key = "api_key"
model = "metaresearch/llama-3.2/pytorch/1b"
testargs = (
f"tune download {model} "
f"--source kaggle "
f"--kaggle-username {expected_username} "
f"--kaggle-api-key {expected_api_key}"
).split()
monkeypatch.setattr(sys, "argv", testargs)
# mock out kagglehub.model_download to get around key storage
mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir)
set_kaggle_credentials_spy = mocker.patch(
"torchtune._cli.download.set_kaggle_credentials"
)

runpy.run_path(TUNE_PATH, run_name="__main__")

set_kaggle_credentials_spy.assert_called_once_with(
expected_username, expected_api_key
)
output = capsys.readouterr().out
assert (
"TIP: you can avoid passing in the --kaggle-username and --kaggle-api-key"
in output
)
assert (
"For more details, see https://github.com/Kaggle/kagglehub/blob/main/README.md#authenticate"
in output
)

# passes partial credentials with just --kaggle-username (expect fallback to environment variables)
@mock.patch.dict(os.environ, {"KAGGLE_KEY": "env_api_key"})
def test_download_from_kaggle_when_partial_credentials_provided(
self, capsys, monkeypatch, mocker, tmpdir
):
expected_username = "username"
expected_api_key = "env_api_key"
model = "metaresearch/llama-3.2/pytorch/1b"
testargs = f"tune download {model} --source kaggle --kaggle-username {expected_username}".split()
monkeypatch.setattr(sys, "argv", testargs)
# mock out kagglehub.model_download to get around key storage
mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir)
set_kaggle_credentials_spy = mocker.patch(
"torchtune._cli.download.set_kaggle_credentials"
)

runpy.run_path(TUNE_PATH, run_name="__main__")

set_kaggle_credentials_spy.assert_called_once_with(
expected_username, expected_api_key
)
output = capsys.readouterr().out
assert (
"TIP: you can avoid passing in the --kaggle-username and --kaggle-api-key"
in output
)
assert (
"For more details, see https://github.com/Kaggle/kagglehub/blob/main/README.md#authenticate"
in output
)

def test_download_from_kaggle_when_set_kaggle_credentials_throws(
self, monkeypatch, mocker, tmpdir
):
model = "metaresearch/llama-3.2/pytorch/1b"
testargs = f"tune download {model} --source kaggle --kaggle-username u --kaggle-api-key k".split()
monkeypatch.setattr(sys, "argv", testargs)
# mock out kagglehub.model_download to get around key storage
mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir)
mocker.patch(
"torchtune._cli.download.set_kaggle_credentials",
side_effect=Exception("some error"),
)

with pytest.warns(
UserWarning,
match="Failed to set Kaggle credentials with error",
):
runpy.run_path(TUNE_PATH, run_name="__main__")

# KaggleApiHTTPError::Unauthorized without --kaggle-username and --kaggle-api-key (expect prompt for credentials)
def test_download_from_kaggle_unauthorized_credentials(
self, capsys, monkeypatch, mocker
):
from http import HTTPStatus

from kagglehub.exceptions import KaggleApiHTTPError

model = "metaresearch/llama-3.2/pytorch/1b"
testargs = f"tune download {model} --source kaggle --kaggle-username username --kaggle-api-key key".split()
monkeypatch.setattr(sys, "argv", testargs)

mock_model_download = mocker.patch("torchtune._cli.download.model_download")
mock_model_download.side_effect = KaggleApiHTTPError(
"Unauthorized",
response=mocker.MagicMock(status_code=HTTPStatus.UNAUTHORIZED),
)

with pytest.raises(SystemExit, match="2"):
runpy.run_path(TUNE_PATH, run_name="__main__")

out_err = capsys.readouterr()
assert (
"Please ensure you have access to the model and have provided the proper Kaggle credentials"
in out_err.err
)
assert "You can also set these to environment variables" in out_err.err

# KaggleApiHTTPError::NotFound
def test_download_from_kaggle_model_not_found(self, capsys, monkeypatch, mocker):
from http import HTTPStatus

from kagglehub.exceptions import KaggleApiHTTPError

model = "mockorganizations/mockmodel/pytorch/mockvariation"
testargs = f"tune download {model} --source kaggle --kaggle-username kaggle_user --kaggle-api-key kaggle_api_key".split()
monkeypatch.setattr(sys, "argv", testargs)

mock_model_download = mocker.patch("torchtune._cli.download.model_download")
mock_model_download.side_effect = KaggleApiHTTPError(
"NotFound", response=mocker.MagicMock(status_code=HTTPStatus.NOT_FOUND)
)

with pytest.raises(SystemExit, match="2"):
runpy.run_path(TUNE_PATH, run_name="__main__")

out_err = capsys.readouterr()
assert f"'{model}' not found on the Kaggle Model Hub." in out_err.err

# KaggleApiHTTPError::InternalServerError
def test_download_from_kaggle_api_error(self, capsys, monkeypatch, mocker):
from http import HTTPStatus

from kagglehub.exceptions import KaggleApiHTTPError

model = "metaresearch/llama-3.2/pytorch/1b"
testargs = f"tune download {model} --source kaggle --kaggle-username kaggle_user --kaggle-api-key kaggle_api_key".split()
monkeypatch.setattr(sys, "argv", testargs)

mock_model_download = mocker.patch("torchtune._cli.download.model_download")
mock_model_download.side_effect = KaggleApiHTTPError(
"InternalError",
response=mocker.MagicMock(status_code=HTTPStatus.INTERNAL_SERVER_ERROR),
)

with pytest.raises(SystemExit, match="2"):
runpy.run_path(TUNE_PATH, run_name="__main__")

out_err = capsys.readouterr()
assert "Failed to download" in out_err.err

def test_download_from_kaggle_warn_on_nonmeta_pytorch_models(
self, monkeypatch, mocker, tmpdir
):
model = "kaggle/kaggle-model-name/pytorch/1b"
testargs = f"tune download {model} --source kaggle".split()
monkeypatch.setattr(sys, "argv", testargs)

# stub out model_download to guarantee success
mocker.patch(
"torchtune._cli.download.model_download",
return_value=tmpdir,
)

with pytest.warns(UserWarning, match="may not be compatible with torchtune"):
runpy.run_path(TUNE_PATH, run_name="__main__")

def test_download_from_kaggle_warn_on_nonpytorch_nontransformers_model(
self, monkeypatch, mocker, tmpdir
):
model = "metaresearch/some-model/some-madeup-framework/1b"
testargs = f"tune download {model} --source kaggle".split()
monkeypatch.setattr(sys, "argv", testargs)

# stub out model_download to guarantee success
mocker.patch(
"torchtune._cli.download.model_download",
return_value=tmpdir,
)

with pytest.warns(UserWarning, match="may not be compatible with torchtune"):
runpy.run_path(TUNE_PATH, run_name="__main__")
Loading
Loading