-
Notifications
You must be signed in to change notification settings - Fork 42
Open
Description
Overview
During my work towards #248 I have converged on a solution that requires a refactor of the model initialisation.
UPDATE: This issue is now part of the 5-PR checkpoint architecture refactor. It implements the Model Transformation Layer that applies post-loading modifications to models.
Parent Issue
Part of #248 - Checkpoint System Refactor
Related Architecture Components
- Foundation: Checkpoint Pipeline Infrastructure (Phase 1) #493 (Pipeline Infrastructure)
- Works with: Checkpoint Acquisition Layer - Multi-source checkpoint loading (S3, HTTP, local, MLFlow) #458 (Checkpoint Acquisition), Checkpoint Loading Orchestration (Phase 2) #494 (Loading Orchestration)
- Integration: Checkpoint System Integration and Migration (Phase 3) #495 (System Integration)
Original Problem
Specifically in
if self.load_weights_only: |
We have a lot of branching that is getting increasingly difficult to extend and maintain.
New Architecture Role
This issue implements the Model Transformation Layer - the third layer in our pipeline architecture:
┌─────────────────────────────────────────────────┐
│ Model Transformation Layer (THIS ISSUE) │
│ (Post-loading modifications) │
├─────────────────────────────────────────────────┤
│ Loading Orchestration Layer │
│ (Strategies for applying checkpoints) │
├─────────────────────────────────────────────────┤
│ Checkpoint Acquisition Layer │
│ (Obtaining checkpoint from sources) │
└─────────────────────────────────────────────────┘
Components to Implement
1. ModelModifier Base Class (training/src/anemoi/training/train/modify.py
)
class ModelModifier(ABC):
"""Base class for all model modifiers"""
@abstractmethod
def apply(self, model: nn.Module) -> nn.Module:
"""Apply modification to model"""
pass
2. Implemented Modifiers
FreezingModelModifier
: Selective parameter freezing by module nameTransferLearningModelModifier
: Transfer learning with fallback for missing checkpoint system- Future:
LoRAModifier
,QuantizationModifier
, etc.
3. ModelModifierApplier
Orchestrates application of multiple modifiers in sequence:
class ModelModifierApplier:
def process(self, base_model: nn.Module, config: DictConfig) -> nn.Module
Configuration
training:
model_modifier:
modifiers:
- _target_: "anemoi.training.train.modify.FreezingModelModifier"
submodules_to_freeze:
- "encoder"
- "processor.0"
- _target_: "anemoi.training.train.modify.TransferLearningModelModifier"
checkpoint_path: "/path/to/pretrained.ckpt"
strict: false
skip_mismatched: true
Benefits
- Modular: Each modification type is a separate, composable modifier
- Configurable: Full Hydra configuration support with validation
- Extensible: Easy addition of new modifier types
- Standalone: Works independently without external dependencies
- Order-aware: Modifiers applied in specified sequence
Implementation Status (PR #442)
- Base ModelModifier class
- FreezingModelModifier
- TransferLearningModelModifier
- ModelModifierApplier
- Configuration integration
- Tests
- Documentation
Future Extensions
This modular system enables:
- PEFT adapters (LoRA, QLoRA)
- Quantization
- On-the-fly renaming (Transfer learning broken for models trained before #182 #249)
- Conditional loading into modified architectures
- Pre-training on one decoder, fine-tuning on multi-decoder
Testing
- Unit tests for each modifier
- Integration tests with training pipeline
- Configuration validation tests
Success Criteria
- All existing functionality preserved
- Clean separation of concerns
- Extensible for future modifiers
- Full test coverage
- Documentation complete
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
No status