forked from williamSYSU/TextGAN-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRelGAN_G.py
126 lines (109 loc) · 4.89 KB
/
RelGAN_G.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# -*- coding: utf-8 -*-
# @Author : William
# @Project : TextGAN-william
# @FileName : RelGAN_G.py
# @Time : Created at 2019-04-25
# @Blog : http://zhiweil.ml/
# @Description :
# Copyrights (C) 2018. All Rights Reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
import config as cfg
from models.generator import LSTMGenerator
from models.relational_rnn_general import RelationalMemory
class RelGAN_G(LSTMGenerator):
def __init__(self, mem_slots, num_heads, head_size, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx,
gpu=False):
super(RelGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu)
self.name = 'relgan'
self.temperature = 1.0 # init value is 1.0
self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
if cfg.model_type == 'LSTM':
# LSTM
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, batch_first=True)
self.lstm2out = nn.Linear(self.hidden_dim, vocab_size)
else:
# RMC
self.hidden_dim = mem_slots * num_heads * head_size
self.lstm = RelationalMemory(mem_slots=mem_slots, head_size=head_size, input_size=embedding_dim,
num_heads=num_heads, return_all_outputs=True)
self.lstm2out = nn.Linear(self.hidden_dim, vocab_size)
self.init_params()
pass
def init_hidden(self, batch_size=cfg.batch_size):
if cfg.model_type == 'LSTM':
h = torch.zeros(1, batch_size, self.hidden_dim)
c = torch.zeros(1, batch_size, self.hidden_dim)
if self.gpu:
return h.cuda(), c.cuda()
else:
return h, c
else:
"""init RMC memory"""
memory = self.lstm.initial_state(batch_size)
memory = self.lstm.repackage_hidden(memory) # detch memory at first
return memory.cuda() if self.gpu else memory
def step(self, inp, hidden):
"""
RelGAN step forward
:param inp: [batch_size]
:param hidden: memory size
:return: pred, hidden, next_token, next_token_onehot, next_o
- pred: batch_size * vocab_size, use for adversarial training backward
- hidden: next hidden
- next_token: [batch_size], next sentence token
- next_token_onehot: batch_size * vocab_size, not used yet
- next_o: batch_size * vocab_size, not used yet
"""
emb = self.embeddings(inp).unsqueeze(1)
out, hidden = self.lstm(emb, hidden)
gumbel_t = self.add_gumbel(self.lstm2out(out.squeeze(1)))
next_token = torch.argmax(gumbel_t, dim=1).detach()
# next_token_onehot = F.one_hot(next_token, cfg.vocab_size).float() # not used yet
next_token_onehot = None
pred = F.softmax(gumbel_t * self.temperature, dim=-1) # batch_size * vocab_size
# next_o = torch.sum(next_token_onehot * pred, dim=1) # not used yet
next_o = None
return pred, hidden, next_token, next_token_onehot, next_o
def sample(self, num_samples, batch_size, one_hot=False, start_letter=cfg.start_letter):
"""
Sample from RelGAN Generator
- one_hot: if return pred of RelGAN, used for adversarial training
:return:
- all_preds: batch_size * seq_len * vocab_size, only use for a batch
- samples: all samples
"""
global all_preds
num_batch = num_samples // batch_size + 1 if num_samples != batch_size else 1
samples = torch.zeros(num_batch * batch_size, self.max_seq_len).long()
if one_hot:
all_preds = torch.zeros(batch_size, self.max_seq_len, self.vocab_size)
if self.gpu:
all_preds = all_preds.cuda()
for b in range(num_batch):
hidden = self.init_hidden(batch_size)
inp = torch.LongTensor([start_letter] * batch_size)
if self.gpu:
inp = inp.cuda()
for i in range(self.max_seq_len):
pred, hidden, next_token, _, _ = self.step(inp, hidden)
samples[b * batch_size:(b + 1) * batch_size, i] = next_token
if one_hot:
all_preds[:, i] = pred
inp = next_token
samples = samples[:num_samples] # num_samples * seq_len
if one_hot:
return all_preds # batch_size * seq_len * vocab_size
return samples
@staticmethod
def add_gumbel(o_t, eps=1e-10, gpu=cfg.CUDA):
"""Add o_t by a vector sampled from Gumbel(0,1)"""
u = torch.zeros(o_t.size())
if gpu:
u = u.cuda()
u.uniform_(0, 1)
g_t = -torch.log(-torch.log(u + eps) + eps)
gumbel_t = o_t + g_t
return gumbel_t