Skip to content

Add linear biases changes for Qwen2 #160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions fast_llm/layers/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from fast_llm.functional.rotary import apply_rotary_embeddings
from fast_llm.functional.triton.rotary import triton_rotary_autograd_
from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear
from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs
from fast_llm.layers.transformer.config import (
TransformerConfig,
TransformerDimNames,
TransformerKwargs,
)
from fast_llm.logging import log_distributed_grad, log_distributed_tensor
from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_
from fast_llm.utils import Assert
Expand Down Expand Up @@ -102,7 +106,7 @@ def __init__(
self.query = OutputParallelLinear(
hidden_dim,
self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query),
bias=self._config.add_linear_biases,
bias=self._config.add_attn_qkv_bias,
weight_init_method=init_method_qkv,
bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_,
sequence_parallel=self._sequence_parallel,
Expand All @@ -111,7 +115,7 @@ def __init__(
self.key_value = OutputParallelLinear(
hidden_dim,
self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value),
bias=self._config.add_linear_biases,
bias=self._config.add_attn_qkv_bias,
weight_init_method=init_method_qkv,
bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_,
sequence_parallel=self._sequence_parallel,
Expand All @@ -123,7 +127,7 @@ def __init__(
self.dense = InputParallelLinear(
self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense),
hidden_dim,
bias=self._config.add_linear_biases,
bias=self._config.add_attn_dense_bias,
weight_init_method=init_method_std_attn_proj,
bias_init_method=init_method_std_attn_proj if self._config.random_bias_init else init_zeros_,
sequence_parallel=self._sequence_parallel,
Expand Down Expand Up @@ -274,18 +278,14 @@ def _query_key_value_backward(
input_grad.add_(self.key_value.backward(key_grad, context.pop("key_value")))
return input_grad


def _decide_window_size(self) -> int | None:
# NOTE: This is a temporal solution for qwen 2.X
# https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71
# TODO: make universal per layer config
window_size = self._config.window_size
if (
self._config.max_window_layers is not None
and self._layer_index < self._config.max_window_layers
):
if self._config.max_window_layers is not None and self._layer_index < self._config.max_window_layers:
window_size = None

return window_size

def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]:
Expand Down Expand Up @@ -337,9 +337,8 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[
query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q])
key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k])


window_size = self._decide_window_size()

if self._use_flash_attention:
input_ = flash_attn(
query,
Expand Down
49 changes: 44 additions & 5 deletions fast_llm/layers/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ class RotaryArchitectureConfig(BaseModelArchitectureConfig):
hint=FieldHint.feature,
)
beta_fast: float = Field(
default=32.,
default=32.0,
desc="Beta-fast for yarn-type scaling.",
hint=FieldHint.feature,
)
beta_slow: float = Field(
default=1.,
default=1.0,
desc="Beta-slow for yarn-type scaling.",
hint=FieldHint.feature,
)
Expand All @@ -149,6 +149,12 @@ class RotaryConfig(RotaryArchitectureConfig, BaseModelConfig):
pass


class AddLinearBiasChoices(str, enum.Enum):
nowhere = "nowhere"
everywhere = "everywhere"
only_attn_qkv = "only_attn_qkv"


@config_class()
class TransformerArchitectureConfig(BaseModelArchitectureConfig):
_abstract = False
Expand All @@ -174,7 +180,11 @@ class TransformerArchitectureConfig(BaseModelArchitectureConfig):
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
add_linear_biases: bool = Field(default=True, desc="Add biases to all dense layers.", hint=FieldHint.core)
add_linear_biases: bool | AddLinearBiasChoices = Field(
default=True,
desc="Add biases to all, none or Q, K, V layers. Accepted values: True, False, or AddLinearBiasChoices.",
hint=FieldHint.core,
)
ffn_hidden_size: int = Field(
default=None,
desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.",
Expand Down Expand Up @@ -243,14 +253,40 @@ def _validate(self) -> None:
self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu
self.projection_size = self.num_attention_heads * self.kv_channels
self.num_unshared_experts = self.num_experts - self.num_shared_experts

super()._validate()

if not TritonConfig.TRITON_ENABLED:
warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.")

Assert.leq(self.num_shared_experts, self.num_experts)
Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts)
Assert.multiple(self.num_attention_heads, self.head_groups)

@property
def add_mlp_bias(self) -> bool:
if isinstance(self.add_linear_biases, bool):
return self.add_linear_biases
if self.add_linear_biases == AddLinearBiasChoices.everywhere:
return True
return False

@property
def add_attn_qkv_bias(self) -> bool:
if isinstance(self.add_linear_biases, bool):
return self.add_linear_biases
if self.add_linear_biases == AddLinearBiasChoices.nowhere:
return False
return True

@property
def add_attn_dense_bias(self) -> bool:
if isinstance(self.add_linear_biases, bool):
return self.add_linear_biases
if self.add_linear_biases == AddLinearBiasChoices.everywhere:
return True
return False

@classmethod
def _from_dict(
cls,
Expand Down Expand Up @@ -577,8 +613,11 @@ def _validate(self) -> None:
Assert.geq(scale, 0)

def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool:
use_flash_attention = self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16)

use_flash_attention = self.use_flash_attention and distributed_config.training_dtype in (
DataType.float16,
DataType.bfloat16,
)

