Skip to content

Commit

Permalink
Update MII to switch from modelid to id (#507)
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Jul 25, 2024
1 parent 1cf5ebe commit ff5e2fc
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions mii/legacy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _get_hf_models_by_type(model_type=None, task=None):
if (model_data["cache_time"] + cache_expiration_seconds) < current_time:
api = HfApi()
model_data["model_list"] = [
SimpleNamespace(modelId=m.modelId,
SimpleNamespace(id=m.id,
pipeline_tag=m.pipeline_tag,
tags=m.tags) for m in api.list_models()
]
Expand All @@ -60,7 +60,7 @@ def _get_hf_models_by_type(model_type=None, task=None):
models = [m for m in models if m.pipeline_tag == task]

# Extract model IDs
model_ids = [m.modelId for m in models]
model_ids = [m.id for m in models]

if task == TaskType.TEXT_GENERATION:
# TODO: this is a temp solution to get around some HF models not having the correct tags
Expand Down
6 changes: 3 additions & 3 deletions mii/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

@dataclass
class ModelInfo:
modelId: str
id: str
pipeline_tag: str
tags: List[str]

Expand All @@ -53,7 +53,7 @@ def _hf_model_list() -> List[ModelInfo]:
if (model_data["cache_time"] + cache_expiration_seconds) < current_time:
api = HfApi()
model_data["model_list"] = [
ModelInfo(modelId=m.modelId,
ModelInfo(id=m.id,
pipeline_tag=m.pipeline_tag,
tags=m.tags) for m in api.list_models()
]
Expand All @@ -70,7 +70,7 @@ def get_default_task(model_name_or_path: str) -> str:
model_name = get_model_name(model_name_or_path)
models = _hf_model_list()
for m in models:
if m.modelId == model_name:
if m.id == model_name:
task = m.pipeline_tag
logger.info(f"Detected default task as '{task}' for model '{model_name}'")
return task
Expand Down
2 changes: 1 addition & 1 deletion scripts/model_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, model_str):

def hf_model(model_str):
api = HfApi()
models = [m.modelId for m in api.list_models()]
models = [m.id for m in api.list_models()]
if model_str in models:
return model_str
else:
Expand Down

0 comments on commit ff5e2fc

Please sign in to comment.