Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 116 additions & 7 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from lmdeploy.messages import PytorchEngineConfig
from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend
from lmdeploy.pytorch.utils import maybe_register_config_serialize_by_value
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


def _update_torch_dtype(config: 'ModelConfig', dtype: str):
Expand All @@ -18,9 +21,6 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str):
dtype (str): user specified data type. Refer to
`PyTorchEngineConfig.dtype` for detailed info
"""
from lmdeploy.utils import get_logger
logger = get_logger('lmdeploy')

quantization_config = getattr(config.hf_config, 'quantization_config', dict())
quant_method = quantization_config.get('quant_method', None)
if quant_method == 'awq':
Expand Down Expand Up @@ -99,8 +99,6 @@ class CacheConfig:

def __post_init__(self):
"""Post init."""
from lmdeploy.utils import get_logger
logger = get_logger('lmdeploy')
if self.window_size > 1 and self.enable_prefix_caching:
logger.warning('Prefix caching is not available for window attention.')
self.enable_prefix_caching = False
Expand Down Expand Up @@ -263,6 +261,32 @@ def _default_check_env(device: str):
pass


def _patch_quantization_config(hf_config: Any, model_format: str = None):
"""Patch quantization config."""
if model_format is None:
return hf_config

if hasattr(hf_config, 'quantization_config'):
logger.warning('Can not perform weight quantization on quantized model.')
return hf_config

if model_format == 'fp8':
logger.debug('Patch quantization config for fp8.')
from lmdeploy.pytorch.envs import scale_fmt
quantization_config = dict(quant_method='fp8', fmt='e4m3', weight_block_size=[128, 128], scale_fmt=scale_fmt)
else:
raise RuntimeError(f'Unsupported weight quantization method: {model_format}')

hf_config.quantization_config = quantization_config
# for vlm models
if hasattr(hf_config, 'text_config'):
hf_config.text_config.quantization_config = quantization_config
elif hasattr(hf_config, 'llm_config'):
hf_config.llm_config.quantization_config = quantization_config

return hf_config


@dataclass
class ModelConfig:
"""Config of model."""
Expand Down Expand Up @@ -304,6 +328,9 @@ class ModelConfig:
# check env for model-device combination
check_env_func: Callable = _default_check_env

# quant config
quant_config: 'QuantizationConfig' = None

def get_head_size(self):
"""Get head size."""
return self.head_dim
Expand All @@ -318,6 +345,7 @@ def from_pretrained(
hf_overrides: Dict[str, Any] = None,
is_draft_model: bool = False,
spec_method: str = None,
model_format: str = None,
):
"""Instantiate one of the configuration classes of the library from a
pretrained model configuration.
Expand All @@ -333,12 +361,14 @@ def from_pretrained(
from transformers import AutoConfig

from lmdeploy.pytorch.transformers import config_from_pretrained
from lmdeploy.utils import get_logger
hf_config = config_from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
if getattr(hf_config, 'model_type', None) in ['phi3']:
# phi3 + trust_remote_code leads to error when tp.
hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path)

# update quantization config
hf_config = _patch_quantization_config(hf_config, model_format=model_format)