# Config parameter `window_size` only can be used with flash attention
if not use_flash_attention:
Assert.is_(self.window_size, None)
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/layers/transformer/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s
self.layer_1 = LinearBase(
hidden_dim,
tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp),
bias=config.add_linear_biases,
bias=config.add_mlp_bias,
weight_init_method=init_method_1,
bias_init_method=init_method_1 if config.random_bias_init else init_zeros_,
lr_scale=tuple(config.mlp_lr_scale),
)
self.layer_2 = LinearBase(
self._intermediate_dim,
hidden_dim,
bias=config.add_linear_biases,
bias=config.add_mlp_bias,
weight_init_method=init_method_2,
bias_init_method=init_method_2 if config.random_bias_init else init_zeros_,
auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1,
Expand Down
16 changes: 16 additions & 0 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import unittest.mock
from fast_llm.layers.transformer.attention import Attention
from fast_llm.layers.transformer.config import TransformerConfig
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.config_utils.tensor_space import TensorSpace


def test_decide_window_size():
Expand All @@ -20,3 +22,17 @@ def test_decide_window_size():
# Arrange - Case 3: max_window_layers is None (always return window_size)
attention._config = TransformerConfig(window_size=512, max_window_layers=None)
assert attention._decide_window_size() == 512


def test_attention_constructor():
transformer_conf = TransformerConfig(
num_layers=2,
num_attention_heads=2,
hidden_size=16,
)
distributed_config = DistributedConfig()
tensor_space = TensorSpace(distributed_config=distributed_config)
transformer_conf.setup_tensor_space(tensor_space)

Attention(transformer_conf, tensor_space, 1)

59 changes: 57 additions & 2 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
import yaml


from fast_llm.layers.transformer.config import TransformerConfig
from fast_llm.utils import Assert
from fast_llm.layers.transformer.config import (
TransformerConfig,
TransformerArchitectureConfig,
AddLinearBiasChoices,
)
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.config import ValidationError

from fast_llm.models.auto import trainer_registry

Expand Down Expand Up @@ -84,3 +88,54 @@ def test_do_use_flash_attention():
mock_distributed_config.training_dtype = DataType.float32
with pytest.raises(AssertionError):
config.do_use_flash_attention(mock_distributed_config)


def test_add_linear_biases_valid_values():
# Valid boolean values
assert TransformerArchitectureConfig(add_linear_biases=True).add_linear_biases is True
assert TransformerArchitectureConfig(add_linear_biases=False).add_linear_biases is False

# Valid enum values
assert TransformerArchitectureConfig(add_linear_biases="nowhere").add_linear_biases == AddLinearBiasChoices.nowhere
assert (
TransformerArchitectureConfig(add_linear_biases="everywhere").add_linear_biases
== AddLinearBiasChoices.everywhere
)
assert (
TransformerArchitectureConfig(add_linear_biases="only_attn_qkv").add_linear_biases == AddLinearBiasChoices.only_attn_qkv
)


def test_add_linear_biases_invalid_values():
with pytest.raises(ValidationError):
TransformerArchitectureConfig(add_linear_biases="invalid_value")

with pytest.raises(ValidationError):
TransformerArchitectureConfig(add_linear_biases=123)

with pytest.raises(ValidationError):
TransformerArchitectureConfig(add_linear_biases=None)


def test_add_mlp_bias():
assert TransformerArchitectureConfig(add_linear_biases=True).add_mlp_bias is True
assert TransformerArchitectureConfig(add_linear_biases=False).add_mlp_bias is False
assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_mlp_bias is True
assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_mlp_bias is False
assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_mlp_bias is False


def test_add_attn_qkv_bias():
assert TransformerArchitectureConfig(add_linear_biases=True).add_attn_qkv_bias is True
assert TransformerArchitectureConfig(add_linear_biases=False).add_attn_qkv_bias is False
assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_attn_qkv_bias is True
assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_attn_qkv_bias is False
assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_qkv_bias is True


def test_add_attn_dense_bias():
assert TransformerArchitectureConfig(add_linear_biases=True).add_attn_dense_bias is True
assert TransformerArchitectureConfig(add_linear_biases=False).add_attn_dense_bias is False
assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_attn_dense_bias is True
assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_attn_dense_bias is False
assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_dense_bias is False
33 changes: 33 additions & 0 deletions tests/test_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from fast_llm.layers.transformer.mlp import MLP
from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP
from fast_llm.layers.transformer.config import TransformerConfig
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.config_utils.tensor_space import TensorSpace


def test_mlp_constructor():
transformer_conf = TransformerConfig(
num_layers=2,
num_attention_heads=2,
hidden_size=16,
)
distributed_config = DistributedConfig()
tensor_space = TensorSpace(distributed_config=distributed_config)
transformer_conf.setup_tensor_space(tensor_space)

MLP(transformer_conf, tensor_space, "name")


def test_moe_mlp_constructor():
transformer_conf = TransformerConfig(
num_layers=2,
num_attention_heads=2,
hidden_size=16,
num_experts=2,
add_linear_biases=False
)
distributed_config = DistributedConfig()
tensor_space = TensorSpace(distributed_config=distributed_config)
transformer_conf.setup_tensor_space(tensor_space)

MixtureOfExpertMLP(transformer_conf, tensor_space, "name")