-
Notifications
You must be signed in to change notification settings - Fork 33
Support block-modular architecture #277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…oleksiy/apriel-ssm
…to modular_hybrids
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a quick look, will go deeped once updated with main
@@ -380,8 +380,8 @@ def validate[T: Config](self: T, *, _is_validating: bool = False) -> T: | |||
|
|||
if expected_class is not None: | |||
# Should be handled in `from_dict`, but can fail if instantiating directly. | |||
Assert.is_(self.__class__, expected_class) | |||
|
|||
# TODO: is this ok? i.e. we want the assigned class to be a subclass of the expected class, not neccessarily exactly the same class. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this is handled in from_dict. The expected class is not the same as the type hint.
|
||
|
||
@config_class() | ||
class BaseBlockConfig(BaseModelConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't really belong in common. Maybe a base_block
submodule?
@@ -41,10 +42,10 @@ class LanguageModelKwargs: | |||
|
|||
@config_class() | |||
class LanguageModelBaseConfig(BaseModelConfig): | |||
transformer: TransformerConfig = Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this gone?
hint=FieldHint.feature, | ||
valid=check_field(Assert.geq, 0), | ||
) | ||
head_normalization: NormalizationConfig | None = Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implicit convention: Put sub-configs on top. I don't think we want None
in the type hint since it's not a valid value after validation.
if self.embeddings_hidden_dropout is None: | ||
self.embeddings_hidden_dropout = 0.0 | ||
if self.head_normalization is None: | ||
self.head_normalization = NormalizationConfig() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather keep the transformer normalization as the default.
@@ -0,0 +1,55 @@ | |||
import typing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename file to block
for consistency
✨ Description
This draft PR addresses #242 by introducing a flexible, modular configuration system for hybrid model architectures.
TODOs:
Which will result in block layout like this: ["bob", "mamba_1", "mamba_2", "bob"], where bobs share weights and mamba do not share weights.
🔍 Type of change
Class hierarchy in the config system:
BaseBlock
intoBaseBlockConfig
inlayers/common
BaseBlockConfig
, both holding functionality specific to their dedicated blocks (TransformerLayer, LlambaBlock)Block-specific hyperparameters & tensor space definition:
HybridBlockConfig
s implemented undermodels/hybrid/config
allowing block-specific hyperparameters definitionLayer freezing:
lr_scale
and component specific scales likenorm_lr_scale
,mlp_lr_scale
etc.. If both are passed, the resulting scale for a component islr_scale
of the block multiplied by the component specific lr (see 'get_lr_scale' function.Select all that apply:
📝 Changes
✅ Checklist
Make sure the following tasks are completed before submitting the PR:
General
Dependencies and Configuration
Testing
Performance Impact