model_config = cls.from_hf_config(
hf_config,
pretrained_model_name_or_path,
Expand All @@ -349,13 +379,14 @@ def from_pretrained(
)

if hf_overrides is not None:
logger = get_logger('lmdeploy')
logger.warning(f'Overriding HF config with {hf_overrides}')
override_hf_config(model_config.hf_config, hf_overrides)

# for serialization of transformers modules
maybe_register_config_serialize_by_value(trust_remote_code)

# add quant_config
model_config.quant_config = QuantizationConfig.from_config(hf_config)
return model_config

@classmethod
Expand Down Expand Up @@ -516,3 +547,81 @@ def from_config(
num_speculative_tokens=num_speculative_tokens,
)
return obj


@dataclass
class QuantizationConfig:
quant_method: str = None
quant_dtype: torch.dtype = None
scale_fmt: str = None
bits: int = None
group_size: int = None
weight_block_size: Tuple[int] = None
activation_scheme: str = None
ignored_layers: List[str] = field(default_factory=list)
hf_quant_config: Dict[str, Any] = field(default_factory=dict)

@classmethod
def from_config(cls, hf_config: Any):
quant_config = getattr(hf_config, 'quantization_config', None)
if quant_config is None:
return cls()

quant_method = quant_config['quant_method']
quant_dtype = quant_config.get('quant_dtype', None)
scale_fmt = quant_config.get('scale_fmt', None)
weight_block_size = quant_config.get('weight_block_size', None)
activation_scheme = quant_config.get('activation_scheme', None)

bits = None
group_size = None

if quant_method == 'awq':
bits = quant_config.get('bits', 4)
group_size = quant_config.get('group_size', 128)
elif quant_method == 'smooth_quant':
if quant_dtype is None:
quant_dtype = 'int8'
elif quant_method == 'fp8':
fmt = quant_config.get('fmt', 'e4m3')
if fmt == 'e4m3':
quant_dtype = 'float8_e4m3fn'
elif fmt == 'e5m2':
quant_dtype = 'float8_e5m2'
else:
raise TypeError(f'Unsupported fp8 fmt: {fmt}')
else:
raise TypeError(f'Unsupported quant method: {quant_method}')

if quant_dtype is not None:
quant_dtype = eval(f'torch.{quant_dtype}')

ignored_layers = quant_config.get('ignored_layers', [])
if not ignored_layers:
ignored_layers = quant_config.get('modules_to_not_convert', [])

return cls(
quant_method=quant_method,
quant_dtype=quant_dtype,
scale_fmt=scale_fmt,
bits=bits,
group_size=group_size,
weight_block_size=weight_block_size,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
hf_quant_config=quant_config,
)

def get_quant_method(self, prefix: str = ''):
"""Get quant method for module."""
if not prefix or not self.ignored_layers:
return self.quant_method

is_ignore = any([prefix in layer_name for layer_name in self.ignored_layers])
Copy link

Copilot AI Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method check on line 600 uses substring matching with prefix in layer_name, which could lead to false positives. For example, if prefix is "layer" and ignored_layers contains "my_layer_norm", it would incorrectly match.

A more robust approach would be to check if the prefix exactly matches the beginning of the layer name or matches a complete segment. Consider using layer_name.startswith(prefix) or checking for exact component matches with proper delimiter handling.

Suggested change
is_ignore = any([prefix in layer_name for layer_name in self.ignored_layers])
def _matches_prefix(p: str, layer_name: str) -> bool:
"""Return True if p and layer_name refer to the same module or
one is a dotted-prefix of the other.
"""
if not p or not layer_name:
return False
if p == layer_name:
return True
if layer_name.startswith(p + '.'):
return True
if p.startswith(layer_name + '.'):
return True
return False
is_ignore = any(_matches_prefix(prefix, layer_name)
for layer_name in self.ignored_layers)

Copilot uses AI. Check for mistakes.
quant_method = None if is_ignore else self.quant_method
print(f'ignore quantization: {is_ignore}, use quant method: {quant_method} Layer {prefix} ')
return quant_method

def get(self, key, default=None):
"""Get extra key from hf quant config."""
return self.hf_quant_config.get(key, default)
1 change: 1 addition & 0 deletions lmdeploy/pytorch/engine/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def build_executor(
dist_config=dist_config,
is_draft_model=False,
spec_method=None if specdecode_config is None else specdecode_config.method,
model_format=misc_config.model_format,
)

if distributed_executor_backend is None:
Expand Down
7 changes: 2 additions & 5 deletions lmdeploy/pytorch/engine/model_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,11 +1066,9 @@ def _build_model(self):
dllm_config=self.misc_config.dllm_config,
strategy_factory=self.strategy_factory,
enable_return_routed_experts=enable_return_routed_experts,
quant_config=self.model_config.quant_config,
)
patched_model = build_patched_model(self.model_config,
device=device,
model_format=self.misc_config.model_format,
build_model_ctx=build_model_ctx)
patched_model = build_patched_model(self.model_config, device=device, build_model_ctx=build_model_ctx)
logger.debug(msg_with_rank(rank, 'loading weights.'))
if not self.misc_config.empty_init:
load_model_weights(patched_model, model_path, device=device)
Expand All @@ -1086,7 +1084,6 @@ def build_model(self):
self._build_model()
self.spec_agent.build_model(self.misc_config.empty_init,
self.patched_model,
model_format=self.misc_config.model_format,
build_model_ctx=self.build_model_ctx)

def build_graph_runner(self):
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/model_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# from torch import distributed as dist
import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch.backends import get_backend
from lmdeploy.pytorch.config import CacheConfig, DLLMConfig, ModelConfig
from lmdeploy.pytorch.config import CacheConfig, DLLMConfig, ModelConfig, QuantizationConfig
from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
from lmdeploy.pytorch.utils import CtxMgrBase, singleton

Expand Down Expand Up @@ -390,6 +390,7 @@ class BuildModelContext:
dllm_config: DLLMConfig = None
strategy_factory: 'StrategyFactoryBase' = None
enable_return_routed_experts: bool = False
quant_config: QuantizationConfig = field(default_factory=QuantizationConfig)


