Skip to content

more flexible code based off of github.com/GokuMohandas/fast-weights #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,5 @@ dmypy.json

# Pyre type checker
.pyre/

associative-retrieval.pkl
13 changes: 9 additions & 4 deletions generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
import random
import cPickle as pickle
try:
import cPickle as pickle
except:
import pickle

num_train = 60000
num_val = 10000
Expand Down Expand Up @@ -36,7 +39,7 @@ def generate_one():

for i in range(0, step_num):
c = random.randint(0, 25)
while d.has_key(c):
while d.get(c):
c = random.randint(0, 25)
b = random.randint(0, 9)
d[c] = b
Expand All @@ -45,14 +48,16 @@ def generate_one():
a[i*2] = get_one_hot(s)
a[i*2+1] = get_one_hot(t)

s = random.choice(d.keys())
s = random.choice(list(d.keys()))
t = chr(s + ord('a'))
r = chr(d[s] + ord('0'))
a[step_num * 2] = get_one_hot('?')
a[step_num * 2 + 1] = get_one_hot('?')
a[step_num * 2 + 2] = get_one_hot(t)
st += '??' + t + r

e = get_one_hot(r)

return a, e

if __name__ == '__main__':
Expand All @@ -74,4 +79,4 @@ def generate_one():
'y_val': y_val
}
with open('associative-retrieval.pkl', 'wb') as f:
pickle.dump(d, f, protocol=2)
pickle.dump(d, f, protocol=2)
83 changes: 52 additions & 31 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import print_function

import os
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from config import cfg
from tensorboardX import SummaryWriter
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
import time
from retrieval import read_data
Expand All @@ -23,34 +24,53 @@ def softmax_cross_entropy_with_logits(logits, labels):

class fast_weights_model(nn.Module):
"""docstring for fast_weights_model"""
def __init__(self, batch_size, step_num, elem_num, hidden_num):
def __init__(self, args):
super(fast_weights_model, self).__init__()
self.x = Variable(torch.randn(batch_size, step_num, elem_num).type(torch.float32))
self.y = Variable(torch.randn(batch_size, elem_num).type(torch.float32))
self.l = torch.tensor([0.9], dtype=torch.float32)
self.e = torch.tensor([0.5], dtype=torch.float32)

self.w1 = Variable(torch.empty(elem_num, 50).uniform_(-np.sqrt(0.02), np.sqrt(0.02)))
self.b1 = Variable(torch.zeros([1, 50]).type(torch.float32))
self.w2 = Variable(torch.empty(500, 100).uniform_(-np.sqrt(0.01), np.sqrt(0.01)))
self.b2 = Variable(torch.zeros([1, 100]).type(torch.float32))
self.w3 = Variable(torch.empty(hidden_num, 100).uniform_(-np.sqrt(0.01), np.sqrt(0.01)))
self.b3 = Variable(torch.zeros([1, 100]).type(torch.float32))
self.w4 = Variable(torch.empty(100, elem_num).uniform_(-np.sqrt(1.0 / elem_num), np.sqrt(1.0 / elem_num)))
self.b4 = Variable(torch.zeros([1, elem_num]).type(torch.float32))

self.w = Variable(torch.tensor(0.05 * np.identity(hidden_num)).type(torch.float32))

self.c = Variable(torch.empty(100, hidden_num).uniform_(-np.sqrt(hidden_num), np.sqrt(hidden_num)))

self.g = Variable(torch.ones([1, hidden_num]).type(torch.float32))
self.b = Variable(torch.ones([1, hidden_num]).type(torch.float32))

def forward(self, bx, by)
self.batch_size = args.batch_size
# Inputs
self.X = Variable(torch.randn(args.batch_size, args.input_dim, args.num_classes).type(torch.float32))
# Targets
self.y = Variable(torch.randn(args.batch_size, args.num_classes).type(torch.float32))
# Learning Rate
self.l = torch.tensor([args.learning_rate], dtype=torch.float32)
# Decay Rate
self.e = torch.tensor([args.decay_rate], dtype=torch.float32)

# Input Weights
self.W_x = Variable(torch.empty(
args.num_classes,
args.hidden_size).uniform_(
-np.sqrt(2.0/args.num_classes),
np.sqrt(2.0/args.num_classes)
), dtype=torch.float32)
self.b_x = Variable(torch.zeros(
[args.hidden_size]
), dtype=torch.float32)

# Hidden weights (initialization explained in Hinton video)
self.W_h = Variable(initial_value=0.5 * np.identity(args.hidden_size),
dtype=torch.float32)

# Softmax weights
self.W_softmax = Variable(torch.empty(
args.hidden_size,
args.num_classes).uniform_(
-np.sqrt(2.0/args.hidden_size),
np.sqrt(2.0/args.hidden_size)
), dtype=torch.float32)
self.b_softmax = Variable(torch.zeros(args.num_classes),
dtype=torch.float32)

# Scale and shift everything for layernorm
self.gain = Variable(torch.ones(args.hidden_size), dtype=torch.float32)
self.bias = Variable(torch.zeros(args.hidden_size), dtype=torch.float32)


def forward(self, bx, by):
self.x = bx
self.y = by
a = torch.zeros([batch_size, hidden_num, hidden_num]).type(torch.float32)
h = torch.zeros([batch_size, hidden_num]).type(torch.float32)
a = torch.zeros([self.batch_size, hidden_num, hidden_num]).type(torch.float32)
h = torch.zeros([self.batch_size, hidden_num]).type(torch.float32)

la = []

Expand All @@ -60,7 +80,7 @@ def forward(self, bx, by)

h = torch.relu(torch.matmul(h, self.w) + torch.matmul(z, self.c))

hs = torch.reshape(h, [batch_size, 1, hidden_num])
hs = torch.reshape(h, [self.batch_size, 1, hidden_num])

hh = hs

Expand All @@ -75,7 +95,7 @@ def forward(self, bx, by)
sig = torch.sqrt(torch.mean(torch.pow((hs - mu), 2), 0))
hs = torch.relu(torch.div(torch.mul(self.g, (hs - mu)), sig) + self.b)

h = torch.reshape(hs, [batch_size, hidden_num])
h = torch.reshape(hs, [self.batch_size, hidden_num])

h = torch.relu(torch.matmul(h, self.w3) + self.b3)
logits = torch.matmul(h, self.w4) + self.b4
Expand All @@ -85,12 +105,13 @@ def forward(self, bx, by)

return self.loss, self.acc

def train(self, save = 0, verbose = 0):
model = fast_weights_model(STEP_NUM, ELEM_NUM, HIDDEN_NUM)
def train(save = 0, verbose = 0):
BATCH_SIZE = 60000
model = fast_weights_model(BATCH_SIZE, STEP_NUM, ELEM_NUM, HIDDEN_NUM)
model.train()
batch_size = cfg.train.batch_size
start_time = time.time()
optimizer = torch.optim.Adam(model.paramters(), lr=cfg.train.model_lr)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.train.model_lr)
writer = SummaryWriter(logdir=os.path.join(cfg.logdir, cfg.exp_name), flush_secs=30)
checkpointer = Checkpointer(os.path.join(cfg.checkpointdir, cfg.exp_name))
start_epoch = 0
Expand Down
12 changes: 6 additions & 6 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch import nn
from torch import optim
from tensorboardX import SummaryWriter
from torch.utils.tensorboard import SummaryWriter

class Checkpointer:
def __init__(self, path, max_num=3):
Expand All @@ -19,8 +19,8 @@ def __init__(self, path, max_num=3):
with open(self.listfile, 'wb') as f:
model_list = []
pickle.dump(model_list, f)


def save(self, model, optimizer, epoch):
checkpoint = {
'model': model.state_dict(),
Expand All @@ -38,10 +38,10 @@ def save(self, model, optimizer, epoch):
model_list.append(filename)
with open(self.listfile, 'rb+') as f:
pickle.dump(model_list, f)

with open(filename, 'wb') as f:
torch.save(checkpoint, f)

def load(self, model, optimizer):
"""
Return starting epoch
Expand All @@ -56,4 +56,4 @@ def load(self, model, optimizer):
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
print('Load checkpoint from {}.'.format(model_list[-1]))
return checkpoint['epoch']
return checkpoint['epoch']