-
-
Notifications
You must be signed in to change notification settings - Fork 7
Description
User Story
As a researcher, I want to train a model using the MAML algorithm so that I can find an optimal weight initialization that can be rapidly adapted to new, unseen tasks with very few training examples.
Dependencies
- [Episodic Data Abstractions for Meta-Learning (N-way K-shot) #290: Episodic Data Abstractions]: This issue depends on the creation of an
EpisodicDataLoaderthat can sample tasks in an N-way, K-shot format.
Phase 1: MAML Trainer Implementation
Goal: Create the main trainer class and implement the MAML algorithm, supporting both the first-order approximation and the full second-order version.
AC 1.1: Scaffolding MAMLTrainer<T> (3 points)
Requirement: Create the main class structure for the trainer.
- Create a new file:
src/Models/Meta/MAMLTrainer.cs. - Define a
public class MAMLTrainer<T>. - The constructor must accept:
IModel<T> metaModel: The initial model to be meta-trained.IOptimizer<T> metaOptimizer: The optimizer for the outer loop meta-updates.ILossFunction<T> lossFunction: The loss function for both inner and outer loops.int innerLoopSteps: The number of adaptation steps on the support set.T innerLoopStepSize: The learning rate for the inner loop adaptation.bool useFirstOrderApproximation: A flag to switch between first-order MAML (true) and full second-order MAML (false).
AC 1.2: Implement the Train Method (8 points)
Requirement: Implement the main training method containing the MAML algorithm's two-level optimization loop.
- Define a public method:
public void Train(EpisodicDataLoader<T> dataLoader, int metaIterations). - Outer Loop: Implement the meta-training loop:
for (int i = 0; i < metaIterations; i++)
- Inside the Outer Loop:
- 1. Meta-Batch Collection: Create a list to store the final query set loss for each task in the meta-batch.
- 2. Inner Loop (iterate through tasks in a meta-batch):
- a. Clone Model: Create a temporary, "fast" model by deep copying the
metaModel's weights. ThistaskModelwill be updated in the inner loop. - b. Inner Adaptation:
for (int j = 0; j < innerLoopSteps; j++)- Perform a training step on the
taskModelusing a batch from the task's support set. The key challenge is to perform this update without losing the gradient history. The update operation itself must be part of the overall computational graph.
- c. Evaluate on Query Set: After the inner loop, calculate the loss of the now-adapted
taskModelon the task's query set. - d. Store Meta-Loss: Add this final query-set loss to the list of meta-losses for the meta-batch.
- a. Clone Model: Create a temporary, "fast" model by deep copying the
- 3. Meta-Update:
- Calculate the average of all losses in the meta-loss list.
- If
useFirstOrderApproximationis true: Detach the inner-loop gradient history from the graph before this step. - Compute Meta-Gradient: Calculate the gradient of the average meta-loss with respect to the original
metaModel's parameters. This will backpropagate through the query set evaluation and all the inner-loop adaptation steps. - Use the
metaOptimizerto apply this computed meta-gradient to themetaModel's weights.
Developer Note: The most complex part of MAML is correctly implementing the meta-gradient calculation (AC 1.2, Step 3). The automatic differentiation engine must be able to handle differentiating through the inner-loop optimization steps. If this is not possible, only the first-order approximation can be implemented.
Phase 2: Validation and Testing
Goal: Verify that the MAML implementation is correct and can successfully meta-learn.
AC 2.1: Unit Tests (5 points)
Requirement: Create unit tests to verify the core algorithm logic.
- Create a new test file:
tests/UnitTests/Meta/MAMLTrainerTests.cs. - Create a test using the first-order approximation.
- Run the
Trainmethod for a single meta-iteration (metaIterations = 1). - Assert that the
metaModel's weights have been updated.
AC 2.2: Integration Test (8 points)
Requirement: Create an integration test on a synthetic problem to prove meta-learning is occurring.
- Synthetic Data: Use the same few-shot sine wave regression problem defined for the Reptile integration test.
- Test Setup:
- Instantiate a simple neural network as the
metaModel. - Instantiate the
MAMLTrainerin first-order mode.
- Instantiate a simple neural network as the
- Test Logic:
- Step 1 (Pre-Meta-Training): Evaluate the initial
metaModelon a set of unseen test tasks. Store the average loss. - Step 2 (Meta-Training): Run the
MAMLTrainer.Train()method for a significant number of meta-iterations. - Step 3 (Post-Meta-Training): Evaluate the now-trained
metaModelon the same set of unseen test tasks. - Assert that the average loss after meta-training is significantly lower than the average loss before meta-training.
- Step 1 (Pre-Meta-Training): Evaluate the initial
Definition of Done
- All checklist items are complete.
- The
MAMLTraineris implemented, supporting at least the first-order approximation. - The integration test demonstrates that the trainer can successfully meta-learn a solution to a synthetic few-shot problem.
- All new code meets the project's >= 90% test coverage requirement.
⚠️ CRITICAL ARCHITECTURAL REQUIREMENTS
Before implementing this user story, you MUST review:
- 📋 Full Requirements:
.github/USER_STORY_ARCHITECTURAL_REQUIREMENTS.md - 📐 Project Rules:
.github/PROJECT_RULES.md
Mandatory Implementation Checklist
1. INumericOperations Usage (CRITICAL)
- Include
protected static readonly INumericOperations<T> NumOps = MathHelper.GetNumericOperations<T>();in base class - NEVER hardcode
double,float, or specific numeric types - use genericT - NEVER use
default(T)- useNumOps.Zeroinstead - Use
NumOps.Zero,NumOps.One,NumOps.FromDouble()for values - Use
NumOps.Add(),NumOps.Multiply(), etc. for arithmetic - Use
NumOps.LessThan(),NumOps.GreaterThan(), etc. for comparisons
2. Inheritance Pattern (REQUIRED)
- Create
I{FeatureName}.csinsrc/Interfaces/(root level, NOT subfolders) - Create
{FeatureName}Base.csinsrc/{FeatureArea}/inheriting from interface - Create concrete classes inheriting from Base class (NOT directly from interface)
3. PredictionModelBuilder Integration (REQUIRED)
- Add private field:
private I{FeatureName}<T>? _{featureName};toPredictionModelBuilder.cs - Add Configure method taking ONLY interface (no parameters):
public IPredictionModelBuilder<T, TInput, TOutput> Configure{FeatureName}(I{FeatureName}<T> {featureName}) { _{featureName} = {featureName}; return this; }
- Use feature in
Build()with default:var {featureName} = _{featureName} ?? new Default{FeatureName}<T>(); - Verify feature is ACTUALLY USED in execution flow
4. Beginner-Friendly Defaults (REQUIRED)
- Constructor parameters with defaults from research/industry standards
- Document WHY each default was chosen (cite papers/standards)
- Validate parameters and throw
ArgumentExceptionfor invalid values
5. Property Initialization (CRITICAL)
- NEVER use
default!operator - String properties:
= string.Empty; - Collections:
= new List<T>();or= new Vector<T>(0); - Numeric properties: appropriate default or
NumOps.Zero
6. Class Organization (REQUIRED)
- One class/enum/interface per file
- ALL interfaces in
src/Interfaces/(root level) - Namespace mirrors folder structure (e.g.,
src/Regularization/→namespace AiDotNet.Regularization)
7. Documentation (REQUIRED)
- XML documentation for all public members
-
<b>For Beginners:</b>sections with analogies and examples - Document all
<param>,<returns>,<exception>tags - Explain default value choices
8. Testing (REQUIRED)
- Minimum 80% code coverage
- Test with multiple numeric types (double, float)
- Test default values are applied correctly
- Test edge cases and exceptions
- Integration tests for PredictionModelBuilder usage
See full details: .github/USER_STORY_ARCHITECTURAL_REQUIREMENTS.md