Skip to content

Checkpoint Pipeline Infrastructure (Phase 1) #493

@JesperDramsch

Description

@JesperDramsch

Overview

Implement the core pipeline infrastructure for the new checkpoint system architecture - Phase 1 of the 5-PR implementation plan.

Parent Issue

Part of #248 - Checkpoint System Refactor

Architecture

Three-Layer Pipeline Design

The checkpoint system is being refactored into three distinct, composable layers that work together through a pipeline pattern:

  1. Checkpoint Acquisition Layer (PR Checkpoint Acquisition Layer - Multi-source checkpoint loading (S3, HTTP, local, MLFlow) #458)

    • Responsible for fetching checkpoint files from various sources (local, S3, HTTP, etc.)
    • Returns a local path to the checkpoint file
    • Handles caching, retries, and network failures
  2. Loading Orchestration Layer (New PR - Phase 2)

    • Implements different strategies for loading checkpoints into models
    • Strategies include: weights-only, transfer learning, warm start, cold start
    • Manages optimizer and scheduler state restoration
  3. Model Transformation Layer (PR Model Transformation Layer - Post-loading modifications (freezing, transfer learning, adapters) #410)

    • Applies post-loading modifications to the model
    • Includes freezing layers, adding LoRA adapters, quantization
    • Executed after checkpoint loading is complete

This Issue's Role

This issue tracks the foundation layer that enables all three layers above to work together:

┌─────────────────────────────────────────────────┐
│   User Configuration (Hydra YAML)               │
└─────────────────────────────────────────────────┘
                    ↓
┌─────────────────────────────────────────────────┐
│   Pipeline Infrastructure (THIS ISSUE)          │
│   - CheckpointContext (carries state)           │
│   - PipelineStage (base abstraction)            │
│   - CheckpointPipeline (orchestrator)           │
│   - ComponentRegistry (factory pattern)         │
└─────────────────────────────────────────────────┘
                    ↓
    Orchestrates the following stages in order:
                    ↓
    1. Acquisition → 2. Loading → 3. Transformation

Data Flow Example

# 1. User provides configuration
config = {
    'checkpoint': {
        'source': {'type': 's3', 'bucket': 'models', 'key': 'model.ckpt'},
        'loading': {'type': 'transfer_learning', 'skip_mismatched': True}
    },
    'model_modifier': {
        'modifiers': [{'type': 'freeze', 'layers': ['encoder']}]
    }
}

# 2. Pipeline creates context
context = CheckpointContext(model=model, config=config)

# 3. Pipeline executes stages
context = await acquisition_stage.process(context)  # Downloads from S3
context = await loading_stage.process(context)      # Loads with transfer learning
context = await modifier_stage.process(context)     # Freezes encoder

# 4. Result: modified model with checkpoint loaded
model = context.model

Key Components to Implement

1. Core Abstractions (training/src/anemoi/training/checkpoint/base.py)

@dataclass
class CheckpointContext:
    """Carries state through pipeline stages"""
    checkpoint_path: Optional[Path] = None
    checkpoint_data: Optional[Dict] = None
    model: Optional[nn.Module] = None
    optimizer: Optional[Optimizer] = None
    scheduler: Optional[Any] = None
    metadata: Dict[str, Any] = field(default_factory=dict)
    config: Optional[DictConfig] = None
    
class PipelineStage(ABC):
    """Base class for all pipeline stages"""
    @abstractmethod
    async def process(self, context: CheckpointContext) -> CheckpointContext:
        pass

2. Pipeline Orchestrator (training/src/anemoi/training/checkpoint/pipeline.py)

class CheckpointPipeline:
    """Orchestrates checkpoint processing through stages"""
    
    def __init__(self, stages: List[PipelineStage]):
        self.stages = stages
    
    async def execute(self, initial_context: CheckpointContext) -> CheckpointContext:
        context = initial_context
        for stage in self.stages:
            context = await stage.process(context)
        return context

3. Registry Pattern (training/src/anemoi/training/checkpoint/registry.py)

  • Component registration for sources, loaders, and modifiers
  • Factory methods for creating components from config
  • Extensibility for adding new component types

4. Error Handling (training/src/anemoi/training/checkpoint/exceptions.py)

  • CheckpointError - Base exception
  • CheckpointNotFoundError - File not found
  • CheckpointLoadError - Loading failures
  • CheckpointIncompatibleError - Compatibility issues

Implementation Checklist

  • Core abstractions defined
  • Pipeline orchestrator implemented
  • Registry pattern established
  • Error handling comprehensive
  • Async utilities implemented
  • Unit tests complete (70% coverage target)
  • Integration tests complete
  • Documentation complete

Design Decisions

1. Async-First Architecture

  • All I/O operations are async for better performance
  • Support for parallel downloads from multiple sources
  • Non-blocking pipeline execution

2. Context Pattern

  • CheckpointContext carries all state through pipeline
  • Immutable context updates for debugging
  • Clear data flow through stages

3. Registry Pattern

  • Extensible component registration
  • Factory methods for instantiation
  • Plugin-style architecture for new components

4. Error Handling Strategy

  • Fail-fast with recovery mechanisms
  • Comprehensive error types for different failures
  • Retry logic with exponential backoff

Testing Strategy

Unit Tests (tests/checkpoint/test_pipeline.py)

  • Test pipeline execution order
  • Test context passing between stages
  • Test error propagation
  • Mock all external dependencies

Integration Tests (tests/checkpoint/test_integration.py)

  • Test full pipeline with mock stages
  • Test async execution
  • Test error recovery
  • Test registry functionality

Migration Path

This infrastructure will support legacy configurations through adapters:

  • load_weights_only=True → WeightsOnlyLoader
  • transfer_learning=True → TransferLearningLoader
  • resume_from_checkpoint → WarmStartLoader

Related Issues

Branch

  • Working branch: feature/checkpoint-pipeline-infrastructure

Success Criteria

  • All tests passing
  • No breaking changes to existing functionality (yet)
  • Documentation complete
  • Ready for Phase 2 development to begin

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

Status

To be triaged

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions