-
Notifications
You must be signed in to change notification settings - Fork 652
Support ignore layers in quant config for qwen3 models #4293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9d15aa4
4b4857e
86b5d55
6c7325e
c17316b
92a3a1d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -8,6 +8,9 @@ | |||||||||||||||||||||||||||||||||||
| from lmdeploy.messages import PytorchEngineConfig | ||||||||||||||||||||||||||||||||||||
| from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend | ||||||||||||||||||||||||||||||||||||
| from lmdeploy.pytorch.utils import maybe_register_config_serialize_by_value | ||||||||||||||||||||||||||||||||||||
| from lmdeploy.utils import get_logger | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| logger = get_logger('lmdeploy') | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def _update_torch_dtype(config: 'ModelConfig', dtype: str): | ||||||||||||||||||||||||||||||||||||
|
|
@@ -18,9 +21,6 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str): | |||||||||||||||||||||||||||||||||||
| dtype (str): user specified data type. Refer to | ||||||||||||||||||||||||||||||||||||
| `PyTorchEngineConfig.dtype` for detailed info | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| from lmdeploy.utils import get_logger | ||||||||||||||||||||||||||||||||||||
| logger = get_logger('lmdeploy') | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| quantization_config = getattr(config.hf_config, 'quantization_config', dict()) | ||||||||||||||||||||||||||||||||||||
| quant_method = quantization_config.get('quant_method', None) | ||||||||||||||||||||||||||||||||||||
| if quant_method == 'awq': | ||||||||||||||||||||||||||||||||||||
|
|
@@ -99,8 +99,6 @@ class CacheConfig: | |||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def __post_init__(self): | ||||||||||||||||||||||||||||||||||||
| """Post init.""" | ||||||||||||||||||||||||||||||||||||
| from lmdeploy.utils import get_logger | ||||||||||||||||||||||||||||||||||||
| logger = get_logger('lmdeploy') | ||||||||||||||||||||||||||||||||||||
| if self.window_size > 1 and self.enable_prefix_caching: | ||||||||||||||||||||||||||||||||||||
| logger.warning('Prefix caching is not available for window attention.') | ||||||||||||||||||||||||||||||||||||
| self.enable_prefix_caching = False | ||||||||||||||||||||||||||||||||||||
|
|
@@ -263,6 +261,32 @@ def _default_check_env(device: str): | |||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def _patch_quantization_config(hf_config: Any, model_format: str = None): | ||||||||||||||||||||||||||||||||||||
| """Patch quantization config.""" | ||||||||||||||||||||||||||||||||||||
| if model_format is None: | ||||||||||||||||||||||||||||||||||||
| return hf_config | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if hasattr(hf_config, 'quantization_config'): | ||||||||||||||||||||||||||||||||||||
| logger.warning('Can not perform weight quantization on quantized model.') | ||||||||||||||||||||||||||||||||||||
| return hf_config | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if model_format == 'fp8': | ||||||||||||||||||||||||||||||||||||
| logger.debug('Patch quantization config for fp8.') | ||||||||||||||||||||||||||||||||||||
| from lmdeploy.pytorch.envs import scale_fmt | ||||||||||||||||||||||||||||||||||||
| quantization_config = dict(quant_method='fp8', fmt='e4m3', weight_block_size=[128, 128], scale_fmt=scale_fmt) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| raise RuntimeError(f'Unsupported weight quantization method: {model_format}') | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| hf_config.quantization_config = quantization_config | ||||||||||||||||||||||||||||||||||||
| # for vlm models | ||||||||||||||||||||||||||||||||||||
| if hasattr(hf_config, 'text_config'): | ||||||||||||||||||||||||||||||||||||
| hf_config.text_config.quantization_config = quantization_config | ||||||||||||||||||||||||||||||||||||
| elif hasattr(hf_config, 'llm_config'): | ||||||||||||||||||||||||||||||||||||
| hf_config.llm_config.quantization_config = quantization_config | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| return hf_config | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| @dataclass | ||||||||||||||||||||||||||||||||||||
| class ModelConfig: | ||||||||||||||||||||||||||||||||||||
| """Config of model.""" | ||||||||||||||||||||||||||||||||||||
|
|
@@ -304,6 +328,9 @@ class ModelConfig: | |||||||||||||||||||||||||||||||||||
| # check env for model-device combination | ||||||||||||||||||||||||||||||||||||
| check_env_func: Callable = _default_check_env | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # quant config | ||||||||||||||||||||||||||||||||||||
| quant_config: 'QuantizationConfig' = None | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def get_head_size(self): | ||||||||||||||||||||||||||||||||||||
| """Get head size.""" | ||||||||||||||||||||||||||||||||||||
| return self.head_dim | ||||||||||||||||||||||||||||||||||||
|
|
@@ -318,6 +345,7 @@ def from_pretrained( | |||||||||||||||||||||||||||||||||||
| hf_overrides: Dict[str, Any] = None, | ||||||||||||||||||||||||||||||||||||
| is_draft_model: bool = False, | ||||||||||||||||||||||||||||||||||||
| spec_method: str = None, | ||||||||||||||||||||||||||||||||||||
| model_format: str = None, | ||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||
| """Instantiate one of the configuration classes of the library from a | ||||||||||||||||||||||||||||||||||||
| pretrained model configuration. | ||||||||||||||||||||||||||||||||||||
|
|
@@ -333,12 +361,14 @@ def from_pretrained( | |||||||||||||||||||||||||||||||||||
| from transformers import AutoConfig | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| from lmdeploy.pytorch.transformers import config_from_pretrained | ||||||||||||||||||||||||||||||||||||
| from lmdeploy.utils import get_logger | ||||||||||||||||||||||||||||||||||||
| hf_config = config_from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code) | ||||||||||||||||||||||||||||||||||||
| if getattr(hf_config, 'model_type', None) in ['phi3']: | ||||||||||||||||||||||||||||||||||||
| # phi3 + trust_remote_code leads to error when tp. | ||||||||||||||||||||||||||||||||||||
| hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # update quantization config | ||||||||||||||||||||||||||||||||||||
| hf_config = _patch_quantization_config(hf_config, model_format=model_format) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| model_config = cls.from_hf_config( | ||||||||||||||||||||||||||||||||||||
| hf_config, | ||||||||||||||||||||||||||||||||||||
| pretrained_model_name_or_path, | ||||||||||||||||||||||||||||||||||||
|
|
@@ -349,13 +379,14 @@ def from_pretrained( | |||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if hf_overrides is not None: | ||||||||||||||||||||||||||||||||||||
| logger = get_logger('lmdeploy') | ||||||||||||||||||||||||||||||||||||
| logger.warning(f'Overriding HF config with {hf_overrides}') | ||||||||||||||||||||||||||||||||||||
| override_hf_config(model_config.hf_config, hf_overrides) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # for serialization of transformers modules | ||||||||||||||||||||||||||||||||||||
| maybe_register_config_serialize_by_value(trust_remote_code) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # add quant_config | ||||||||||||||||||||||||||||||||||||
| model_config.quant_config = QuantizationConfig.from_config(hf_config) | ||||||||||||||||||||||||||||||||||||
| return model_config | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||||
|
|
@@ -516,3 +547,81 @@ def from_config( | |||||||||||||||||||||||||||||||||||
| num_speculative_tokens=num_speculative_tokens, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| return obj | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| @dataclass | ||||||||||||||||||||||||||||||||||||
| class QuantizationConfig: | ||||||||||||||||||||||||||||||||||||
| quant_method: str = None | ||||||||||||||||||||||||||||||||||||
| quant_dtype: torch.dtype = None | ||||||||||||||||||||||||||||||||||||
| scale_fmt: str = None | ||||||||||||||||||||||||||||||||||||
| bits: int = None | ||||||||||||||||||||||||||||||||||||
| group_size: int = None | ||||||||||||||||||||||||||||||||||||
| weight_block_size: Tuple[int] = None | ||||||||||||||||||||||||||||||||||||
| activation_scheme: str = None | ||||||||||||||||||||||||||||||||||||
| ignored_layers: List[str] = field(default_factory=list) | ||||||||||||||||||||||||||||||||||||
| hf_quant_config: Dict[str, Any] = field(default_factory=dict) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||||
| def from_config(cls, hf_config: Any): | ||||||||||||||||||||||||||||||||||||
| quant_config = getattr(hf_config, 'quantization_config', None) | ||||||||||||||||||||||||||||||||||||
| if quant_config is None: | ||||||||||||||||||||||||||||||||||||
| return cls() | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| quant_method = quant_config['quant_method'] | ||||||||||||||||||||||||||||||||||||
| quant_dtype = quant_config.get('quant_dtype', None) | ||||||||||||||||||||||||||||||||||||
| scale_fmt = quant_config.get('scale_fmt', None) | ||||||||||||||||||||||||||||||||||||
| weight_block_size = quant_config.get('weight_block_size', None) | ||||||||||||||||||||||||||||||||||||
| activation_scheme = quant_config.get('activation_scheme', None) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| bits = None | ||||||||||||||||||||||||||||||||||||
| group_size = None | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if quant_method == 'awq': | ||||||||||||||||||||||||||||||||||||
| bits = quant_config.get('bits', 4) | ||||||||||||||||||||||||||||||||||||
| group_size = quant_config.get('group_size', 128) | ||||||||||||||||||||||||||||||||||||
| elif quant_method == 'smooth_quant': | ||||||||||||||||||||||||||||||||||||
| if quant_dtype is None: | ||||||||||||||||||||||||||||||||||||
| quant_dtype = 'int8' | ||||||||||||||||||||||||||||||||||||
| elif quant_method == 'fp8': | ||||||||||||||||||||||||||||||||||||
| fmt = quant_config.get('fmt', 'e4m3') | ||||||||||||||||||||||||||||||||||||
| if fmt == 'e4m3': | ||||||||||||||||||||||||||||||||||||
| quant_dtype = 'float8_e4m3fn' | ||||||||||||||||||||||||||||||||||||
| elif fmt == 'e5m2': | ||||||||||||||||||||||||||||||||||||
| quant_dtype = 'float8_e5m2' | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| raise TypeError(f'Unsupported fp8 fmt: {fmt}') | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| raise TypeError(f'Unsupported quant method: {quant_method}') | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if quant_dtype is not None: | ||||||||||||||||||||||||||||||||||||
| quant_dtype = eval(f'torch.{quant_dtype}') | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| ignored_layers = quant_config.get('ignored_layers', []) | ||||||||||||||||||||||||||||||||||||
| if not ignored_layers: | ||||||||||||||||||||||||||||||||||||
| ignored_layers = quant_config.get('modules_to_not_convert', []) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| return cls( | ||||||||||||||||||||||||||||||||||||
| quant_method=quant_method, | ||||||||||||||||||||||||||||||||||||
| quant_dtype=quant_dtype, | ||||||||||||||||||||||||||||||||||||
| scale_fmt=scale_fmt, | ||||||||||||||||||||||||||||||||||||
| bits=bits, | ||||||||||||||||||||||||||||||||||||
| group_size=group_size, | ||||||||||||||||||||||||||||||||||||
| weight_block_size=weight_block_size, | ||||||||||||||||||||||||||||||||||||
| activation_scheme=activation_scheme, | ||||||||||||||||||||||||||||||||||||
| ignored_layers=ignored_layers, | ||||||||||||||||||||||||||||||||||||
| hf_quant_config=quant_config, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def get_quant_method(self, prefix: str = ''): | ||||||||||||||||||||||||||||||||||||
| """Get quant method for module.""" | ||||||||||||||||||||||||||||||||||||
| if not prefix or not self.ignored_layers: | ||||||||||||||||||||||||||||||||||||
| return self.quant_method | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| is_ignore = any([prefix in layer_name for layer_name in self.ignored_layers]) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
| is_ignore = any([prefix in layer_name for layer_name in self.ignored_layers]) | |
| def _matches_prefix(p: str, layer_name: str) -> bool: | |
| """Return True if p and layer_name refer to the same module or | |
| one is a dotted-prefix of the other. | |
| """ | |
| if not p or not layer_name: | |
| return False | |
| if p == layer_name: | |
| return True | |
| if layer_name.startswith(p + '.'): | |
| return True | |
| if p.startswith(layer_name + '.'): | |
| return True | |
| return False | |
| is_ignore = any(_matches_prefix(prefix, layer_name) | |
| for layer_name in self.ignored_layers) |
Uh oh!
There was an error while loading. Please reload this page.