Skip to content

llama3 rope #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import yaml

from fast_llm.utils import Assert, Tag, get_type_name, header, log
from fast_llm.utils import Assert, Tag, get_type_name, header, log, pop_nested_dict_value, set_nested_dict_value

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -663,15 +663,7 @@ def from_dict(
if isinstance(update, Config):
update = update._to_dict(format_=_ConfigDictFormat.tuple)
for keys, value in update.items():
if isinstance(keys, str):
default[keys] = value
else:
dict_to_update = default
for key in keys[:-1]:
if key not in dict_to_update:
dict_to_update[key] = {}
dict_to_update = dict_to_update[key]
dict_to_update[keys[-1]] = value
set_nested_dict_value(default, keys, value)

return cls._from_dict(default, strict)

Expand Down Expand Up @@ -802,12 +794,21 @@ def _from_dict_dict(cls, value, type_, strict: bool):
return {key: cls._from_dict_nested(value_, args[1], strict) for key, value_ in value.items()}

@classmethod
def _handle_renamed_field(cls, default: dict[str, typing.Any], old_name: str, new_name: str):
def _handle_renamed_field(
cls,
default: dict[str, typing.Any],
old_name: str | tuple[str, ...],
new_name: str | tuple[str, ...],
fn: typing.Callable | None = None,
):
if old_name in default:
warnings.warn(f"Field `{old_name}` is deprecated in class {get_type_name(cls)}, use `{new_name}` instead.")
default[new_name] = default.pop(old_name)
value = pop_nested_dict_value(default, old_name)
if fn is not None:
value = fn(value)
set_nested_dict_value(default, new_name, value)

def compare(self, other: "Config", log_fn: typing.Union[BaseException, typing.Callable] = ValueError):
def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typing.Callable] = ValueError):
# TODO: Check classes?
self_dict = self._to_dict(format_=_ConfigDictFormat.tuple, serializable=True)
other_dict = other._to_dict(format_=_ConfigDictFormat.tuple, serializable=True)
Expand All @@ -824,7 +825,7 @@ def compare(self, other: "Config", log_fn: typing.Union[BaseException, typing.Ca
log(
f"Config diff:\n "
+ "\n ".join(
f"{''.join(key)}`: `{self_value}` != `{other_value}`"
f"{'.'.join(key)}`: `{self_value}` != `{other_value}`"
for key, (self_value, other_value) in diff.items()
),
log_fn=log_fn,
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/base_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_architecture(self):
def compare_architecture(
self,
model_config: "BaseModelArchitectureConfig",
log_fn: typing.Union[BaseException, typing.Callable] = ValueError,
log_fn: typing.Union[type[BaseException], typing.Callable] = ValueError,
):
return self.get_architecture().compare(model_config.get_architecture(), log_fn)

Expand Down
16 changes: 8 additions & 8 deletions fast_llm/engine/checkpoint/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.tensor import SafeTensorSlice
from fast_llm.utils import Assert
from fast_llm.utils import Assert, get_nested_dict_value, set_nested_dict_value

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class ParamConverter:
fast_llm_name: tuple[str, ...] | None
export_name: str | None
export_name: tuple[str, ...] | str | None

def export_param(self, fast_llm_value):
return fast_llm_value
Expand Down Expand Up @@ -203,7 +203,7 @@ def _export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing
else cls._get_fast_llm_attribute(config, converter.fast_llm_name) # Noqa
)
if converter.export_name is not None:
exported_config[converter.export_name] = value
set_nested_dict_value(exported_config, converter.export_name, value)

return exported_config # Noqa

Expand All @@ -213,11 +213,11 @@ def _import_config(
) -> BaseModelArchitectureConfig: # noqa
kwargs = {}
for converter in cls._get_config_converters():
value = converter.import_param(
None
if converter.export_name is None or converter.export_name not in config
else config[converter.export_name]
)
try:
value = None if converter.export_name is None else get_nested_dict_value(config, converter.export_name)
except KeyError:
value = None
value = converter.import_param(value)
if converter.fast_llm_name is not None:
kwargs[converter.fast_llm_name] = value

Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/config_utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def is_main_rank():
return DistributedConfig.default_rank == _MAIN_RANK


def log_main_rank(*message, log_fn: typing.Union[BaseException, typing.Callable] = logger.info, join: str = ", "):
def log_main_rank(*message, log_fn: typing.Union[type[BaseException], typing.Callable] = logger.info, join: str = ", "):
if is_main_rank():
log(*message, log_fn=log_fn, join=join)

Expand Down
27 changes: 0 additions & 27 deletions fast_llm/functional/rotary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import math

import torch

from fast_llm.utils import div
Expand All @@ -13,31 +11,6 @@ def convert_rotary_real_to_complex(tensor: torch.Tensor, kv_channels: int, dim:
return tensor.unflatten(dim, (-1, 2, div(kv_channels, 2))).movedim(dim + 1, dim + 2).flatten(dim, dim + 2)


def get_rotary_frequencies(
sequence_length,
kv_channels,
scale=-math.log(10000),
*,
complex_format: bool = True,
device="cuda",
):
# Calculate the complex frequencies (https://blog.eleuther.ai/rotary-embeddings/)
# `exp(i * n * a) = cos(n * a) + i sin(n * a)`,
# `a = theta ** - (2 * (channel // 2) / kv_channels)`,
# where n is the position in the sequence.
# We preform the calculation in high precision because it matters for rotary embeddings.
angles = torch.outer(
torch.arange(sequence_length, device=device, dtype=torch.float64),
torch.exp(scale * torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64)),
)
frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64)
if not complex_format:
frequencies = convert_rotary_complex_to_real(
torch.view_as_real(frequencies).flatten(-2), kv_channels, 3
).contiguous()
return frequencies


def apply_rotary_embeddings(tensor: torch.Tensor, rope_frequencies: torch.Tensor) -> torch.Tensor:
"""
Apply rotary embeddings to a tensor:
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig):

def _validate(self):
if self.use_position_embeddings is None:
self.use_position_embeddings = not self.transformer.use_rotary_embeddings
self.use_position_embeddings = not self.transformer.rotary.enabled
super()._validate()

def setup_tensor_space(self, tensor_space: TensorSpace):
Expand Down
6 changes: 2 additions & 4 deletions fast_llm/layers/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ def __init__(
self._debug_transformer = self._config.debug_transformer
self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config)

self._triton_rotary = self._config.triton_rotary

init_method_qkv = init_normal_(
std=self._config.init_method_std_qkv,
min_val=self._config.init_method_min_qkv,
Expand Down Expand Up @@ -300,7 +298,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict):
key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels)
value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels)

if self._config.use_rotary_position_embeddings:
if self._config.rotary.enabled:
if self._debug_transformer:
self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs)
self._debug_log(
Expand All @@ -309,7 +307,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict):
self._KV_DIMS,
kwargs,
)
rotary_fn = triton_rotary_autograd_ if self._triton_rotary else apply_rotary_embeddings
rotary_fn = triton_rotary_autograd_ if self._config.rotary.triton else apply_rotary_embeddings
query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q])
key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k])

Expand Down
116 changes: 89 additions & 27 deletions fast_llm/layers/transformer/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import enum
import logging
import math
import typing
import warnings

from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none
Expand Down Expand Up @@ -75,6 +76,72 @@ class TransformerLossNames:
router_z_loss = "router_z_loss"


class RotaryEmbeddingType(str, enum.Enum):
none = "none"
default = "default"
llama3 = "llama3"


@config_class()
class RotaryArchitectureConfig(BaseModelArchitectureConfig):
_abstract = False
type: RotaryEmbeddingType = Field(
default=RotaryEmbeddingType.none,
desc="The type of rotary embedding to use. Choices: none, default, llama3.",
hint=FieldHint.feature,
)
theta: float = Field(
default=10000,
desc="Scale for the rotary positional embeddings",
hint=FieldHint.feature,
)
# TODO: Make a backup implementation that doesn't affect the layout.
triton: bool = Field(
default=True,
desc="Enable the triton implementation of the rotary embeddings. Affects the model layout.",
hint=FieldHint.deprecated,
)
# TODO: These are not really architecture parameters, but we want to import them from huggingface.
scale_factor: float = Field(default=8.0, desc="Scaling factor for llama3-type scaling.", hint=FieldHint.feature)
low_frequency_factor: float = Field(
default=1.0, desc="Low frequency factor for llama3-type scaling.", hint=FieldHint.feature
)
high_frequency_factor: float = Field(
default=4.0, desc="High frequency factor for llama3-type scaling.", hint=FieldHint.feature
)
original_context_length: int = Field(
default=8192, desc="Original context length for llama3-type scaling.", hint=FieldHint.feature
)

@property
def enabled(self):
return self.type != RotaryEmbeddingType.none

@property
def complex_format(self):
# TODO: Make a backup implementation that doesn't affect the layout.
return self.enabled and not self.triton

def _validate(self):
# These happen during conversion.
if self.scale_factor is None:
self.scale_factor = 8.0
if self.low_frequency_factor is None:
self.low_frequency_factor = 1.0
if self.high_frequency_factor is None:
self.high_frequency_factor = 4.0
if self.original_context_length is None:
self.original_context_length = 8192
super()._validate()
if self.triton and not TritonConfig.TRITON_ENABLED:
warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.")


@config_class()
class RotaryConfig(RotaryArchitectureConfig, BaseModelConfig):
pass


@config_class()
class TransformerArchitectureConfig(BaseModelArchitectureConfig):
_abstract = False
Expand Down Expand Up @@ -119,12 +186,9 @@ class TransformerArchitectureConfig(BaseModelArchitectureConfig):
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
use_rotary_embeddings: bool = Field(
default=False, desc="Enable rotary positional embeddings.", hint=FieldHint.feature
)
rotary_embedding_scale: float = Field(
default=-math.log(10000),
desc="Scale for the rotary positional embeddings. Default: -math.log(10000) = -9.210",
rotary: RotaryArchitectureConfig = Field(
default_factory=RotaryArchitectureConfig,
desc="Configuration for the rotary positional embeddings.",
hint=FieldHint.feature,
)
gated: bool = Field(default=False, desc="Enable gated MLP.", hint=FieldHint.feature)
Expand All @@ -133,11 +197,6 @@ class TransformerArchitectureConfig(BaseModelArchitectureConfig):
desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.",
hint=FieldHint.core,
)
triton_rotary: bool = Field(
default=True,
desc="Enable the triton implementation of the rotary embeddings. Affects the model layout.",
hint=FieldHint.deprecated,
)
num_experts: int = Field(
default=1,
desc="Number of MLP experts in a Mixture of Expert (MoE) model",
Expand Down Expand Up @@ -185,6 +244,24 @@ def _validate(self):
Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts)
Assert.multiple(self.num_attention_heads, self.head_groups)

@classmethod
def _from_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
flat: bool = False,
):
# TODO v0.x: Remove backward compatibility.
cls._handle_renamed_field(
default,
"use_rotary_embeddings",
("rotary", "type"),
lambda x: RotaryEmbeddingType.default if x else RotaryEmbeddingType.none,
)
cls._handle_renamed_field(default, "rotary_embedding_scale", ("rotary", "theta"), lambda x: math.exp(-x))
cls._handle_renamed_field(default, "triton_rotary", ("rotary", "triton"))
return super()._from_dict(default, strict, flat)

def setup_tensor_space(self, tensor_space: TensorSpace):
tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor)

Expand Down Expand Up @@ -245,24 +322,11 @@ def setup_tensor_space(self, tensor_space: TensorSpace):
)
)

@property
def complex_rotary_embeddings(self):
return self.use_rotary_position_embeddings and not self.triton_rotary

@property
def rotary_position_embedding_scale(self):
# TODO: Set through rotary theta instead.
return self.rotary_embedding_scale if self.use_rotary_position_embeddings else None

@property
def use_rotary_position_embeddings(self):
# TODO: Set through rotary theta instead.
return self.use_rotary_embeddings


@config_class()
class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig):
normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig)
rotary: RotaryConfig = FieldUpdate(default_factory=RotaryConfig)
# Default: hidden_size**-0.5
# TODO: Allow custom initialization (InitializationConfig?)
init_method_std: float = Field(
Expand Down Expand Up @@ -492,8 +556,6 @@ def _validate(self):
if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None:
Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2)
super()._validate()
if self.triton_rotary and not TritonConfig.TRITON_ENABLED:
warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.")
Assert.geq(self.attention_dropout, 0)
Assert.geq(self.hidden_dropout, 0)
Assert.incl(len(self.mlp_lr_scale), (1, self.num_experts))
Expand Down
Loading
Loading