Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ cache
__pycache__
*.pyc
.pytest_cache
*.tgz
*.tgz
.huggingface
7 changes: 7 additions & 0 deletions api/core/v1alpha1/model_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,18 @@ type ModelHub struct {
// the whole repo which includes all kinds of quantized models.
// TODO: this is only supported with Huggingface, add support for ModelScope
// in the near future.
// Note: once filename is set, allowPatterns and ignorePatterns should be left unset.
Filename *string `json:"filename,omitempty"`
// Revision refers to a Git revision id which can be a branch name, a tag, or a commit hash.
// +kubebuilder:default=main
// +optional
Revision *string `json:"revision,omitempty"`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's update the comment of Filename as well,

	// Filename refers to a specified model file rather than the whole repo.
	// This is helpful to download a specified GGUF model rather than downloading
	// the whole repo which includes all kinds of quantized models.
	// TODO: this is only supported with Huggingface, add support for ModelScope
	// in the near future.
        // Note: once filename is set, allowPatterns and ignorePatterns should be left unset.

// AllowPatterns refers to files matched with at least one pattern will be downloaded.
// +optional
AllowPatterns []string `json:"allowPatterns,omitempty"`
// IgnorePatterns refers to files matched with any of the patterns will not be downloaded.
// +optional
IgnorePatterns []string `json:"ignorePatterns,omitempty"`
}

// URIProtocol represents the protocol of the URI.
Expand Down
10 changes: 10 additions & 0 deletions api/core/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 26 additions & 4 deletions client-go/applyconfiguration/core/v1alpha1/modelhub.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions config/crd/bases/llmaz.io_openmodels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,25 @@ spec:
description: ModelHub represents the model registry for model
downloads.
properties:
allowPatterns:
description: AllowPatterns refers to only files matching at
least one pattern are downloaded.
items:
type: string
type: array
filename:
description: |-
Filename refers to a specified model file rather than the whole repo.
This is helpful to download a specified GGUF model rather than downloading
the whole repo which includes all kinds of quantized models.
in the near future.
type: string
ignorePatterns:
description: IgnorePatterns refers to files matching any of
the patterns are not downloaded.
items:
type: string
type: array
modelID:
description: |-
ModelID refers to the model identifier on model hub,
Expand Down
33 changes: 20 additions & 13 deletions llmaz/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,39 @@
import os
from datetime import datetime

from llmaz.model_loader.constant import *

from llmaz.model_loader.objstore.objstore import model_download
from llmaz.model_loader.model_hub.hub_factory import HubFactory
from llmaz.model_loader.model_hub.huggingface import HUGGING_FACE
from llmaz.model_loader.model_hub.huggingface import HUB_HUGGING_FACE
from llmaz.util.logger import Logger


if __name__ == "__main__":
model_source_type = os.getenv("MODEL_SOURCE_TYPE")
model_source_type = os.getenv(ENV_HUB_MODEL_SOURCE_TYPE)
start_time = datetime.now()

if model_source_type == "modelhub":
hub_name = os.getenv("MODEL_HUB_NAME", HUGGING_FACE)
revision = os.getenv("REVISION")
model_id = os.getenv("MODEL_ID")
model_file_name = os.getenv("MODEL_FILENAME")
hub_name = os.getenv(ENV_HUB_MODEL_HUB_NAME, HUB_HUGGING_FACE)
revision = os.getenv(ENV_HUB_REVISION)
model_id = os.getenv(ENV_HUB_MODEL_ID)
model_file_name = os.getenv(ENV_HUB_MODEL_FILENAME)
model_allow_patterns = os.getenv(ENV_HUB_MODEL_ALLOW_PATTERNS)
model_ignore_patterns = os.getenv(ENV_HUB_MODEL_IGNORE_PATTERNS)

if not model_id:
raise EnvironmentError(f"Environment variable '{model_id}' not found.")

hub = HubFactory.new(hub_name)
hub.load_model(model_id, model_file_name, revision)
model_allow_patterns_list, model_ignore_patterns_list = [], []
if model_allow_patterns:
model_allow_patterns_list = model_allow_patterns.split(',')
if model_ignore_patterns:
model_ignore_patterns_list = model_ignore_patterns.split(',')
hub.load_model(model_id, model_file_name, revision, model_allow_patterns_list, model_ignore_patterns_list)
elif model_source_type == "objstore":
provider = os.getenv("PROVIDER")
endpoint = os.getenv("ENDPOINT")
bucket = os.getenv("BUCKET")
src = os.getenv("MODEL_PATH")
provider = os.getenv(ENV_OBJ_PROVIDER)
endpoint = os.getenv(ENV_OBJ_ENDPOINT)
bucket = os.getenv(ENV_OBJ_BUCKET)
src = os.getenv(ENV_OBJ_MODEL_PATH)

model_download(provider=provider, endpoint=endpoint, bucket=bucket, src=src)
else:
Expand Down
16 changes: 16 additions & 0 deletions llmaz/model_loader/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
MODEL_LOCAL_DIR = "/workspace/models/"
HUB_HUGGING_FACE = "Huggingface"
HUB_MODEL_SCOPE = "ModelScope"

