@@ -323,9 +323,20 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str
323323 cfg .setdefault (key , copy .deepcopy (default_dict .get (key )))
324324
325325 # 5. collect supported modules
326+ embedding_types = (torch .nn .Embedding ,)
326327 gguf_name = get_gguf_scheme (default_scheme )
327- if gguf_name and torch .nn .Embedding not in supported_types :
328- supported_types = (* supported_types , torch .nn .Embedding )
328+ if gguf_name :
329+ if torch .nn .Embedding not in supported_types :
330+ supported_types = (* supported_types , torch .nn .Embedding )
331+
332+ # for some Embedding which type() is not torch.nn.Embedding
333+ # for example: transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding
334+ model_module_name = model .__class__ .__module__
335+ module_cls = sys .modules [model_module_name ]
336+ for name in module_cls .__dict__ :
337+ if name .endswith ("Embedding" ) and not name .endswith ("RotaryEmbedding" ):
338+ embedding_types = (* embedding_types , getattr (module_cls , name ))
339+ supported_types = (* supported_types , * embedding_types )
329340
330341 all_supported_layer_names , embedding_layer_names = [], []
331342 all_module_names = []
@@ -338,7 +349,7 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str
338349 if type (m ) not in supported_types and m .__class__ .__name__ not in inner_supported_types :
339350 continue
340351 all_supported_layer_names .append (n )
341- if isinstance (m , torch . nn . Embedding ):
352+ if isinstance (m , embedding_types ) or m . __class__ . __name__ . endswith ( " Embedding" ):
342353 embedding_layer_names .append (n )
343354
344355 # 6. expand regex configs
@@ -650,7 +661,7 @@ def get_layer_config_by_gguf_format(layer_config, target_gguf_format: str, model
650661
651662 import gguf # pylint: disable=E0401
652663
653- from auto_round .utils .common import LazyImport
664+ from auto_round .utils .common import MM_KEYS , LazyImport
654665 from auto_round .utils .model import get_lm_head_name , get_module
655666
656667 # from auto_round.export.export_to_gguf.convert import ModelBase, get_model_architecture
@@ -660,24 +671,41 @@ def get_layer_config_by_gguf_format(layer_config, target_gguf_format: str, model
660671 hparams = model .config .to_dict (), model_type = model_type
661672 )
662673 try :
663- model_class = convert_hf_to_gguf .ModelBase .from_model_architecture (model_architecture , model_type = model_type )
674+ if model_type != ModelType .TEXT :
675+ model_class_vision = convert_hf_to_gguf .ModelBase .from_model_architecture (
676+ model_architecture , model_type = model_type
677+ )
678+ model_class = convert_hf_to_gguf .ModelBase .from_model_architecture (
679+ model_architecture , model_type = ModelType .TEXT
680+ )
681+
664682 except NotImplementedError :
665683 return layer_config , {}
666684
667685 n_layer = None
668- for name in ["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" ]:
669- sub_attr = "text_config" if model_type == ModelType .TEXT else "vision_config"
686+ if model_type != ModelType .TEXT :
687+ n_layer_vision = None
688+ for name in ["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" , "depth" ]:
670689 if hasattr (model .config , name ):
671690 n_layer = getattr (model .config , name )
672- break
673- if hasattr (model .config , sub_attr ):
674- if hasattr (getattr (model .config , sub_attr ), name ):
675- n_layer = getattr (getattr (model .config , sub_attr ), name )
691+ if model_type != ModelType .TEXT :
692+ if n_layer is not None and hasattr (model .config , "text_config" ):
693+ if hasattr (getattr (model .config , "text_config" ), name ):
694+ n_layer = getattr (getattr (model .config , "text_config" ), name )
695+ for config_name in ["vision_config" , "vision_encoder" ]:
696+ if hasattr (model .config , config_name ):
697+ if hasattr (getattr (model .config , config_name ), name ):
698+ n_layer_vision = getattr (getattr (model .config , config_name ), name )
699+ break
700+ if n_layer and n_layer_vision :
676701 break
702+
677703 if n_layer is None :
678704 return layer_config , {}
679705
680706 tensor_map = gguf .get_tensor_name_map (model_class .model_arch , n_layer )
707+ if model_type != ModelType .TEXT :
708+ tensor_map_vision = gguf .get_tensor_name_map (model_class_vision .model_arch , n_layer_vision )
681709
682710 def _set_config (config , target_config ):
683711 for k , v in target_config .items ():
@@ -733,7 +761,17 @@ def _set_config(config, target_config):
733761 re .search ("gguf:q([0-9]{1,})_[01k]" , GGUF_CONFIG [target_gguf_format ]["embedding" ]).group (1 )
734762 )
735763
736- gguf_name = tensor_map .get_name (layer_name )
764+ if model_type != ModelType .TEXT and any ([key in layer_name for key in MM_KEYS ]):
765+ gguf_name = tensor_map_vision .get_name (layer_name )
766+ if gguf_name is None :
767+ for key in MM_KEYS :
768+ gguf_name = tensor_map_vision .get_name (layer_name .replace (f".{ key } " , "" ))
769+ if gguf_name is not None :
770+ break
771+ else :
772+ gguf_name = tensor_map .get_name (layer_name )
773+ if gguf_name is None :
774+ gguf_name = tensor_map .get_name (layer_name .replace (".language_model" , "" ))
737775 bits_index = 6
738776 if config .get ("fixed_by_user" , False ):
739777 if "bits" not in config :
0 commit comments