-
Notifications
You must be signed in to change notification settings - Fork 42
Description
Overview
Implement the core pipeline infrastructure for the new checkpoint system architecture - Phase 1 of the 5-PR implementation plan.
Parent Issue
Part of #248 - Checkpoint System Refactor
Architecture
Three-Layer Pipeline Design
The checkpoint system is being refactored into three distinct, composable layers that work together through a pipeline pattern:
-
Checkpoint Acquisition Layer (PR Checkpoint Acquisition Layer - Multi-source checkpoint loading (S3, HTTP, local, MLFlow) #458)
- Responsible for fetching checkpoint files from various sources (local, S3, HTTP, etc.)
- Returns a local path to the checkpoint file
- Handles caching, retries, and network failures
-
Loading Orchestration Layer (New PR - Phase 2)
- Implements different strategies for loading checkpoints into models
- Strategies include: weights-only, transfer learning, warm start, cold start
- Manages optimizer and scheduler state restoration
-
Model Transformation Layer (PR Model Transformation Layer - Post-loading modifications (freezing, transfer learning, adapters) #410)
- Applies post-loading modifications to the model
- Includes freezing layers, adding LoRA adapters, quantization
- Executed after checkpoint loading is complete
This Issue's Role
This issue tracks the foundation layer that enables all three layers above to work together:
┌─────────────────────────────────────────────────┐
│ User Configuration (Hydra YAML) │
└─────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────┐
│ Pipeline Infrastructure (THIS ISSUE) │
│ - CheckpointContext (carries state) │
│ - PipelineStage (base abstraction) │
│ - CheckpointPipeline (orchestrator) │
│ - ComponentRegistry (factory pattern) │
└─────────────────────────────────────────────────┘
↓
Orchestrates the following stages in order:
↓
1. Acquisition → 2. Loading → 3. Transformation
Data Flow Example
# 1. User provides configuration
config = {
'checkpoint': {
'source': {'type': 's3', 'bucket': 'models', 'key': 'model.ckpt'},
'loading': {'type': 'transfer_learning', 'skip_mismatched': True}
},
'model_modifier': {
'modifiers': [{'type': 'freeze', 'layers': ['encoder']}]
}
}
# 2. Pipeline creates context
context = CheckpointContext(model=model, config=config)
# 3. Pipeline executes stages
context = await acquisition_stage.process(context) # Downloads from S3
context = await loading_stage.process(context) # Loads with transfer learning
context = await modifier_stage.process(context) # Freezes encoder
# 4. Result: modified model with checkpoint loaded
model = context.model
Key Components to Implement
1. Core Abstractions (training/src/anemoi/training/checkpoint/base.py
)
@dataclass
class CheckpointContext:
"""Carries state through pipeline stages"""
checkpoint_path: Optional[Path] = None
checkpoint_data: Optional[Dict] = None
model: Optional[nn.Module] = None
optimizer: Optional[Optimizer] = None
scheduler: Optional[Any] = None
metadata: Dict[str, Any] = field(default_factory=dict)
config: Optional[DictConfig] = None
class PipelineStage(ABC):
"""Base class for all pipeline stages"""
@abstractmethod
async def process(self, context: CheckpointContext) -> CheckpointContext:
pass
2. Pipeline Orchestrator (training/src/anemoi/training/checkpoint/pipeline.py
)
class CheckpointPipeline:
"""Orchestrates checkpoint processing through stages"""
def __init__(self, stages: List[PipelineStage]):
self.stages = stages
async def execute(self, initial_context: CheckpointContext) -> CheckpointContext:
context = initial_context
for stage in self.stages:
context = await stage.process(context)
return context
3. Registry Pattern (training/src/anemoi/training/checkpoint/registry.py
)
- Component registration for sources, loaders, and modifiers
- Factory methods for creating components from config
- Extensibility for adding new component types
4. Error Handling (training/src/anemoi/training/checkpoint/exceptions.py
)
CheckpointError
- Base exceptionCheckpointNotFoundError
- File not foundCheckpointLoadError
- Loading failuresCheckpointIncompatibleError
- Compatibility issues
Implementation Checklist
- Core abstractions defined
- Pipeline orchestrator implemented
- Registry pattern established
- Error handling comprehensive
- Async utilities implemented
- Unit tests complete (70% coverage target)
- Integration tests complete
- Documentation complete
Design Decisions
1. Async-First Architecture
- All I/O operations are async for better performance
- Support for parallel downloads from multiple sources
- Non-blocking pipeline execution
2. Context Pattern
CheckpointContext
carries all state through pipeline- Immutable context updates for debugging
- Clear data flow through stages
3. Registry Pattern
- Extensible component registration
- Factory methods for instantiation
- Plugin-style architecture for new components
4. Error Handling Strategy
- Fail-fast with recovery mechanisms
- Comprehensive error types for different failures
- Retry logic with exponential backoff
Testing Strategy
Unit Tests (tests/checkpoint/test_pipeline.py
)
- Test pipeline execution order
- Test context passing between stages
- Test error propagation
- Mock all external dependencies
Integration Tests (tests/checkpoint/test_integration.py
)
- Test full pipeline with mock stages
- Test async execution
- Test error recovery
- Test registry functionality
Migration Path
This infrastructure will support legacy configurations through adapters:
load_weights_only=True
→ WeightsOnlyLoadertransfer_learning=True
→ TransferLearningLoaderresume_from_checkpoint
→ WarmStartLoader
Related Issues
- Parent: Anemoi [Fine-tuning, Transfer Learning, Model Freezing] Roadmap #248 (Checkpoint System Refactor)
- Blocks: Loading Orchestration issue
- Related to: Checkpoint Acquisition Layer - Multi-source checkpoint loading (S3, HTTP, local, MLFlow) #458 (Checkpoint Acquisition), Model Transformation Layer - Post-loading modifications (freezing, transfer learning, adapters) #410 (Model Modifiers)
Branch
- Working branch:
feature/checkpoint-pipeline-infrastructure
Success Criteria
- All tests passing
- No breaking changes to existing functionality (yet)
- Documentation complete
- Ready for Phase 2 development to begin
Metadata
Metadata
Assignees
Labels
Type
Projects
Status