Skip to content

Commit

Permalink
Cache HF model list for inference tests (microsoft#4940)
Browse files Browse the repository at this point in the history
Cache the model list in blob storage so it can be shared across CI
runners. Code borrowed from MII:
https://github.com/microsoft/DeepSpeed-MII/blob/95d1e1c8890a016f2b5788414754abbbfd4540ae/mii/utils.py#L39
  • Loading branch information
mrwyattii authored and amaurya committed Feb 17, 2024
1 parent d657da0 commit a7eea69
Showing 1 changed file with 57 additions and 17 deletions.
74 changes: 57 additions & 17 deletions tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,33 @@

# DeepSpeed Team

import os
import time
import pickle
import torch
import pytest

import itertools
import pickle
import os
import time

from dataclasses import dataclass
from typing import List

import deepspeed
from deepspeed.git_version_info import torch_info
from unit.common import DistributedTest
import torch

from huggingface_hub import HfApi
from packaging import version as pkg_version
from deepspeed.ops.op_builder import OpBuilder
from torch import nn
from transformers import pipeline, AutoTokenizer
from transformers.models.t5.modeling_t5 import T5Block
from transformers.models.roberta.modeling_roberta import RobertaLayer
from huggingface_hub import HfApi
from deepspeed.model_implementations import DeepSpeedTransformerInference
from torch import nn

from deepspeed.accelerator import get_accelerator
from deepspeed.git_version_info import torch_info
from deepspeed.model_implementations import DeepSpeedTransformerInference
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.ops.op_builder import OpBuilder

from unit.common import DistributedTest

rocm_version = OpBuilder.installed_rocm_version()
if rocm_version != (0, 0):
Expand Down Expand Up @@ -65,14 +73,46 @@
"text2text-generation", "summarization", "translation"
]


@dataclass
class ModelInfo:
modelId: str
pipeline_tag: str
tags: List[str]


def _hf_model_list() -> List[ModelInfo]:
""" Caches HF model list to avoid repeated API calls """

cache_dir = os.getenv("TRANSFORMERS_CACHE", "~/.cache/huggingface")
cache_file_path = os.path.join(cache_dir, "DS_model_cache.pkl")
cache_expiration_seconds = 60 * 60 * 24 # 1 day

# Load or initialize the cache
model_data = {"cache_time": 0, "model_list": []}
if os.path.isfile(cache_file_path):
with open(cache_file_path, 'rb') as f:
model_data = pickle.load(f)

current_time = time.time()

# Update the cache if it has expired
if (model_data["cache_time"] + cache_expiration_seconds) < current_time:
api = HfApi()
model_data["model_list"] = [
ModelInfo(modelId=m.modelId, pipeline_tag=m.pipeline_tag, tags=m.tags) for m in api.list_models()
]
model_data["cache_time"] = current_time

# Save the updated cache
with open(cache_file_path, 'wb') as f:
pickle.dump(model_data, f)

return model_data["model_list"]


# Get a list of all models and mapping from task to supported models
try:
with open("hf_models.pkl", "rb") as fp:
_hf_models = pickle.load(fp)
except FileNotFoundError:
_hf_models = list(HfApi().list_models())
with open("hf_models.pkl", "wb") as fp:
pickle.dump(_hf_models, fp)
_hf_models = _hf_model_list()
_hf_model_names = [m.modelId for m in _hf_models]
_hf_task_to_models = {task: [m.modelId for m in _hf_models if m.pipeline_tag == task] for task in _test_tasks}

Expand Down

0 comments on commit a7eea69

Please sign in to comment.