Skip to content

Commit 287d540

Browse files
authored
Merge pull request #2 from sywangyi/align_hub_change
fix the import error of
2 parents a764f58 + c956841 commit 287d540

File tree

2 files changed

+6
-17
lines changed

2 files changed

+6
-17
lines changed

optimum/intel/neural_compressor/configuration.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from functools import reduce
1818
from typing import Any, Optional, Union
1919

20-
from transformers.file_utils import cached_path, hf_bucket_url
20+
from transformers.utils import cached_file
2121

2222
import yaml
2323
from neural_compressor.conf.config import Conf, Distillation_Conf, Pruning_Conf, Quantization_Conf
@@ -97,16 +97,11 @@ def from_pretrained(cls, config_name_or_path: str, config_file_name: Optional[st
9797
revision = kwargs.get("revision", None)
9898

9999
config_file_name = config_file_name if config_file_name is not None else CONFIG_NAME
100-
if os.path.isdir(config_name_or_path):
101-
config_file = os.path.join(config_name_or_path, config_file_name)
102-
elif os.path.isfile(config_name_or_path):
103-
config_file = config_name_or_path
104-
else:
105-
config_file = hf_bucket_url(config_name_or_path, filename=config_file_name, revision=revision)
106100

107101
try:
108-
resolved_config_file = cached_path(
109-
config_file,
102+
resolved_config_file = cached_file(
103+
config_name_or_path,
104+
config_file_name,
110105
cache_dir=cache_dir,
111106
force_download=force_download,
112107
resume_download=resume_download,

optimum/intel/neural_compressor/quantization.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
AutoModelForTokenClassification,
3434
XLNetLMHeadModel,
3535
)
36-
from transformers.file_utils import cached_path, hf_bucket_url
36+
from transformers.utils import cached_file
3737
from transformers.models.auto.auto_factory import _get_model_class
3838
from transformers.utils.versions import require_version
3939

@@ -271,15 +271,9 @@ def from_pretrained(
271271

272272
q_model_name = q_model_name if q_model_name is not None else WEIGHTS_NAME
273273
revision = download_kwargs.pop("revision", None)
274-
if os.path.isdir(model_name_or_path):
275-
state_dict_path = os.path.join(model_name_or_path, q_model_name)
276-
elif os.path.isfile(model_name_or_path):
277-
state_dict_path = model_name_or_path
278-
else:
279-
state_dict_path = hf_bucket_url(model_name_or_path, filename=q_model_name, revision=revision)
280274

281275
try:
282-
state_dict_path = cached_path(state_dict_path, **download_kwargs)
276+
state_dict_path = cached_file(model_name_or_path, q_model_name, **download_kwargs)
283277
except EnvironmentError as err:
284278
logger.error(err)
285279
msg = (

0 commit comments

Comments
 (0)