- torch==1.4.0
- PyYAML==3.13
We borrow the embedding from the deepmind/leo repo
You can download the pretrained embeddings here,
or do
$ wget http://storage.googleapis.com/leo-embeddings/embeddings.zip
$ unzip embeddings.zip
python3 main.py -train \
-verbose \
-N 5 \
-K 1 \
-embedding_dir $(EMBEDDING_DIR) \
-dataset miniImageNet \
-exp_name toy-example \
-save_checkpoint
where
-N
,-K
means N-way K-shot training,-exp_name
help you keep track of your experiment,-save_checkpoint
to save model for later testing.
for full arguments, see main.py
python3 main.py -test \
-N 5 \
-K 1 \
-embedding_dir $(EMBEDDING_DIR) \
-dataset miniImageNet \
-verbose \
-load $(model_path)
The testing result will be printed on the console.
This projects comes with Comet.ml support. If you want to disable logging, just add -disable_comet
as an argument.
You will need to modify the COMET_PROJECT_NAME
and COMET_WORKSPACE
in config.yml
to enable monitoring.
*If you do not save your comet API key in .comet.config
, you will have to specify API key in line 147
in solver.py
.
You can modify the hyperparameters in config.yml
, where detailed descriptions are also provided.
The hyperparameters that yield the best result in this code are as follow:
Hyperparameters | miniImageNet 1-shot | miniImageNet 5-shot | tieredImageNet 1-shot | tieredImageNet 5-shot |
---|---|---|---|---|
outer_lr |
0.0005 | 0.0006 | 0.0006 | 0.0006 |
l2_penalty_weight |
0.0001 | 8.5e-6 | 3.6e-10 | 3.6e-10 |
orthogonality_penalty_weight |
303.0 | 0.00152 | 0.188 | 0.188 |
dropout |
0.3 | 0.3 | 0.3 | 0.3 |
kl_weight |
0 | 0.001 | 0.001 | 0.001 |
encoder_penalty_weight |
1e-9 | 2.66e-7 | 5.7e-6 | 5.7e-6 |
Implementation | miniImageNet 1-shot | miniImageNet 5-shot | tieredImageNet 1-shot | tieredImageNet 5-shot |
---|---|---|---|---|
LEO Paper | 61.76 ± 0.08% | 77.59 ± 0.12% | 66.33 ± 0.05% | 81.44 ± 0.09% |
this code | 59.46 ± 0.08% | 76.01 ± 0.09% | 66.62 ± 0.07% | 81.72 ± 0.09% |
*The result we obtained may not be comparable since the model is trained on both the training set and validation set in the paper, while our model is only trained on the training set and validated on the validation set.
Note: This project is licensed under the terms of the MIT license.