diff --git a/mii/legacy/utils.py b/mii/legacy/utils.py index 062a9524..f1a7cb59 100644 --- a/mii/legacy/utils.py +++ b/mii/legacy/utils.py @@ -42,7 +42,7 @@ def _get_hf_models_by_type(model_type=None, task=None): if (model_data["cache_time"] + cache_expiration_seconds) < current_time: api = HfApi() model_data["model_list"] = [ - SimpleNamespace(modelId=m.modelId, + SimpleNamespace(id=m.id, pipeline_tag=m.pipeline_tag, tags=m.tags) for m in api.list_models() ] @@ -60,7 +60,7 @@ def _get_hf_models_by_type(model_type=None, task=None): models = [m for m in models if m.pipeline_tag == task] # Extract model IDs - model_ids = [m.modelId for m in models] + model_ids = [m.id for m in models] if task == TaskType.TEXT_GENERATION: # TODO: this is a temp solution to get around some HF models not having the correct tags diff --git a/mii/utils.py b/mii/utils.py index 3958b09b..cf030cfe 100644 --- a/mii/utils.py +++ b/mii/utils.py @@ -31,7 +31,7 @@ @dataclass class ModelInfo: - modelId: str + id: str pipeline_tag: str tags: List[str] @@ -53,7 +53,7 @@ def _hf_model_list() -> List[ModelInfo]: if (model_data["cache_time"] + cache_expiration_seconds) < current_time: api = HfApi() model_data["model_list"] = [ - ModelInfo(modelId=m.modelId, + ModelInfo(id=m.id, pipeline_tag=m.pipeline_tag, tags=m.tags) for m in api.list_models() ] @@ -70,7 +70,7 @@ def get_default_task(model_name_or_path: str) -> str: model_name = get_model_name(model_name_or_path) models = _hf_model_list() for m in models: - if m.modelId == model_name: + if m.id == model_name: task = m.pipeline_tag logger.info(f"Detected default task as '{task}' for model '{model_name}'") return task diff --git a/scripts/model_download.py b/scripts/model_download.py index c4a8490d..6c57b911 100755 --- a/scripts/model_download.py +++ b/scripts/model_download.py @@ -27,7 +27,7 @@ def __init__(self, model_str): def hf_model(model_str): api = HfApi() - models = [m.modelId for m in api.list_models()] + models = [m.id for m in api.list_models()] if model_str in models: return model_str else: