Skip to content

Model Transformation Layer - Post-loading modifications (freezing, transfer learning, adapters) #410

@JesperDramsch

Description

@JesperDramsch

Overview

During my work towards #248 I have converged on a solution that requires a refactor of the model initialisation.

UPDATE: This issue is now part of the 5-PR checkpoint architecture refactor. It implements the Model Transformation Layer that applies post-loading modifications to models.

Parent Issue

Part of #248 - Checkpoint System Refactor

Related Architecture Components

Original Problem

Specifically in

if self.load_weights_only:

We have a lot of branching that is getting increasingly difficult to extend and maintain.

New Architecture Role

This issue implements the Model Transformation Layer - the third layer in our pipeline architecture:

┌─────────────────────────────────────────────────┐
│    Model Transformation Layer (THIS ISSUE)      │
│         (Post-loading modifications)            │
├─────────────────────────────────────────────────┤
│         Loading Orchestration Layer             │
│    (Strategies for applying checkpoints)        │
├─────────────────────────────────────────────────┤
│        Checkpoint Acquisition Layer             │
│      (Obtaining checkpoint from sources)        │
└─────────────────────────────────────────────────┘

Components to Implement

1. ModelModifier Base Class (training/src/anemoi/training/train/modify.py)

class ModelModifier(ABC):
    """Base class for all model modifiers"""
    @abstractmethod
    def apply(self, model: nn.Module) -> nn.Module:
        """Apply modification to model"""
        pass

2. Implemented Modifiers

  • FreezingModelModifier: Selective parameter freezing by module name
  • TransferLearningModelModifier: Transfer learning with fallback for missing checkpoint system
  • Future: LoRAModifier, QuantizationModifier, etc.

3. ModelModifierApplier

Orchestrates application of multiple modifiers in sequence:

class ModelModifierApplier:
    def process(self, base_model: nn.Module, config: DictConfig) -> nn.Module

Configuration

training:
  model_modifier:
    modifiers:
      - _target_: "anemoi.training.train.modify.FreezingModelModifier"
        submodules_to_freeze:
          - "encoder"
          - "processor.0"
      - _target_: "anemoi.training.train.modify.TransferLearningModelModifier"
        checkpoint_path: "/path/to/pretrained.ckpt"
        strict: false
        skip_mismatched: true

Benefits

  • Modular: Each modification type is a separate, composable modifier
  • Configurable: Full Hydra configuration support with validation
  • Extensible: Easy addition of new modifier types
  • Standalone: Works independently without external dependencies
  • Order-aware: Modifiers applied in specified sequence

Implementation Status (PR #442)

  • Base ModelModifier class
  • FreezingModelModifier
  • TransferLearningModelModifier
  • ModelModifierApplier
  • Configuration integration
  • Tests
  • Documentation

Future Extensions

This modular system enables:

Testing

  • Unit tests for each modifier
  • Integration tests with training pipeline
  • Configuration validation tests

Success Criteria

  • All existing functionality preserved
  • Clean separation of concerns
  • Extensible for future modifiers
  • Full test coverage
  • Documentation complete

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

Status

No status

Relationships

None yet

Development

No branches or pull requests

Issue actions