Skip to content

Commit 2448dcf

Browse files
authored
support for model scope (#957)
* support for model scope Signed-off-by: n1ck-guo <heng.guo@intel.com>
1 parent 3c1a678 commit 2448dcf

File tree

13 files changed

+218
-32
lines changed

13 files changed

+218
-32
lines changed

auto_round/__main__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ def __init__(self, *args, **kwargs):
4545
help="Path to the pre-trained model or model identifier from huggingface.co/models. "
4646
"Examples: 'facebook/opt-125m', 'bert-base-uncased', or local path like '/path/to/model'",
4747
)
48+
basic.add_argument(
49+
"--platform",
50+
default="hf",
51+
help="Platform to load the pre-trained model. Options: [hf, model_scope]."
52+
" hf stands for huggingface and model_scope stands for model scope.",
53+
)
4854
basic.add_argument(
4955
"--scheme",
5056
default="W4A16",
@@ -566,6 +572,7 @@ def tune(args):
566572

567573
autoround: BaseCompressor = AutoRound(
568574
model=model_name,
575+
platform=args.platform,
569576
scheme=scheme,
570577
dataset=args.dataset,
571578
iters=args.iters,

auto_round/autoround.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class AutoRound:
4343
Attributes:
4444
model (torch.nn.Module): The loaded PyTorch model in eval mode.
4545
tokenizer: Tokenizer used to prepare input text for calibration/tuning.
46+
platform (str): The platform to load pretrained moded, options: ["hf", "model_scope"]
4647
bits (int): Weight quantization bits.
4748
group_size (int): Per-group size for weight quantization.
4849
sym (bool): Whether to use symmetric weight quantization.
@@ -67,6 +68,7 @@ def __new__(
6768
cls,
6869
model: Union[torch.nn.Module, str],
6970
tokenizer=None,
71+
platform: str = "hf",
7072
scheme: Union[str, dict, QuantizationScheme, AutoScheme] = "W4A16",
7173
layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None,
7274
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k",
@@ -146,7 +148,7 @@ def __new__(
146148
"""
147149
model_cls = []
148150

149-
if (extra_config and not extra_config.mllm_config.is_default()) or is_mllm_model(model):
151+
if (extra_config and not extra_config.mllm_config.is_default()) or is_mllm_model(model, platform=platform):
150152
logger.info("using MLLM mode for multimodal model.")
151153
model_cls.append(MLLMCompressor)
152154
if extra_config:
@@ -170,6 +172,7 @@ def __new__(
170172
ar = dynamic_compressor(
171173
model=model,
172174
tokenizer=tokenizer,
175+
platform=platform,
173176
scheme=scheme,
174177
layer_config=layer_config,
175178
dataset=dataset,
@@ -314,6 +317,7 @@ def __init__(
314317
self,
315318
model: Union[torch.nn.Module, str],
316319
tokenizer=None,
320+
platform: str = "hf",
317321
scheme: Union[str, dict, QuantizationScheme] = "W4A16",
318322
layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None,
319323
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k",
@@ -331,6 +335,7 @@ def __init__(
331335
super().__init__(
332336
model=model,
333337
tokenizer=tokenizer,
338+
platform=platform,
334339
scheme=scheme,
335340
layer_config=layer_config,
336341
dataset=dataset,
@@ -354,6 +359,7 @@ class AutoRoundAdam(AdamCompressor):
354359
Args:
355360
model: The PyTorch model to be quantized.
356361
tokenizer: An optional tokenizer for processing input data.
362+
platform (str): The platform to load pretrained moded, options: ["hf", "model_scope"]
357363
scheme (str| dict | QuantizationScheme ): A preset scheme that defines the quantization configurations
358364
bits (int): Number of bits for quantization (default is 4).
359365
group_size (int): Size of the quantization group (default is 128).
@@ -413,6 +419,7 @@ def __init__(
413419
self,
414420
model: Union[torch.nn.Module, str],
415421
tokenizer=None,
422+
platform: str = "hf",
416423
scheme: Union[str, dict, QuantizationScheme] = "W4A16",
417424
layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None,
418425
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k",
@@ -431,6 +438,7 @@ def __init__(
431438
super().__init__(
432439
model=model,
433440
tokenizer=tokenizer,
441+
platform=platform,
434442
scheme=scheme,
435443
layer_config=layer_config,
436444
batch_size=batch_size,
@@ -455,6 +463,7 @@ class AutoRoundMLLM(MLLMCompressor):
455463
Args:
456464
model: The PyTorch model to be quantized.
457465
tokenizer: An optional tokenizer for processing input data.
466+
platform (str): The platform to load pretrained moded, options: ["hf", "model_scope"]
458467
processor: Any multi-modal model will require an object to encode or
459468
decode the data that groups several modalities (among text, vision and audio).
460469
image_processor: Image processor for special model like llava.
@@ -513,6 +522,7 @@ def __init__(
513522
self,
514523
model: Union[torch.nn.Module, str],
515524
tokenizer=None,
525+
platform: str = "hf",
516526
processor=None,
517527
image_processor=None,
518528
scheme: Union[str, dict, QuantizationScheme] = "W4A16",
@@ -533,6 +543,7 @@ def __init__(
533543
super().__init__(
534544
model=model,
535545
tokenizer=tokenizer,
546+
platform=platform,
536547
processor=processor,
537548
image_processor=image_processor,
538549
scheme=scheme,
@@ -559,6 +570,7 @@ class AutoRoundDiffusion(DiffusionCompressor):
559570
Args:
560571
model: The PyTorch model to be quantized.
561572
tokenizer: An optional tokenizer for processing input data, is not used for diffusion models.
573+
platform (str): The platform to load pretrained moded, options: ["hf", "model_scope"]
562574
guidance_scale (float): Control how much the image generation process follows the text prompt.
563575
The more it is, the more closely it follows the prompt (default is 7.5).
564576
num_inference_steps (int): The reference number of denoising steps (default is 50).

auto_round/compressors/adam.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class AdamCompressor(BaseCompressor):
2727
Args:
2828
model: The PyTorch model to be quantized.
2929
tokenizer: An optional tokenizer for processing input data.
30+
platform (str): The platform to load pretrained moded, options: ["hf", "model_scope"]
3031
scheme (str| dict | QuantizationScheme ): A preset scheme that defines the quantization configurations
3132
bits (int): Number of bits for quantization (default is 4).
3233
group_size (int): Size of the quantization group (default is 128).
@@ -86,6 +87,7 @@ def __init__(
8687
self,
8788
model: Union[torch.nn.Module, str],
8889
tokenizer=None,
90+
platform="hf",
8991
scheme: Union[str, dict, QuantizationScheme] = "W4A16",
9092
layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None,
9193
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k",
@@ -104,6 +106,7 @@ def __init__(
104106
super(AdamCompressor, self).__init__(
105107
model=model,
106108
tokenizer=tokenizer,
109+
platform=platform,
107110
scheme=scheme,
108111
layer_config=layer_config,
109112
batch_size=batch_size,

auto_round/compressors/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from tqdm import tqdm
3131
from transformers import set_seed
3232

33+
from auto_round import envs
3334
from auto_round.auto_scheme.gen_auto_scheme import AutoScheme
3435
from auto_round.compressors.utils import (
3536
block_forward,
@@ -105,6 +106,7 @@ class BaseCompressor(object):
105106
Attributes:
106107
model (torch.nn.Module): The loaded PyTorch model in eval mode.
107108
tokenizer: Tokenizer used to prepare input text for calibration/tuning.
109+
platform (str): The platform to load pretrained moded, options: ["hf", "model_scope"]
108110
bits (int): Weight quantization bits.
109111
group_size (int): Per-group size for weight quantization.
110112
sym (bool): Whether to use symmetric weight quantization.
@@ -129,6 +131,7 @@ def __init__(
129131
self,
130132
model: Union[torch.nn.Module, str],
131133
tokenizer=None,
134+
platform="hf",
132135
scheme: Union[str, dict, QuantizationScheme, AutoScheme] = "W4A16",
133136
layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None,
134137
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k",
@@ -228,6 +231,10 @@ def __init__(
228231
device = kwargs.pop("device", None)
229232
# Scale factor for RAM usage per parameter.
230233
mem_per_param_scale = kwargs.pop("mem_per_param_scale", None)
234+
235+
if envs.AR_USE_MODELSCOPE:
236+
platform = "model_scope"
237+
self.platform = platform
231238
self.quant_lm_head = kwargs.pop("quant_lm_head", False)
232239
self.mllm = kwargs.pop("mllm") if "mllm" in kwargs else False
233240
self.diffusion = kwargs.pop("diffusion") if "diffusion" in kwargs else False
@@ -259,6 +266,7 @@ def __init__(
259266
if isinstance(model, str):
260267
model, tokenizer = llm_load_model(
261268
model,
269+
platform=platform,
262270
device="cpu", # always load cpu first
263271
)
264272
elif tokenizer is None and not self.diffusion and iters > 0:

auto_round/compressors/diffusion/compressor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class DiffusionCompressor(BaseCompressor):
4747
Args:
4848
model: The PyTorch model to be quantized.
4949
tokenizer: An optional tokenizer for processing input data, is not used for diffusion models.
50+
platform (str): The platform to load pretrained moded, options: ["hf", "model_scope"]
5051
guidance_scale (float): Control how much the image generation process follows the text prompt.
5152
The more it is, the more closely it follows the prompt (default is 7.5).
5253
num_inference_steps (int): The reference number of denoising steps (default is 50).
@@ -81,6 +82,7 @@ def __init__(
8182
self,
8283
model: Union[object, str],
8384
tokenizer=None,
85+
platform: str = "hf",
8486
guidance_scale: float = 7.5,
8587
num_inference_steps: int = 50,
8688
generator_seed: int = None,
@@ -110,7 +112,7 @@ def __init__(
110112
self._set_device(device_map)
111113

112114
if isinstance(model, str):
113-
pipe, model = diffusion_load_model(model, device=self.device)
115+
pipe, model = diffusion_load_model(model, platform=platform, device=self.device)
114116
elif isinstance(model, pipeline_utils.DiffusionPipeline):
115117
pipe = model
116118
model = pipe.transformer
@@ -145,6 +147,7 @@ def __init__(
145147
super(DiffusionCompressor, self).__init__(
146148
model=model,
147149
tokenizer=None,
150+
platform=platform,
148151
scheme=scheme,
149152
layer_config=layer_config,
150153
dataset=dataset,

auto_round/compressors/mllm/compressor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class MLLMCompressor(BaseCompressor):
8787
Args:
8888
model: The PyTorch model to be quantized.
8989
tokenizer: An optional tokenizer for processing input data.
90+
platform (str): The platform to load pretrained moded, options: ["hf", "model_scope"]
9091
processor: Any multi-modal model will require an object to encode or
9192
decode the data that groups several modalities (among text, vision and audio).
9293
image_processor: Image processor for special model like llava.
@@ -145,6 +146,7 @@ def __init__(
145146
self,
146147
model: Union[torch.nn.Module, str],
147148
tokenizer=None,
149+
platform: str = "hf",
148150
processor=None,
149151
image_processor=None,
150152
scheme: Union[str, dict, QuantizationScheme] = "W4A16",
@@ -171,7 +173,7 @@ def __init__(
171173
self._set_device(device_map)
172174

173175
if isinstance(model, str):
174-
model, processor, tokenizer, image_processor = mllm_load_model(model, device=self.device)
176+
model, processor, tokenizer, image_processor = mllm_load_model(model, platform=platform, device=self.device)
175177

176178
self.model = model
177179
quant_nontext_module = self._check_quant_nontext(layer_config, quant_nontext_module)
@@ -258,6 +260,7 @@ def __init__(
258260
super(MLLMCompressor, self).__init__(
259261
model=model,
260262
tokenizer=tokenizer,
263+
platform=platform,
261264
scheme=scheme,
262265
layer_config=layer_config,
263266
dataset=dataset,
@@ -374,6 +377,7 @@ def calib(self, nsamples, bs):
374377
continue
375378
try:
376379
if isinstance(data_new, torch.Tensor):
380+
data_new = data_new.to(self.model.device)
377381
self.model(data_new)
378382
elif isinstance(data_new, tuple) or isinstance(data_new, list):
379383
self.model(*data_new)

auto_round/compressors/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def gguf_args_check(args_or_ar, formats: list[str] = None, model_type=ModelType.
480480

481481
from auto_round.export.export_to_gguf.convert import download_convert_file
482482
from auto_round.logger import logger
483-
from auto_round.utils.model import download_hf_model, get_gguf_architecture
483+
from auto_round.utils.model import download_or_get_path, get_gguf_architecture
484484

485485
formats = sorted(formats, key=lambda x: len(x))
486486
export_gguf = False
@@ -505,7 +505,7 @@ def gguf_args_check(args_or_ar, formats: list[str] = None, model_type=ModelType.
505505
else:
506506
model_path = args_or_ar.model.name_or_path
507507
if not os.path.isdir(model_path):
508-
model_path = download_hf_model(model_path)
508+
model_path = download_or_get_path(model_path, args_or_ar.platform)
509509
model_architecture = get_gguf_architecture(model_path, model_type=ModelType.TEXT)
510510
if model_architecture not in ModelBase._model_classes[ModelType.TEXT]:
511511
logger.warning(
@@ -539,7 +539,7 @@ def gguf_args_check(args_or_ar, formats: list[str] = None, model_type=ModelType.
539539
else:
540540
model_path = args_or_ar.model.name_or_path
541541
if not os.path.isdir(model_path):
542-
model_path = download_hf_model(model_path)
542+
model_path = download_or_get_path(model_path, args_or_ar.platform)
543543
model_architecture = get_gguf_architecture(model_path, model_type=ModelType.TEXT)
544544
if model_architecture not in ModelBase._model_classes[ModelType.TEXT]:
545545
logger.error(f"Model {model_architecture} is not supported to export gguf format.")

auto_round/envs.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818

1919
if TYPE_CHECKING:
2020
AR_LOG_LEVEL: str = "INFO"
21+
AR_USE_MODELSCOPE: bool = "False"
2122

2223
environment_variables: dict[str, Callable[[], Any]] = {
2324
# this is used for configuring the default logging level
2425
"AR_LOG_LEVEL": lambda: os.getenv("AR_LOG_LEVEL", "INFO").upper(),
26+
"AR_USE_MODELSCOPE": lambda: os.getenv("AR_USE_MODELSCOPE", "False").lower() in ["1", "true"],
2527
}
2628

2729

@@ -41,3 +43,30 @@ def is_set(name: str):
4143
if name in environment_variables:
4244
return name in os.environ
4345
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
46+
47+
48+
def set_config(**kwargs):
49+
"""
50+
Set configuration values for environment variables.
51+
52+
Args:
53+
**kwargs: Keyword arguments where keys are environment variable names
54+
and values are the desired values to set.
55+
56+
Example:
57+
set_config(AR_LOG_LEVEL="DEBUG", AR_USE_MODELSCOPE=True)
58+
"""
59+
for key, value in kwargs.items():
60+
if key in environment_variables:
61+
# Convert value to appropriate string format
62+
if key == "AR_USE_MODELSCOPE":
63+
# Handle boolean values for AR_USE_MODELSCOPE
64+
str_value = "true" if value in [True, "True", "true", "1", 1] else "false"
65+
else:
66+
# For other variables, convert to string
67+
str_value = str(value)
68+
69+
# Set the environment variable
70+
os.environ[key] = str_value
71+
else:
72+
raise AttributeError(f"module {__name__!r} has no attribute {key!r}")

auto_round/export/export_to_gguf/convert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,11 @@ def is_extra_tensor(tensor_name):
167167
from safetensors import safe_open
168168

169169
from auto_round.export.export_to_gguf.special_handle import get_tensor_from_file
170-
from auto_round.utils import download_hf_model
170+
from auto_round.utils import download_or_get_path
171171

172172
dir_path = cls.model.name_or_path
173173
if not os.path.isdir(dir_path):
174-
dir_path = download_hf_model(dir_path)
174+
dir_path = download_or_get_path(dir_path)
175175
INDEX_FILE = "model.safetensors.index.json"
176176
if INDEX_FILE in os.listdir(dir_path):
177177
with open(os.path.join(dir_path, INDEX_FILE)) as f:

auto_round/export/export_to_gguf/export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
LazyImport,
2929
check_to_quantized,
3030
clear_memory,
31-
download_hf_model,
31+
download_or_get_path,
3232
flatten_list,
3333
get_block_names,
3434
get_gguf_architecture,
@@ -77,7 +77,7 @@ def create_model_class(
7777
tmp_work_dir = model.name_or_path
7878
os.makedirs(output_dir, exist_ok=True)
7979
if not os.path.isdir(tmp_work_dir):
80-
tmp_work_dir = download_hf_model(tmp_work_dir)
80+
tmp_work_dir = download_or_get_path(tmp_work_dir)
8181
with torch.inference_mode():
8282
model_architecture = get_gguf_architecture(tmp_work_dir, model_type=model_type)
8383
try:

0 commit comments

Comments
 (0)