Skip to content

Support block-modular architecture #277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 158 commits into
base: main
Choose a base branch
from
Draft

Support block-modular architecture #277

wants to merge 158 commits into from

Conversation

oleksost
Copy link
Contributor

@oleksost oleksost commented May 29, 2025

✨ Description

This draft PR addresses #242 by introducing a flexible, modular configuration system for hybrid model architectures.

TODOs:

  • add more testing to make sure legacy behaviour is well supported
  • implement weight sharing
  • support for block-specific learning rate scales
  • make sure model serialisation/conversion works as expected
  • review and unify naming conventions (block, layer) across codebase.
  • clean up & test
model:
  base_model:
    cross_entropy_impl: fused
    blocks:
      bob:
        type: transformer
        hidden_size: 512
        share_weights: true
        
      mamba:
        type: discrete_mamba2
        state_size: 16
        expansion_factor: 2
        hidden_size: 512
        
    hybrid_block_layout: ["bob", "mamba", "mamba", "bob"]
    num_layers: 4

Which will result in block layout like this: ["bob", "mamba_1", "mamba_2", "bob"], where bobs share weights and mamba do not share weights.

🔍 Type of change

Class hierarchy in the config system:

  • started moving functionality specific to BaseBlock into BaseBlockConfig in layers/common
  • transformer and SSM layer configs inherit from BaseBlockConfig, both holding functionality specific to their dedicated blocks (TransformerLayer, LlambaBlock)

Block-specific hyperparameters & tensor space definition:

  • HybridBlockConfigs implemented under models/hybrid/config allowing block-specific hyperparameters definition
  • the names of the elements in the tensor space now include block suffixes; no suffixes are used in the case of non-hybrid GPT models
  • still supports legacy behaviour with blocks defined using lists like [t,m2d,m] & non-hybrid GPT models

Layer freezing:

  • in case of PEFT layer freezing must be explicit: i.e. if LoRA is used, it dpoes nto automatically freeze other layers and lr_scales must be used.
  • we have per_block lr_scale and component specific scales like norm_lr_scale, mlp_lr_scale etc.. If both are passed, the resulting scale for a component is lr_scale of the block multiplied by the component specific lr (see 'get_lr_scale' function.
  • for GPT model (non-hybrid) 'lr_scale' should not be used as it would be applied to all layers, since all alyers share the same config.

Select all that apply:

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

📝 Changes

✅ Checklist

Make sure the following tasks are completed before submitting the PR:

General

  • 📜 I have read and followed the contributing guidelines.
  • 🏷️ I am using a clear and descriptive PR title that summarizes the key change or feature introduced.
  • 🎉 The functionality is complete, and I have tested the changes.
  • 📝 I have updated the documentation if needed.
  • ⚠️ The change does not introduce any new issues (e.g., runtime warnings, type checker errors, linting problems, unhandled edge cases).
  • 🧩 I have commented my code, especially in hard-to-understand areas.

Dependencies and Configuration

  • 🐋 I have updated the Docker configuration or dependencies, if applicable.
  • 🔄 I have ensured compatibility with the existing setup after dependency changes.

Testing

  • 🧪 I have added or updated tests to cover my changes.
  • ✔️ New and existing tests pass locally with my changes.
  • 🚦 I have tested these changes on GPUs and verified training stability.
  • 🏋️ I have tested the changes on realistic training workloads, if applicable.

Performance Impact

  • 📊 I have run benchmarks where applicable to evaluate the performance impact.
  • ✅ The benchmarks show no performance regression.
  • 🚀 The benchmarks indicate a potential performance improvement.
  • ⚠️ The benchmarks indicate a potential performance degradation.
  • 📈 I have provided benchmark results and detailed any performance impact below, if applicable.

@oleksost oleksost requested a review from nandahkrishna May 29, 2025 12:30
@oleksost oleksost requested a review from jlamypoirier June 10, 2025 00:58
Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a quick look, will go deeped once updated with main

@@ -380,8 +380,8 @@ def validate[T: Config](self: T, *, _is_validating: bool = False) -> T:

if expected_class is not None:
# Should be handled in `from_dict`, but can fail if instantiating directly.
Assert.is_(self.__class__, expected_class)

# TODO: is this ok? i.e. we want the assigned class to be a subclass of the expected class, not neccessarily exactly the same class.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is handled in from_dict. The expected class is not the same as the type hint.



@config_class()
class BaseBlockConfig(BaseModelConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't really belong in common. Maybe a base_block submodule?

@@ -41,10 +42,10 @@ class LanguageModelKwargs:

@config_class()
class LanguageModelBaseConfig(BaseModelConfig):
transformer: TransformerConfig = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this gone?

hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)
head_normalization: NormalizationConfig | None = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implicit convention: Put sub-configs on top. I don't think we want None in the type hint since it's not a valid value after validation.

if self.embeddings_hidden_dropout is None:
self.embeddings_hidden_dropout = 0.0
if self.head_normalization is None:
self.head_normalization = NormalizationConfig()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather keep the transformer normalization as the default.

@@ -0,0 +1,55 @@
import typing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename file to block for consistency

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants