Skip to content

Commit

Permalink
Merge pull request #2 from ASSANDHOLE/sine_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ASSANDHOLE authored Oct 19, 2022
2 parents 88357a7 + d59a97a commit ad16852
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 1 deletion.
42 changes: 42 additions & 0 deletions example_sinewave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch

from sine_dataset import create_dataset_sinewave

from utils import NamedDict


def get_args_maml_regression():
args = NamedDict()
args.problem_dim = (1, 3)
args.train_test = (20, 3)
args.epoch = 100
args.update_lr = 0.00001 # with other configs unchanged, lr=0.01 causes loss=nan, i.e., gradient explode
args.meta_lr = 0.001
args.k_spt = 20
args.k_qry = 100
args.update_step = 5
args.update_step_test = 10
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# args.device = torch.device('cpu')
return args


def get_network_structure_maml_regression():
config = [
('linear', [40, 1]),
('leakyrelu', [0.1, True]),
('linear', [40, 40]),
('leakyrelu', [0.1, True]),
('linear', [1, 40]),
]
return config


def get_dataset_sinewave(args, **kwargs):
problem_dim = args.problem_dim
n_problem = args.train_test
spt_qry = (args.k_spt, args.k_qry)
dataset = create_dataset_sinewave(problem_dim, n_problem, spt_qry=spt_qry, **kwargs)
return dataset


19 changes: 18 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Unified MAML for this project
from __future__ import annotations

from typing import Tuple, List

Expand All @@ -9,6 +10,7 @@

from utils import NamedDict
from example import get_args, get_network_structure, get_dataset
from example_sinewave import get_args_maml_regression, get_network_structure_maml_regression, get_dataset_sinewave


class Sol:
Expand Down Expand Up @@ -128,5 +130,20 @@ def main():
print(f'Random loss: {random_loss:.4f}')


def main_sinewave():
args = get_args_maml_regression()
network_structure = get_network_structure_maml_regression()
dataset = get_dataset_sinewave(args)
sol = Sol(dataset, args, network_structure)
train_loss = sol.train(explicit=5)
test_res, test_loss = sol.test()
print(f'Test loss: {test_loss:.4f}')

args.test_update_step = 30
sol = Sol(dataset, args, network_structure)
random_res, random_loss = sol.test()
print(f'Random loss: {random_loss:.4f}')

if __name__ == '__main__':
main()
# main()
main_sinewave()
95 changes: 95 additions & 0 deletions sine_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import math
import random
import numpy as np
from typing import Tuple, List, Callable


def get_sine_wave_sampling(amp: float = 1, phase: float = 0, num_points: int = 100,
domain: Tuple[float, float] = (0, 1)) ->\
Tuple[List[Tuple[float, float]], Callable[[float], float]]:

def sine_wave_func(x: float) -> float:
return math.sin(x + phase) * amp

# uniform sampling
interval = (domain[1] - domain[0]) / num_points
X = [domain[0] + i * interval + interval / 2 for i in range(num_points)]
dataset = [(x, sine_wave_func(x)) for x in X]

return dataset, sine_wave_func


def create_dataset_sinewave(problem_dim: Tuple[int, int], train_test: Tuple[int, int], spt_qry: Tuple[int, int],
amp: Tuple[float, float] = (0.1, 5.0), phase: Tuple[float, float] = (0.1, math.pi),
domain: Tuple[float, float] = (-5.0, 5.0)) -> \
Tuple[
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
]:
"""
Sample sine wave functions and create a dataset for meta-learning.
Hyperparameters from MAML paper: https://arxiv.org/pdf/1703.03400.pdf.
Using consistent hyperparameters for comparison.
Parameters
----------
problem_dim : Tuple[int, int]
The number of variables and number of objectives for each problem
train_test : Tuple[int, int]
[n_train, n_test]
spt_qry : Tuple[int, int]
The number of support and query points for each problem
amp : Tuple[float, float]
The amplitude range
phase : Tuple[float, float]
The phase range
domain : Tuple[float, float]
The domain of input
Returns
-------
Tuple[
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
]
The first element is the training set [support set, support label, query set, query label]
The second element is the test set [support set, support label, query set, query label]
"""
if problem_dim[0] != 1:
print('The objective function of the dataset only support 1 variable. \n'
'Current number of variables: %d', problem_dim[0])
exit(-1)

def create_dataset_inner(n_problems: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
set_spt_x, set_spt_y, set_qry_x, set_qry_y = [], [], [], []
for objective in range(problem_dim[1] * n_problems):
_amp = amp[0] + random.random() * (amp[1]-amp[0]) # [0.1, 5.0)
_phase = phase[0] + random.random() * (phase[1]-phase[0]) # [0, 2pi)
dataset, _ = get_sine_wave_sampling(_amp, _phase, spt_qry[0] + spt_qry[1], domain)
random.shuffle(dataset)
obj_spt, obj_qry = dataset[:spt_qry[0]], dataset[spt_qry[0]:]
obj_spt_x, obj_spt_y = zip(*obj_spt)
obj_qry_x, obj_qry_y = zip(*obj_qry)
set_spt_x.append(obj_spt_x)
set_spt_y.append(obj_spt_y)
set_qry_x.append(obj_qry_x)
set_qry_y.append(obj_qry_y)
set_spt_x = np.array(set_spt_x).astype(np.float32)[:, :, np.newaxis]
set_spt_y = np.array(set_spt_y).astype(np.float32)
set_qry_x = np.array(set_qry_x).astype(np.float32)[:, :, np.newaxis]
set_qry_y = np.array(set_qry_y).astype(np.float32)
return set_spt_x, set_spt_y, set_qry_x, set_qry_y

train_spt_x, train_spt_y, train_qry_x, train_qry_y = create_dataset_inner(train_test[0])
test_spt_x, test_spt_y, test_qry_x, test_qry_y = create_dataset_inner(train_test[1])
return (train_spt_x, train_spt_y, train_qry_x, train_qry_y), (test_spt_x, test_spt_y, test_qry_x, test_qry_y)


def test():
train_set, test_set = create_dataset_sinewave(problem_dim=(1, 3), train_test=(4, 2), spt_qry=(5, 20)) # (12, 5, 1)
print(train_set[0].shape, train_set[1].shape, train_set[2].shape, train_set[3].shape)
print(test_set[0].shape, test_set[1].shape, test_set[2].shape, test_set[3].shape)


if __name__ == '__main__':
test()

0 comments on commit ad16852

Please sign in to comment.