Skip to content

Add support for deepseek_r1_qwen based models #1884

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2448,6 +2448,119 @@ def get_falcon_spec(self, model):
self._num_kv_attr = "num_kv_heads"


@register_loader("Qwen2Config")
class DeepSeekR1DistillQwen2Loader(ModelLoader):
@property
def architecture_name(self):
return "Qwen2ForCausalLM"

def get_model_spec(self, model):
num_layers = model.config.num_hidden_layers

num_heads = model.config.num_attention_heads
num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
if num_heads_kv == num_heads:
num_heads_kv = None

sliding_window = getattr(model.config, "sliding_window", 0)
max_window_layers = getattr(model.config, "max_window_layers", None)

rope_scaling = getattr(model.config, "rope_scaling", None)
if rope_scaling:
rope_type = rope_scaling.get("type") or rope_scaling.get("rope_type")
rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_type)
rotary_scaling_factor = rope_scaling["factor"]

if rotary_scaling_type is None:
raise NotImplementedError(
"RoPE scaling type '%s' is not yet implemented. "
"The following RoPE scaling types are currently supported: %s"
% (rope_type, ", ".join(_SUPPORTED_ROPE_SCALING.keys()))
)
else:
rotary_scaling_type = None
rotary_scaling_factor = 1

spec = transformer_spec.TransformerDecoderModelSpec.from_config(
num_layers,
num_heads,
activation=common_spec.Activation.SWISH,
pre_norm=True,
ffn_glu=True,
rms_norm=True,
rotary_dim=0,
rotary_interleave=False,
rotary_scaling_type=rotary_scaling_type,
rotary_scaling_factor=rotary_scaling_factor,
rotary_base=getattr(model.config, "rope_theta", 10000),
num_heads_kv=num_heads_kv,
sliding_window=sliding_window if sliding_window > 0 else None,
head_dim=getattr(model.config, "head_dim", None),
)

self.set_decoder(spec.decoder, model.model)
self.set_linear(spec.decoder.projection, model.lm_head)
return spec

def get_vocabulary(self, model, tokenizer):
tokens = super().get_vocabulary(model, tokenizer)

extra_ids = model.config.vocab_size - len(tokens)
for i in range(extra_ids):
tokens.append("<extra_id_%d>" % i)
return tokens

def set_vocabulary(self, spec, tokens):
spec.register_vocabulary(tokens)

def set_config(self, config, model, tokenizer):
config.bos_token = (
tokenizer.bos_token
if tokenizer.bos_token is not None
else tokenizer.pad_token
)
config.eos_token = tokenizer.eos_token
config.unk_token = (
tokenizer.unk_token if tokenizer.unk_token is not None else ""
)
config.layer_norm_epsilon = model.config.rms_norm_eps

def set_layer_norm(self, spec, layer_norm):
spec.gamma = layer_norm.weight

def set_decoder(self, spec, module):
spec.scale_embeddings = False
self.set_embeddings(spec.embeddings, module.embed_tokens)
self.set_layer_norm(spec.layer_norm, module.norm)

for layer_spec, layer in zip(spec.layer, module.layers):
self.set_layer_norm(
layer_spec.self_attention.layer_norm, layer.input_layernorm
)
self.set_layer_norm(
layer_spec.ffn.layer_norm, layer.post_attention_layernorm
)

split_layers = [common_spec.LinearSpec() for _ in range(3)]
self.set_linear(split_layers[0], layer.self_attn.q_proj)
self.set_linear(split_layers[1], layer.self_attn.k_proj)
self.set_linear(split_layers[2], layer.self_attn.v_proj)

utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
self.set_linear(
layer_spec.self_attention.linear[1],
layer.self_attn.o_proj,
)

self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)

delattr(layer, "self_attn")
delattr(layer, "mlp")
gc.collect()


@register_loader("DistilBertConfig")
class DistilBertLoader(ModelLoader):
@property
Expand Down
Loading