class StepContextManager(CtxMgrBase[StepContext]):
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,8 @@ def prepare_inputs_for_generation(
image_mask=image_mask,
)

def rename_weight(self, name: str) -> str:
@classmethod
def rename_weight(cls, name: str) -> str:
"""Rename weight."""
if name.startswith('model.language_model.'):
return 'language_model.' + name[len('model.language_model.'):]
Expand Down
14 changes: 10 additions & 4 deletions lmdeploy/pytorch/models/interns1_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .interns1_pro_ts import InternS1ProTimeSeriesModel
from .patch import get_build_model_context
from .patch import add_prefix, get_build_model_context
from .qwen3_moe import Qwen3MoeModel
from .qwen3_vl import Qwen3VLVisionModel
from .utils.cudagraph import CudaGraphMixin
Expand All @@ -38,7 +38,8 @@ def __init__(self,
config: PretrainedConfig,
ctx_mgr: StepContextManager,
dtype: torch.dtype = None,
device: torch.device = None):
device: torch.device = None,
prefix: str = ''):
super().__init__()

self.config = config
Expand All @@ -52,10 +53,14 @@ def __init__(self,
config.vision_config,
dtype=dtype,
device=device,
prefix=add_prefix('visual', prefix=prefix),
)

# build text model
self.language_model = Qwen3MoeModel(config.text_config, dtype=dtype, device=device)
self.language_model = Qwen3MoeModel(config.text_config,
dtype=dtype,
device=device,
prefix=add_prefix('language_model', prefix=prefix))

# build lm_head
self.lm_head = build_rowwise_linear(config.text_config.hidden_size,
Expand Down Expand Up @@ -233,7 +238,8 @@ def prepare_inputs_for_generation(
ts_mask=ts_mask,
)

def rename_weight(self, name: str) -> str:
@classmethod
def rename_weight(cls, name: str) -> str:
"""Rename weight."""
if name.startswith('model.language_model.'):
return 'language_model.' + name[len('model.language_model.'):]
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/models/internvl3_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,8 @@ def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], adapter

return load_lora_weights(weights, adapter_id)

def rename_weight(self, name: str) -> str:
@classmethod
def rename_weight(cls, name: str) -> str:
"""Rename weight."""
if name == 'lm_head.weight':
return 'language_model.lm_head.weight'
Expand Down
34 changes: 10 additions & 24 deletions lmdeploy/pytorch/models/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,38 +197,19 @@ def build_model_from_hf_config(model_config: PretrainedConfig,
if device is None:
device = torch.device('cuda')
model_cls = _get_model_class(model_config, module_map)
# update quant config
if build_model_ctx is not None and hasattr(model_cls, 'update_quant_config'):
build_model_ctx.quant_config = model_cls.update_quant_config(build_model_ctx.quant_config)

with build_model_context(build_model_ctx):
model = model_cls(model_config, ctx_mgr, dtype=dtype, device=device)
return model.eval()


def _patch_quantization_config(model_config: PretrainedConfig, model_format: str):
"""Patch quantization config."""
if model_format is None:
return

if hasattr(model_config, 'quantization_config'):
logger.warning('Can not perform weight quantization on quantized model.')
return

if model_format == 'fp8':
logger.debug('Patch quantization config for fp8.')
from lmdeploy.pytorch.envs import scale_fmt
quantization_config = dict(quant_method='fp8', fmt='e4m3', weight_block_size=[128, 128], scale_fmt=scale_fmt)
else:
raise RuntimeError(f'Unsupported weight quantization method: {model_format}')
model_config.quantization_config = quantization_config


@torch.inference_mode()
def build_patched_model(config: ModelConfig,
device: torch.device = None,
model_format: str = None,
build_model_ctx: 'BuildModelContext' = None):
def build_patched_model(config: ModelConfig, device: torch.device = None, build_model_ctx: 'BuildModelContext' = None):
"""Build patched model."""
model_config = config.hf_config
llm_config = config.llm_config
_patch_quantization_config(llm_config, model_format)
dtype = config.dtype
return build_model_from_hf_config(model_config, dtype=dtype, device=device, build_model_ctx=build_model_ctx)

Expand Down Expand Up @@ -353,3 +334,8 @@ def get_build_model_context() -> BuildModelContext:
"""Get build model context."""
global BUILD_MODEL_CTX
return BUILD_MODEL_CTX


def add_prefix(name: str, prefix: str) -> str:
"""Add prefix to module name."""
return name if not prefix else f'{prefix}.{name}'
Loading