Skip to content

Add exaone support #1881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2448,6 +2448,112 @@ def get_falcon_spec(self, model):
self._num_kv_attr = "num_kv_heads"


@register_loader("ExaoneConfig")
class ExaoneLoader(ModelLoader):
@property
def architecture_name(self):
return "ExaoneForCausalLM"

def get_model_spec(self, model):
num_layers = model.config.num_layers
num_heads = model.config.num_attention_heads
num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
if num_heads_kv == num_heads:
num_heads_kv = None

rope_scaling = getattr(model.config, "rope_scaling", None)
if rope_scaling:
rope_type = rope_scaling.get("type") or rope_scaling.get("rope_type")
rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_type)
rotary_scaling_factor = rope_scaling.get("factor", 1)

if rotary_scaling_type is None:
raise NotImplementedError(
f"RoPE scaling type '{rope_type}' is not yet implemented. "
f"These RoPE types are supported: {', '.join(_SUPPORTED_ROPE_SCALING.keys())}"
)
else:
rotary_scaling_type = None
rotary_scaling_factor = 1

spec = transformer_spec.TransformerDecoderModelSpec.from_config(
num_layers,
num_heads,
activation=common_spec.Activation.SWISH,
pre_norm=True,
ffn_glu=True,
rms_norm=True,
rotary_dim=0,
rotary_interleave=False,
rotary_scaling_type=rotary_scaling_type,
rotary_scaling_factor=rotary_scaling_factor,
rotary_base=getattr(model.config, "rope_theta", 10000),
num_heads_kv=num_heads_kv,
head_dim=getattr(model.config, "head_dim", None),
)

self.set_decoder(spec.decoder, model.transformer)
self.set_linear(spec.decoder.projection, model.lm_head)

if rotary_scaling_type == attention_spec.RotaryScalingType.Llama3:
for layer in spec.decoder.layer:
layer.self_attention.rotary_low_freq_factor = rope_scaling["low_freq_factor"]
layer.self_attention.rotary_high_freq_factor = rope_scaling["high_freq_factor"]

return spec

def get_vocabulary(self, model, tokenizer):
tokens = super().get_vocabulary(model, tokenizer)

extra_ids = model.config.vocab_size - len(tokens)
for i in range(extra_ids):
tokens.append(f"<extra_id_{i}>")

if model.config.vocab_size < len(tokens):
tokens = tokens[:model.config.vocab_size]

return tokens

def set_vocabulary(self, spec, tokens):
spec.register_vocabulary(tokens)

def set_config(self, config, model, tokenizer):
config.bos_token = tokenizer.bos_token
config.eos_token = tokenizer.eos_token
config.unk_token = tokenizer.unk_token
config.layer_norm_epsilon = model.config.layer_norm_epsilon

def set_layer_norm(self, spec, layer_norm):
spec.gamma = layer_norm.weight

def set_decoder(self, spec, transformer):
spec.scale_embeddings = False
self.set_embeddings(spec.embeddings, transformer.wte)
self.set_layer_norm(spec.layer_norm, transformer.ln_f)

for layer_spec, layer in zip(spec.layer, transformer.h):
self.set_layer_norm(
layer_spec.self_attention.layer_norm, layer.ln_1
)
self.set_layer_norm(
layer_spec.ffn.layer_norm, layer.ln_2
)

split_layers = [common_spec.LinearSpec() for _ in range(3)]
self.set_linear(split_layers[0], layer.attn.attention.q_proj)
self.set_linear(split_layers[1], layer.attn.attention.k_proj)
self.set_linear(split_layers[2], layer.attn.attention.v_proj)

utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
self.set_linear(
layer_spec.self_attention.linear[1], layer.attn.attention.out_proj
)

self.set_linear(layer_spec.ffn.linear_0, layer.mlp.c_fc_0)
self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.c_fc_1)
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.c_proj)


@register_loader("DistilBertConfig")
class DistilBertLoader(ModelLoader):
@property
Expand Down
Loading