Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add rope alibi to encoder #1687

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions python/ctranslate2/converters/opennmt_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def check_opt(opt, num_source_embeddings):
activation_fn = getattr(opt, "pos_ffn_activation_fn", "relu")
feat_merge = getattr(opt, "feat_merge", "concat")
self_attn_type = getattr(opt, "self_attn_type", "scaled-dot")
if self_attn_type == "scaled-dot-flash":
self_attn_type = "scaled-dot"

check = utils.ConfigurationChecker()
check(
Expand Down Expand Up @@ -60,8 +62,20 @@ def _get_model_spec_seq2seq(
):
"""Creates a model specification from the model options."""
with_relative_position = getattr(opt, "max_relative_positions", 0) > 0
with_rotary = getattr(opt, "max_relative_positions", 0) == -1
with_alibi = getattr(opt, "max_relative_positions", 0) == -2
activation_fn = getattr(opt, "pos_ffn_activation_fn", "relu")
num_heads = getattr(opt, "heads", 8)
num_kv = getattr(opt, "num_kv", 0)
if num_kv == num_heads or num_kv == 0:
num_kv = None
rotary_dim = 0 if with_rotary else None
rotary_interleave = getattr(opt, "rotary_interleave", True)
ffn_glu = (activation_fn == "silu") or (activation_fn == "gated-gelu")
sliding_window = getattr(opt, "sliding_window", 0)

feat_merge = getattr(opt, "feat_merge", "concat")
layer_norm = getattr(opt, "layer_norm", "standard")

# Return the first head of the last layer unless the model was trained with alignments.
if getattr(opt, "lambda_align", 0) == 0:
Expand All @@ -71,20 +85,26 @@ def _get_model_spec_seq2seq(
alignment_layer = opt.alignment_layer
alignment_heads = opt.alignment_heads

num_heads = getattr(opt, "heads", 8)

model_spec = transformer_spec.TransformerSpec.from_config(
(opt.enc_layers, opt.dec_layers),
num_heads,
with_relative_position=with_relative_position,
activation=_SUPPORTED_ACTIVATIONS[activation_fn],
ffn_glu=ffn_glu,
with_relative_position=with_relative_position,
alibi=with_alibi,
rms_norm=layer_norm == "rms",
rotary_dim=rotary_dim,
rotary_interleave=rotary_interleave,
multi_query_attention=getattr(opt, "multiquery", False),
num_heads_kv=num_kv,
sliding_window=sliding_window,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
num_source_embeddings=num_source_embeddings,
embeddings_merge=_SUPPORTED_FEATURES_MERGE[feat_merge],
multi_query_attention=getattr(opt, "multiquery", False),
)

model_spec.config.layer_norm_epsilon = getattr(opt, "norm_eps", 1e-6)
model_spec.config.decoder_start_token = getattr(opt, "decoder_start_token", "<s>")

set_transformer_spec(model_spec, variables)
Expand Down
117 changes: 113 additions & 4 deletions python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,20 @@ def __init__(
relative_attention_bias: bool = False,
ffn_glu: bool = False,
rms_norm: bool = False,
alibi: bool = False,
alibi_use_positive_positions: bool = False,
scale_alibi: bool = False,
rotary_dim: Optional[int] = None,
rotary_interleave: bool = True,
rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None,
rotary_scaling_factor: float = 1,
rotary_base: float = 10000,
parallel_residual: bool = False,
shared_layer_norm: bool = False,
multi_query_attention: bool = False,
num_heads_kv: Optional[int] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = None,
):
"""Initializes a Transformer encoder specification.

Expand All @@ -43,9 +56,30 @@ def __init__(
ffn_glu: Use gated linear units in the FFN layers as described in
https://arxiv.org/abs/2002.05202.
rms_norm: Use the root mean square layer normalization.
multi_query_attention: Use multi-query attention.
alibi: Use attention with linear biases.
alibi_use_positive_positions: Use positive positions in the ALiBi definition.
scale_alibi: Apply the dot product scale factor to ALiBi.
rotary_dim: Apply rotary embeddings to these first N dimensions. If 0, rotary
embeddings are applied to all dimensions.
rotary_interleave: Interleave the head dimensions when rotary embeddings are applied.
Otherwise the head dimensions are sliced in half.
rotary_scaling_type: Type of RoPE scaling.
rotary_scaling_factor: Factor used in the RoPE scaling.
rotary_base: The base period of the rotary embeddings.
parallel_residual: Use parallel residual connections in each layer block, as used
by the GPT-J and GPT-NeoX models.
shared_layer_norm: When using parallel residual, share the input and post
attention layer norms.
multi_query_attention: Use multi-query attention (alias for num_heads_kv=1).
num_heads_kv: Number of attention heads for the key and value.
sliding_window: Max sequence length to retain in KV Cache.
"""
self.multi_query_attention = multi_query_attention
if multi_query_attention:
if num_heads_kv is not None and num_heads_kv != 1:
raise ValueError(
"Enabling multi_query_attention implies num_heads_kv=1"
)
num_heads_kv = 1
self.num_heads = np.dtype("int16").type(num_heads)
self.pre_norm = pre_norm
self.activation = np.dtype("int8").type(activation)
Expand All @@ -54,7 +88,17 @@ def __init__(
common_spec.EmbeddingsSpec() for _ in range(num_source_embeddings)
]
self.scale_embeddings = True
if not relative_position and not relative_attention_bias:
self.alibi = alibi
self.alibi_use_positive_positions = alibi_use_positive_positions
self.scale_alibi = scale_alibi
if sliding_window is not None:
self.sliding_window = np.dtype("int32").type(sliding_window)
if (
not relative_position
and not relative_attention_bias
and not alibi
and rotary_dim is None
):
self.position_encodings = PositionEncoderSpec()
if pre_norm and not no_final_norm:
self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
Expand All @@ -66,10 +110,22 @@ def __init__(
relative_attention_bias=relative_attention_bias,
ffn_glu=ffn_glu,
rms_norm=rms_norm,
num_heads_kv=1 if multi_query_attention else None,
rotary_dim=rotary_dim,
rotary_interleave=rotary_interleave,
rotary_scaling_type=rotary_scaling_type,
rotary_scaling_factor=rotary_scaling_factor,
rotary_base=rotary_base,
parallel_residual=parallel_residual,
shared_layer_norm=shared_layer_norm,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
sliding_window=sliding_window,
)
for _ in range(num_layers)
]
self.multi_query_attention = multi_query_attention or (
num_heads_kv != num_heads
)


class TransformerDecoderSpec(model_spec.LayerSpec):
Expand Down Expand Up @@ -224,15 +280,29 @@ def __init__(
relative_attention_bias=False,
ffn_glu=False,
rms_norm=False,
rotary_dim=None,
rotary_interleave=True,
rotary_scaling_type=None,
rotary_scaling_factor=1,
rotary_base=10000,
parallel_residual=False,
shared_layer_norm=False,
num_heads_kv=None,
head_dim=None,
sliding_window=None,
):
self.self_attention = attention_spec.MultiHeadAttentionSpec(
self_attention=True,
relative_position=relative_position,
relative_attention_bias=relative_attention_bias,
rms_norm=rms_norm,
rotary_dim=rotary_dim,
rotary_interleave=rotary_interleave,
rotary_scaling_type=rotary_scaling_type,
rotary_scaling_factor=rotary_scaling_factor,
rotary_base=rotary_base,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
sliding_window=sliding_window,
)
self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm)
Expand Down Expand Up @@ -364,7 +434,20 @@ def from_config(
relative_attention_bias: bool = False,
ffn_glu: bool = False,
rms_norm: bool = False,
alibi: bool = False,
alibi_use_positive_positions: bool = False,
scale_alibi: bool = False,
rotary_dim: Optional[int] = None,
rotary_interleave: bool = True,
rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None,
rotary_scaling_factor: float = 1,
rotary_base: float = 10000,
parallel_residual: bool = False,
shared_layer_norm: bool = False,
multi_query_attention: bool = False,
num_heads_kv: Optional[int] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = None,
):
"""Creates a Transformer model specification.

Expand Down Expand Up @@ -408,7 +491,20 @@ def from_config(
relative_attention_bias=relative_attention_bias,
ffn_glu=ffn_glu,
rms_norm=rms_norm,
alibi=alibi,
alibi_use_positive_positions=alibi_use_positive_positions,
scale_alibi=scale_alibi,
rotary_dim=rotary_dim,
rotary_interleave=rotary_interleave,
rotary_scaling_type=rotary_scaling_type,
rotary_scaling_factor=rotary_scaling_factor,
rotary_base=rotary_base,
parallel_residual=parallel_residual,
shared_layer_norm=shared_layer_norm,
multi_query_attention=multi_query_attention,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
sliding_window=sliding_window,
)

decoder = TransformerDecoderSpec(
Expand All @@ -424,7 +520,20 @@ def from_config(
alignment_heads=alignment_heads,
ffn_glu=ffn_glu,
rms_norm=rms_norm,
alibi=alibi,
alibi_use_positive_positions=alibi_use_positive_positions,
scale_alibi=scale_alibi,
rotary_dim=rotary_dim,
rotary_interleave=rotary_interleave,
rotary_scaling_type=rotary_scaling_type,
rotary_scaling_factor=rotary_scaling_factor,
rotary_base=rotary_base,
parallel_residual=parallel_residual,
shared_layer_norm=shared_layer_norm,
multi_query_attention=multi_query_attention,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
sliding_window=sliding_window,
)

return cls(encoder, decoder)
Expand Down
Loading