👉 Check out my blog post on MAML! ✍️
This is a custom implementation of the paper Model-Agnostic Meta-Learning (Finn et
al.), using Higher
for second-order optimization, thus making this framework truly model-agnostic. Compared to
other implementations, the optimizee does not need to be constructed specifically for MAML, you can just plug in
any PyTorch model into MAML
!
See this example from learner.py
:
class ConvNetClassifier(nn.Module):
def __init__(self, device, input_channels: int, n_classes: int):
super().__init__()
self.cnn = nn.Sequential(
nn.Conv2d(input_channels, 64, 3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64),
nn.ReLU())
self.flc = nn.Sequential(
nn.Linear(64*20*20, n_classes)).to(device)
def forward(self, x):
x = self.cnn(x)
x = x.view(x.size(0), -1)
x = self.flc(x)
return x
This implementation ships with the OmniGlot dataset for classification, and a few simple toy regression datasets (sinusoid, harmonic, etc.). To use other datasets, you will need to write your own Dataset class, following the given interface (TODO).
usage: main.py [-h] [--checkpoint_path CHECKPOINT_PATH] [--load LOAD] [--eval] [--samples SAMPLES] [-k K] [-q Q] [-n N] [-s S]
[--dataset {omniglot,sinusoid,harmonic}] [--meta-batch-size META_BATCH_SIZE] [--iterations ITERATIONS]
Model-Agnostic Meta-Learning
optional arguments:
-h, --help show this help message and exit
--checkpoint_path CHECKPOINT_PATH
path to checkpoint saving directory
--load LOAD path to model checkpoint
--eval Evaluation moed
--samples SAMPLES Number of samples per task. The resulting number of test samples will be this value minus <K>.
-k K Number of shots for meta-training
-q Q Number of meta-testing samples
-n N Number of classes for n-way classification
-s S Number of inner loop optimization steps during meta-training
--dataset {omniglot,sinusoid,harmonic}
--meta-batch-size META_BATCH_SIZE
Number of tasks per meta-update
--iterations ITERATIONS
Number of outer-loop iterations