Skip to content

Commit

Permalink
[+] Automatically download models
Browse files Browse the repository at this point in the history
  • Loading branch information
hykilpikonna committed Oct 14, 2023
1 parent 3f6ef73 commit ac2f4ba
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions msclap/CLAPWrapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from pathlib import Path
import warnings
warnings.filterwarnings("ignore")
Expand All @@ -16,35 +18,45 @@
import argparse
import yaml
import sys
from huggingface_hub.file_download import hf_hub_download
logging.set_verbosity_error()


class CLAPWrapper():
"""
A class for interfacing CLAP model.
"""
model_repo = "microsoft/msclap"
model_name = {
'2022': 'CLAP_weights_2022.pth',
'2023': 'CLAP_weights_2023.pth',
'clapcap': 'clapcap_weights_2023.pth'
}

def __init__(self, model_fp: Path | str | None = None, version: str = '2023', use_cuda=False):
# Check if version is supported
self.supported_versions = self.model_name.keys()
if version not in self.supported_versions:
raise ValueError(f"The version {version} is not supported. The supported versions are {str(self.supported_versions)}")

def __init__(self, model_fp, version, use_cuda=False):
self.supported_versions = ['2022', '2023', 'clapcap']
self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
self.file_path = os.path.realpath(__file__)
self.default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}")
self.config_as_str = self.get_config_path(version)
self.config_as_str = (Path(__file__).parent / f"configs/config_{version}.yml").read_text()

# Automatically download model if not provided
if not model_fp:
model_fp = hf_hub_download(self.model_repo, self.model_name[version])

self.model_fp = model_fp
self.use_cuda = use_cuda
if 'clapcap' in version:
self.clapcap, self.tokenizer, self.args = self.load_clapcap()
else:
self.clap, self.tokenizer, self.args = self.load_clap()

def get_config_path(self, version):
if version in self.supported_versions:
return (Path(__file__).parent / f"configs/config_{version}.yml").read_text()
else:
raise ValueError(f"The specific version is not supported. The supported versions are {str(self.supported_versions)}")

def read_config_as_args(self,config_path,args=None,is_config_str=False):
return_dict = {}

Expand Down

0 comments on commit ac2f4ba

Please sign in to comment.