Skip to content

Commit

Permalink
ST-MAML
Browse files Browse the repository at this point in the history
  • Loading branch information
zhexjtu committed Feb 2, 2023
1 parent d51c8f8 commit bcc7097
Show file tree
Hide file tree
Showing 19 changed files with 1,865 additions and 170 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,18 @@ To reproduce the temperature prediction expriments, run with:
> python main.py --method MAML --model MLP_MAML

For the cross dataset image completion experiments, run with:

> cd ST-MAML-ImgCompletion
> python meta_main.py
For the regression fitting, train a model with:

> cd ST-MAML-Reg
> python python meta_main.py --aug_enc --kl_weight=2.0 --in_weight_rest=0.1 --model_type='prob' --output_folder='results'
For visualization purpose, run with:

> python visual.py --aug_enc --kl_weight=2.0 --in_weight_rest=0.1 --model_type='prob' --output_folder='results'
Binary file added ST-MAML-ImgCompletion/.DS_Store
Binary file not shown.
108 changes: 108 additions & 0 deletions ST-MAML-ImgCompletion/meta_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import torch
from meta_train import meta_train_one_epoch, meta_test_one_epoch
from meta_model import EncModel
import torch.optim as optim
import os
import argparse
from reg_data_loader import *
from utils import get_saved_file
torch.manual_seed(1234)
def get_arguments():
parser = argparse.ArgumentParser(description='ST-MAML')
parser.add_argument('--num_sample_function', type=int, default=40)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gpu_id', type=int, default=2)
parser.add_argument('--meta_bs', type=int, default=36)
parser.add_argument('--epoch', type=int, default=150)
parser.add_argument('--weight_decay', type=float, default=1e-5)
parser.add_argument('--total_batches', type=int, default=500)
parser.add_argument('--inner_loop_grad_clip', type=float, default=10.)
parser.add_argument('--noise_std', default=0.3, type=float)
parser.add_argument('--resume_train', default=False, action='store_true')
parser.add_argument('--inner_lr', default=1e-3, type=float)
parser.add_argument('--inner_step', default=3, type=int)
parser.add_argument('--img_size', default=28, type=int)
parser.add_argument('--loss_type', default='BCEloss', choices=['BCEloss', 'MSEloss'])
parser.add_argument('--output_folder', default='results')
return parser.parse_args()

args = get_arguments()




dataset_list = ['mnist', 'fmnist', 'kmnist']
dataset_list_train = [get_dloader('./data/', dataset, args.meta_bs, args.num_sample_function, args.img_size, train=True) for dataset in dataset_list]
dataset_list_eval = [get_dloader('./data/', dataset, args.meta_bs, args.num_sample_function, args.img_size, train=False) for dataset in dataset_list]
trainset = MultimodalFewShotDataset(
dataset_list_train,
num_total_batches=args.total_batches,
mix_meta_batch=True,
mix_mini_batch=False,
txt_file=None,
)
testset = MultimodalFewShotDataset(
dataset_list_eval,
num_total_batches=100,
mix_meta_batch=True,
mix_mini_batch=False,
txt_file=None,
)



if torch.cuda.is_available():
device = 'cuda'
torch.cuda.set_device(args.gpu_id)
else:
device = 'cpu'
print(device)

state_str = os.path.join(args.output_folder, args.loss_type+'_'+str(args.inner_step))
args.save_epoch_loss = os.path.join(state_str, 'save_epoch_loss.txt')
if not os.path.exists(args.output_folder):
os.mkdir(args.output_folder)
if not os.path.exists(state_str):
os.mkdir(state_str)


if not os.path.exists(args.save_epoch_loss):
f = open(args.save_epoch_loss, 'a+')






model = EncModel(2, [64, 128, 256, 128], 1, 64, 80, 128, 3, [20, 40],64)
model.to(device)


print(model)
if args.resume_train:
path = state_str
start_epoch = get_saved_file(path)
start_model_str = os.path.join(state_str, 'epoch_'+str(start_epoch)+'.pt')
model.load_state_dict(torch.load(start_model_str))
start_epoch += 1


else:
start_epoch=0




opt = optim.Adam(model.parameters(), lr=args.lr)




for epoch in range(start_epoch, args.epoch):
meta_train_one_epoch(model, trainset, opt, epoch, args, 'train', device)
if (epoch)%5==0:
torch.save(model.state_dict(), os.path.join(state_str, 'epoch_'+str(epoch)+'.pt'))

meta_test_one_epoch(model, testset, opt, epoch, args, 'eval', device)


195 changes: 195 additions & 0 deletions ST-MAML-ImgCompletion/meta_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.distributions as distributions
from collections import OrderedDict
import torch.nn.init as init



def linear_block(iput, aug, param, prefix):

n_data, _ = iput.size()
out = F.linear(iput, weight=param[prefix+'linear1.weight'], bias=param[prefix+'linear1.bias'])
out = F.relu(out, inplace=True)
out = torch.cat([out, aug], dim=-1)

out = F.linear(out, weight=param[prefix+'linear2.weight'], bias=param[prefix+'linear2.bias'])
out = F.relu(out, inplace=True)

out = F.linear(out, weight=param[prefix+'linear3.weight'], bias=param[prefix+'linear3.bias'])
out = F.relu(out, inplace=True)

