-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
19 changed files
with
1,865 additions
and
170 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import torch | ||
from meta_train import meta_train_one_epoch, meta_test_one_epoch | ||
from meta_model import EncModel | ||
import torch.optim as optim | ||
import os | ||
import argparse | ||
from reg_data_loader import * | ||
from utils import get_saved_file | ||
torch.manual_seed(1234) | ||
def get_arguments(): | ||
parser = argparse.ArgumentParser(description='ST-MAML') | ||
parser.add_argument('--num_sample_function', type=int, default=40) | ||
parser.add_argument('--lr', type=float, default=1e-3) | ||
parser.add_argument('--gpu_id', type=int, default=2) | ||
parser.add_argument('--meta_bs', type=int, default=36) | ||
parser.add_argument('--epoch', type=int, default=150) | ||
parser.add_argument('--weight_decay', type=float, default=1e-5) | ||
parser.add_argument('--total_batches', type=int, default=500) | ||
parser.add_argument('--inner_loop_grad_clip', type=float, default=10.) | ||
parser.add_argument('--noise_std', default=0.3, type=float) | ||
parser.add_argument('--resume_train', default=False, action='store_true') | ||
parser.add_argument('--inner_lr', default=1e-3, type=float) | ||
parser.add_argument('--inner_step', default=3, type=int) | ||
parser.add_argument('--img_size', default=28, type=int) | ||
parser.add_argument('--loss_type', default='BCEloss', choices=['BCEloss', 'MSEloss']) | ||
parser.add_argument('--output_folder', default='results') | ||
return parser.parse_args() | ||
|
||
args = get_arguments() | ||
|
||
|
||
|
||
|
||
dataset_list = ['mnist', 'fmnist', 'kmnist'] | ||
dataset_list_train = [get_dloader('./data/', dataset, args.meta_bs, args.num_sample_function, args.img_size, train=True) for dataset in dataset_list] | ||
dataset_list_eval = [get_dloader('./data/', dataset, args.meta_bs, args.num_sample_function, args.img_size, train=False) for dataset in dataset_list] | ||
trainset = MultimodalFewShotDataset( | ||
dataset_list_train, | ||
num_total_batches=args.total_batches, | ||
mix_meta_batch=True, | ||
mix_mini_batch=False, | ||
txt_file=None, | ||
) | ||
testset = MultimodalFewShotDataset( | ||
dataset_list_eval, | ||
num_total_batches=100, | ||
mix_meta_batch=True, | ||
mix_mini_batch=False, | ||
txt_file=None, | ||
) | ||
|
||
|
||
|
||
if torch.cuda.is_available(): | ||
device = 'cuda' | ||
torch.cuda.set_device(args.gpu_id) | ||
else: | ||
device = 'cpu' | ||
print(device) | ||
|
||
state_str = os.path.join(args.output_folder, args.loss_type+'_'+str(args.inner_step)) | ||
args.save_epoch_loss = os.path.join(state_str, 'save_epoch_loss.txt') | ||
if not os.path.exists(args.output_folder): | ||
os.mkdir(args.output_folder) | ||
if not os.path.exists(state_str): | ||
os.mkdir(state_str) | ||
|
||
|
||
if not os.path.exists(args.save_epoch_loss): | ||
f = open(args.save_epoch_loss, 'a+') | ||
|
||
|
||
|
||
|
||
|
||
|
||
model = EncModel(2, [64, 128, 256, 128], 1, 64, 80, 128, 3, [20, 40],64) | ||
model.to(device) | ||
|
||
|
||
print(model) | ||
if args.resume_train: | ||
path = state_str | ||
start_epoch = get_saved_file(path) | ||
start_model_str = os.path.join(state_str, 'epoch_'+str(start_epoch)+'.pt') | ||
model.load_state_dict(torch.load(start_model_str)) | ||
start_epoch += 1 | ||
|
||
|
||
else: | ||
start_epoch=0 | ||
|
||
|
||
|
||
|
||
opt = optim.Adam(model.parameters(), lr=args.lr) | ||
|
||
|
||
|
||
|
||
for epoch in range(start_epoch, args.epoch): | ||
meta_train_one_epoch(model, trainset, opt, epoch, args, 'train', device) | ||
if (epoch)%5==0: | ||
torch.save(model.state_dict(), os.path.join(state_str, 'epoch_'+str(epoch)+'.pt')) | ||
|
||
meta_test_one_epoch(model, testset, opt, epoch, args, 'eval', device) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torchvision.models as models | ||
import torch.distributions as distributions | ||
from collections import OrderedDict | ||
import torch.nn.init as init | ||
|
||
|
||
|
||
def linear_block(iput, aug, param, prefix): | ||
|
||
n_data, _ = iput.size() | ||
out = F.linear(iput, weight=param[prefix+'linear1.weight'], bias=param[prefix+'linear1.bias']) | ||
out = F.relu(out, inplace=True) | ||
out = torch.cat([out, aug], dim=-1) | ||
|
||
out = F.linear(out, weight=param[prefix+'linear2.weight'], bias=param[prefix+'linear2.bias']) | ||
out = F.relu(out, inplace=True) | ||
|
||
out = F.linear(out, weight=param[prefix+'linear3.weight'], bias=param[prefix+'linear3.bias']) | ||
out = F.relu(out, inplace=True) | ||
|
||
out = F.linear(out, weight=param[prefix+'linear4.weight'], bias=param[prefix+'linear4.bias']) | ||
out = F.relu(out, inplace=True) | ||
|
||
out = F.linear(out, weight=param[prefix+'linear5.weight'], bias=param[prefix+'linear5.bias']) | ||
# out = torch.matmul(out, param[prefix+'linear3.weight'].permute(0, 2, 1)) + param[prefix+'linear3.bias'] | ||
return out | ||
|
||
|
||
class LR(nn.Module): | ||
def __init__(self, n_input, n_h, n_output, n_out_aug): | ||
super(LR, self).__init__() | ||
self.mlp = nn.Sequential() | ||
self.mlp.add_module('linear1', nn.Linear(n_input, n_h[0])) | ||
self.mlp.add_module('relu1', nn.ReLU(True)) | ||
self.mlp.add_module('linear2', nn.Linear(n_h[0]+n_out_aug, n_h[1])) | ||
self.mlp.add_module('relu2', nn.ReLU(True)) | ||
self.mlp.add_module('linear3', nn.Linear(n_h[1], n_h[2])) | ||
self.mlp.add_module('relu3', nn.ReLU(True)) | ||
self.mlp.add_module('linear4', nn.Linear(n_h[2], n_h[3])) | ||
self.mlp.add_module('relu4', nn.ReLU(True)) | ||
self.mlp.add_module('linear5', nn.Linear(n_h[3], n_output)) | ||
|
||
|
||
for layer in self.mlp.modules(): | ||
if isinstance(layer, nn.Linear): | ||
init.kaiming_normal_(layer.weight, mode='fan_in') | ||
if layer.bias is not None: | ||
init.constant_(layer.bias, 0) | ||
|
||
|
||
|
||
|
||
def forward(self, x, aug, param=None): | ||
|
||
prefix = 'mlp.' | ||
y = linear_block(x, aug, param, prefix) | ||
return y | ||
|
||
def cloned_state_dict(self): | ||
|
||
cloned_state_dict = { | ||
key: val.clone() | ||
for key, val in self.state_dict().items() | ||
} | ||
return cloned_state_dict | ||
|
||
|
||
|
||
class TaskEnc(nn.Module): | ||
def __init__(self, n_xyaug, n_hid, n_out): | ||
super(TaskEnc, self).__init__() | ||
self.encoder = nn.Sequential() | ||
self.encoder.add_module('xy2hid', nn.Linear(n_xyaug, n_hid)) | ||
self.encoder.add_module('relu1', nn.ReLU(inplace=True)) | ||
self.encoder.add_module('hid2out', nn.Linear(n_hid, n_out)) | ||
|
||
for layer in self.encoder.modules(): | ||
if isinstance(layer, nn.Linear): | ||
init.kaiming_normal_(layer.weight, mode='fan_in') | ||
if layer.bias is not None: | ||
init.constant_(layer.bias, 0) | ||
|
||
|
||
def forward(self, aug): | ||
out = self.encoder(aug) | ||
return out | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
class AugInfo(nn.Module): | ||
def __init__(self, n_in, n_hid, n_out): | ||
super(AugInfo, self).__init__() | ||
self.aug_enc = nn.Sequential() | ||
self.aug_enc.add_module('xy2hid', nn.Linear(n_in, n_hid[0])) | ||
self.aug_enc.add_module('relu1', nn.ReLU(inplace=True)) | ||
self.aug_enc.add_module('hid2hid', nn.Linear(n_hid[0], n_hid[1])) | ||
self.aug_enc.add_module('relu2', nn.ReLU(inplace=True)) | ||
self.aug_enc.add_module('hid2out', nn.Linear(n_hid[1], n_out)) | ||
|
||
for layer in self.aug_enc.modules(): | ||
if isinstance(layer, nn.Linear): | ||
init.kaiming_normal_(layer.weight, mode='fan_in') | ||
if layer.bias is not None: | ||
init.constant_(layer.bias, 0) | ||
|
||
|
||
|
||
def forward(self, x, y): | ||
n_spt, _ = x.size() | ||
xy_concat = torch.cat([x, y] ,dim=-1) | ||
out = self.aug_enc(xy_concat) | ||
out = out.mean(dim=0, keepdim=True) | ||
return out | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
class EncModel(nn.Module): | ||
def __init__(self, n_in, n_hid, n_out, n_xyaug, n_hid_emb, n_out_emb, n_xy, n_hid_aug, n_out_aug, aug_feature=True): | ||
super(EncModel, self).__init__() | ||
|
||
assert (n_in+n_out==n_xy), 'dimension mismatching.' | ||
|
||
self.aug_feature = aug_feature | ||
|
||
self.augenc = AugInfo(n_xy, n_hid_aug, n_out_aug) | ||
|
||
self.learner = LR(n_in, n_hid, n_out, n_out_aug) | ||
self.encoder = TaskEnc(n_xyaug, n_hid_emb, n_out_emb) | ||
|
||
|
||
def task_encoder(self, x_spt, y_spt): | ||
self.aug = self.augenc(x_spt, y_spt) | ||
|
||
def param_encoder(self): | ||
self.task_emb = self.encoder(self.aug) | ||
|
||
|
||
|
||
|
||
|
||
def encode_param(self, param=None): | ||
if param == None: | ||
adapted_state_dict = self.learner.cloned_state_dict() | ||
adapted_params = OrderedDict() | ||
for (key, val) in self.learner.named_parameters(): | ||
|
||
if key == 'mlp.linear5.weight': | ||
code = self.task_emb | ||
adapted_params[key] = torch.sigmoid(code)*val | ||
adapted_state_dict[key] = adapted_params[key] | ||
else: | ||
adapted_params[key] = val | ||
adapted_state_dict[key] = adapted_params[key] | ||
return adapted_state_dict | ||
else: | ||
adapted_state_dict = self.learner.cloned_state_dict() | ||
adapted_params = OrderedDict() | ||
for (key, val) in param.items(): | ||
|
||
if key == 'mlp.linear5.weight': | ||
code = self.task_emb | ||
adapted_params[key] = torch.sigmoid(code)*val | ||
adapted_state_dict[key] = adapted_params[key] | ||
else: | ||
adapted_params[key] = val | ||
adapted_state_dict[key] = adapted_params[key] | ||
return adapted_state_dict | ||
|
||
|
||
|
||
|
||
|
||
def forward(self, x, param): | ||
|
||
self.aug_vec = self.aug.repeat(x.size(0), 1) | ||
y = self.learner.forward(x, self.aug_vec, param) | ||
|
||
return y | ||
|
||
|
||
|
||
|
Oops, something went wrong.