ENV_HUB_MODEL_SOURCE_TYPE = "MODEL_SOURCE_TYPE"
ENV_HUB_MODEL_HUB_NAME = "MODEL_HUB_NAME"
ENV_HUB_REVISION = "REVISION"
ENV_HUB_MODEL_ID = "MODEL_ID"
ENV_HUB_MODEL_FILENAME = "MODEL_FILENAME"
ENV_HUB_MODEL_ALLOW_PATTERNS = "MODEL_ALLOW_PATTERNS"
ENV_HUB_MODEL_IGNORE_PATTERNS = "MODEL_IGNORE_PATTERNS"

ENV_OBJ_PROVIDER = "PROVIDER"
ENV_OBJ_ENDPOINT = "ENDPOINT"
ENV_OBJ_BUCKET = "BUCKET"
ENV_OBJ_MODEL_PATH = "MODEL_PATH"
1 change: 0 additions & 1 deletion llmaz/model_loader/defaults.py

This file was deleted.

12 changes: 6 additions & 6 deletions llmaz/model_loader/model_hub/hub_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from llmaz.model_loader.constant import HUB_HUGGING_FACE, HUB_MODEL_SCOPE
from llmaz.model_loader.model_hub.model_hub import ModelHub
from llmaz.model_loader.model_hub.huggingface import HUGGING_FACE, Huggingface
from llmaz.model_loader.model_hub.modelscope import MODEL_SCOPE, ModelScope

from llmaz.model_loader.model_hub.huggingface import Huggingface
from llmaz.model_loader.model_hub.modelscope import ModelScope

SUPPORT_MODEL_HUBS = {
HUGGING_FACE: Huggingface,
MODEL_SCOPE: ModelScope,
HUB_HUGGING_FACE: Huggingface,
HUB_MODEL_SCOPE: ModelScope,
}


class HubFactory:

@classmethod
def new(cls, hub_name: str) -> ModelHub:
if hub_name not in SUPPORT_MODEL_HUBS.keys():
Expand Down
64 changes: 23 additions & 41 deletions llmaz/model_loader/model_hub/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,69 +17,51 @@
import concurrent.futures
import os

from huggingface_hub import hf_hub_download, list_repo_files
from huggingface_hub import snapshot_download

from llmaz.model_loader.defaults import MODEL_LOCAL_DIR
from llmaz.model_loader.constant import MODEL_LOCAL_DIR, HUB_HUGGING_FACE
from llmaz.model_loader.model_hub.model_hub import (
HUGGING_FACE,
MAX_WORKERS,
ModelHub,
)
from llmaz.util.logger import Logger
from llmaz.model_loader.model_hub.util import get_folder_total_size

from typing import Optional
from typing import Optional, List


class Huggingface(ModelHub):
@classmethod
def name(cls) -> str:
return HUGGING_FACE
return HUB_HUGGING_FACE

@classmethod
def load_model(
cls, model_id: str, filename: Optional[str], revision: Optional[str]
cls,
model_id: str,
filename: Optional[str],
revision: Optional[str],
allow_patterns: Optional[List[str]],
ignore_patterns: Optional[List[str]],
) -> None:
Logger.info(
f"Start to download, model_id: {model_id}, filename: {filename}, revision: {revision}"
)

if filename:
hf_hub_download(
repo_id=model_id,
filename=filename,
local_dir=MODEL_LOCAL_DIR,
revision=revision,
)
file_size = os.path.getsize(MODEL_LOCAL_DIR + filename) / (1024**3)
Logger.info(
f"The total size of {MODEL_LOCAL_DIR + filename} is {file_size: .2f} GB"
)
return

local_dir = os.path.join(
MODEL_LOCAL_DIR, f"models--{model_id.replace('/','--')}"
MODEL_LOCAL_DIR, f"models--{model_id.replace('/', '--')}"
)

# # TODO: Should we verify the download is finished?
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
futures = []
for file in list_repo_files(repo_id=model_id):
# TODO: support version management, right now we didn't distinguish with them.
futures.append(
executor.submit(
hf_hub_download,
repo_id=model_id,
filename=file,
local_dir=local_dir,
revision=revision,
).add_done_callback(handle_completion)
)
if filename:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be more simple here, if filename, the allow_patterns will be the filename. Actually, In OP is not that accurate, we may have pattern like *.json.

And we should add a validation in the webhook as Once filename is set, the both patterns should be nil.

allow_patterns.append(filename)
local_dir = MODEL_LOCAL_DIR

snapshot_download(
repo_id=model_id,
revision=revision,
local_dir=local_dir,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)

total_size = get_folder_total_size(local_dir)
Logger.info(f"The total size of {local_dir} is {total_size: .2f} GB")


def handle_completion(future):
filename = future.result()
Logger.info(f"Download completed for {filename}")
Logger.info(f"The total size of {local_dir} is {total_size: .2f} GB")
13 changes: 7 additions & 6 deletions llmaz/model_loader/model_hub/model_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@
"""

from abc import ABC, abstractmethod
from typing import Optional

MAX_WORKERS = 4
HUGGING_FACE = "Huggingface"
MODEL_SCOPE = "ModelScope"
from typing import Optional, List


class ModelHub(ABC):
Expand All @@ -31,6 +27,11 @@ def name(cls) -> str:
@classmethod
@abstractmethod
def load_model(
cls, model_id: str, filename: Optional[str], revision: Optional[str]
cls,
model_id: str,
filename: Optional[str],
revision: Optional[str],
allow_patterns: Optional[List[str]],
ignore_patterns: Optional[List[str]],
) -> None:
pass
Loading
Loading