out = F.linear(out, weight=param[prefix+'linear4.weight'], bias=param[prefix+'linear4.bias'])
out = F.relu(out, inplace=True)

out = F.linear(out, weight=param[prefix+'linear5.weight'], bias=param[prefix+'linear5.bias'])
# out = torch.matmul(out, param[prefix+'linear3.weight'].permute(0, 2, 1)) + param[prefix+'linear3.bias']
return out


class LR(nn.Module):
def __init__(self, n_input, n_h, n_output, n_out_aug):
super(LR, self).__init__()
self.mlp = nn.Sequential()
self.mlp.add_module('linear1', nn.Linear(n_input, n_h[0]))
self.mlp.add_module('relu1', nn.ReLU(True))
self.mlp.add_module('linear2', nn.Linear(n_h[0]+n_out_aug, n_h[1]))
self.mlp.add_module('relu2', nn.ReLU(True))
self.mlp.add_module('linear3', nn.Linear(n_h[1], n_h[2]))
self.mlp.add_module('relu3', nn.ReLU(True))
self.mlp.add_module('linear4', nn.Linear(n_h[2], n_h[3]))
self.mlp.add_module('relu4', nn.ReLU(True))
self.mlp.add_module('linear5', nn.Linear(n_h[3], n_output))


for layer in self.mlp.modules():
if isinstance(layer, nn.Linear):
init.kaiming_normal_(layer.weight, mode='fan_in')
if layer.bias is not None:
init.constant_(layer.bias, 0)




def forward(self, x, aug, param=None):

prefix = 'mlp.'
y = linear_block(x, aug, param, prefix)
return y

def cloned_state_dict(self):

cloned_state_dict = {
key: val.clone()
for key, val in self.state_dict().items()
}
return cloned_state_dict



class TaskEnc(nn.Module):
def __init__(self, n_xyaug, n_hid, n_out):
super(TaskEnc, self).__init__()
self.encoder = nn.Sequential()
self.encoder.add_module('xy2hid', nn.Linear(n_xyaug, n_hid))
self.encoder.add_module('relu1', nn.ReLU(inplace=True))
self.encoder.add_module('hid2out', nn.Linear(n_hid, n_out))

for layer in self.encoder.modules():
if isinstance(layer, nn.Linear):
init.kaiming_normal_(layer.weight, mode='fan_in')
if layer.bias is not None:
init.constant_(layer.bias, 0)


def forward(self, aug):
out = self.encoder(aug)
return out







class AugInfo(nn.Module):
def __init__(self, n_in, n_hid, n_out):
super(AugInfo, self).__init__()
self.aug_enc = nn.Sequential()
self.aug_enc.add_module('xy2hid', nn.Linear(n_in, n_hid[0]))
self.aug_enc.add_module('relu1', nn.ReLU(inplace=True))
self.aug_enc.add_module('hid2hid', nn.Linear(n_hid[0], n_hid[1]))
self.aug_enc.add_module('relu2', nn.ReLU(inplace=True))
self.aug_enc.add_module('hid2out', nn.Linear(n_hid[1], n_out))

for layer in self.aug_enc.modules():
if isinstance(layer, nn.Linear):
init.kaiming_normal_(layer.weight, mode='fan_in')
if layer.bias is not None:
init.constant_(layer.bias, 0)



def forward(self, x, y):
n_spt, _ = x.size()
xy_concat = torch.cat([x, y] ,dim=-1)
out = self.aug_enc(xy_concat)
out = out.mean(dim=0, keepdim=True)
return out









class EncModel(nn.Module):
def __init__(self, n_in, n_hid, n_out, n_xyaug, n_hid_emb, n_out_emb, n_xy, n_hid_aug, n_out_aug, aug_feature=True):
super(EncModel, self).__init__()

assert (n_in+n_out==n_xy), 'dimension mismatching.'

self.aug_feature = aug_feature

self.augenc = AugInfo(n_xy, n_hid_aug, n_out_aug)

self.learner = LR(n_in, n_hid, n_out, n_out_aug)
self.encoder = TaskEnc(n_xyaug, n_hid_emb, n_out_emb)


def task_encoder(self, x_spt, y_spt):
self.aug = self.augenc(x_spt, y_spt)

def param_encoder(self):
self.task_emb = self.encoder(self.aug)





def encode_param(self, param=None):
if param == None:
adapted_state_dict = self.learner.cloned_state_dict()
adapted_params = OrderedDict()
for (key, val) in self.learner.named_parameters():

if key == 'mlp.linear5.weight':
code = self.task_emb
adapted_params[key] = torch.sigmoid(code)*val
adapted_state_dict[key] = adapted_params[key]
else:
adapted_params[key] = val
adapted_state_dict[key] = adapted_params[key]
return adapted_state_dict
else:
adapted_state_dict = self.learner.cloned_state_dict()
adapted_params = OrderedDict()
for (key, val) in param.items():

if key == 'mlp.linear5.weight':
code = self.task_emb
adapted_params[key] = torch.sigmoid(code)*val
adapted_state_dict[key] = adapted_params[key]
else:
adapted_params[key] = val
adapted_state_dict[key] = adapted_params[key]
return adapted_state_dict





def forward(self, x, param):

self.aug_vec = self.aug.repeat(x.size(0), 1)
y = self.learner.forward(x, self.aug_vec, param)

return y




Loading

0 comments on commit bcc7097

Please sign in to comment.