Skip to content

Commit cf3f628

Browse files
Soohwan KimSoohwan Kim
authored andcommitted
Updates
1 parent e940d19 commit cf3f628

File tree

8 files changed

+792
-348
lines changed

8 files changed

+792
-348
lines changed

README.md

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,27 @@ This repository contains only model code, but you can train with speech transfor
99
I appreciate any kind of [feedback or contribution](https://github.com/sooftware/Speech-Transformer/issues)
1010

1111
## Usage
12+
- Training
13+
```python
14+
import torch
15+
from speech_transformer import SpeechTransformer
16+
17+
BATCH_SIZE, SEQ_LENGTH, DIM, NUM_CLASSES = 3, 12345, 80, 4
18+
19+
cuda = torch.cuda.is_available()
20+
device = torch.device('cuda' if cuda else 'cpu')
21+
22+
inputs = torch.rand(BATCH_SIZE, SEQ_LENGTH, DIM).to(device)
23+
input_lengths = torch.IntTensor([100, 50, 8])
24+
targets = torch.LongTensor([[2, 3, 3, 3, 3, 3, 2, 2, 1, 0],
25+
[2, 3, 3, 3, 3, 3, 2, 1, 2, 0],
26+
[2, 3, 3, 3, 3, 3, 2, 2, 0, 1]]).to(device) # 1 means <eos_token>
27+
target_lengths = torch.IntTensor([10, 9, 8])
28+
29+
model = SpeechTransformer(num_classes=NUM_CLASSES, d_model=512, num_heads=8, input_dim=DIM)
30+
predictions, logits = model(inputs, input_lengths, targets, target_lengths)
31+
```
32+
- Beam Search Decoding
1233
```python
1334
import torch
1435
from speech_transformer import SpeechTransformer
@@ -20,10 +41,9 @@ device = torch.device('cuda' if cuda else 'cpu')
2041

2142
inputs = torch.rand(BATCH_SIZE, SEQ_LENGTH, DIM).to(device) # BxTxD
2243
input_lengths = torch.LongTensor([SEQ_LENGTH, SEQ_LENGTH - 10, SEQ_LENGTH - 20]).to(device)
23-
targets = torch.LongTensor([1, 2, 3, 4, 5]).to(device)
2444

25-
model = SpeechTransformer(num_classes=NUM_CLASSES, d_model=512, num_heads=8, input_dim=DIM, extractor='vgg')
26-
output = model(inputs, input_lengths, targets, return_attns=False)
45+
model = SpeechTransformer(num_classes=NUM_CLASSES, d_model=512, num_heads=8, input_dim=DIM)
46+
predictions, _ = model(inputs, input_lengths)
2747
```
2848

2949
## Troubleshoots and Contributing

speech_transformer/beam_decoder.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
2+
import torch
3+
import torch.nn as nn
4+
from torch import Tensor
5+
6+
from speech_transformer.decoder import SpeechTransformerDecoder
7+
8+
9+
class BeamTransformerDecoder(nn.Module):
10+
def __init__(self, decoder: SpeechTransformerDecoder, batch_size: int, beam_size: int = 3) -> None:
11+
super(BeamTransformerDecoder, self).__init__()
12+
self.decoder = decoder
13+
self.beam_size = beam_size
14+
self.sos_id = decoder.sos_id
15+
self.pad_id = decoder.pad_id
16+
self.eos_id = decoder.eos_id
17+
self.ongoing_beams = None
18+
self.cumulative_ps = None
19+
self.finished = [[] for _ in range(batch_size)]
20+
self.finished_ps = [[] for _ in range(batch_size)]
21+
self.forward_step = decoder.forward_step
22+
self.use_cuda = True if torch.cuda.is_available() else False
23+
24+
def _inflate(self, tensor: Tensor, n_repeat: int, dim: int) -> Tensor:
25+
repeat_dims = [1] * len(tensor.size())
26+
repeat_dims[dim] *= n_repeat
27+
28+
return tensor.repeat(*repeat_dims)
29+
30+
def _get_successor(
31+
self,
32+
current_ps: Tensor,
33+
current_vs: Tensor,
34+
finished_ids: tuple,
35+
num_successor: int,
36+
eos_count: int,
37+
k: int
38+
) -> int:
39+
finished_batch_idx, finished_idx = finished_ids
40+
41+
successor_ids = current_ps.topk(k + num_successor)[1]
42+
successor_idx = successor_ids[finished_batch_idx, -1]
43+
44+
successor_p = current_ps[finished_batch_idx, successor_idx]
45+
successor_v = current_vs[finished_batch_idx, successor_idx]
46+
47+
prev_status_idx = (successor_idx // k)
48+
prev_status = self.ongoing_beams[finished_batch_idx, prev_status_idx]
49+
prev_status = prev_status.view(-1)[:-1]
50+
51+
successor = torch.cat([prev_status, successor_v.view(1)])
52+
53+
if int(successor_v) == self.eos_id:
54+
self.finished[finished_batch_idx].append(successor)
55+
self.finished_ps[finished_batch_idx].append(successor_p)
56+
eos_count = self._get_successor(
57+
current_ps=current_ps,
58+
current_vs=current_vs,
59+
finished_ids=finished_ids,
60+
num_successor=num_successor + eos_count,
61+
eos_count=eos_count + 1,
62+
k=k,
63+
)
64+
65+
else:
66+
self.ongoing_beams[finished_batch_idx, finished_idx] = successor
67+
self.cumulative_ps[finished_batch_idx, finished_idx] = successor_p
68+
69+
return eos_count
70+
71+
def _get_hypothesis(self):
72+
predictions = list()
73+
74+
for batch_idx, batch in enumerate(self.finished):
75+
# if there is no terminated sentences, bring ongoing sentence which has the highest probability instead
76+
if len(batch) == 0:
77+
prob_batch = self.cumulative_ps[batch_idx]
78+
top_beam_idx = int(prob_batch.topk(1)[1])
79+
predictions.append(self.ongoing_beams[batch_idx, top_beam_idx])
80+
81+
# bring highest probability sentence
82+
else:
83+
top_beam_idx = int(torch.FloatTensor(self.finished_ps[batch_idx]).topk(1)[1])
84+
predictions.append(self.finished[batch_idx][top_beam_idx])
85+
86+
predictions = self._fill_sequence(predictions)
87+
return predictions
88+
89+
def _is_all_finished(self, k: int) -> bool:
90+
for done in self.finished:
91+
if len(done) < k:
92+
return False
93+
94+
return True
95+
96+
def _fill_sequence(self, y_hats: list) -> Tensor:
97+
batch_size = len(y_hats)
98+
max_length = -1
99+
100+
for y_hat in y_hats:
101+
if len(y_hat) > max_length:
102+
max_length = len(y_hat)
103+
104+
matched = torch.zeros((batch_size, max_length), dtype=torch.long)
105+
106+
for batch_idx, y_hat in enumerate(y_hats):
107+
matched[batch_idx, :len(y_hat)] = y_hat
108+
matched[batch_idx, len(y_hat):] = int(self.pad_id)
109+
110+
return matched
111+
112+
def forward(self, encoder_outputs: torch.FloatTensor, encoder_output_lengths: torch.FloatTensor):
113+
batch_size = encoder_outputs.size(0)
114+
115+
decoder_inputs = torch.IntTensor(batch_size, self.decoder.max_length).fill_(self.sos_id).long()
116+
decoder_input_lengths = torch.IntTensor(batch_size).fill_(1)
117+
118+
outputs = self.forward_step(
119+
decoder_inputs=decoder_inputs[:, :1],
120+
decoder_input_lengths=decoder_input_lengths,
121+
encoder_outputs=encoder_outputs,
122+
encoder_output_lengths=encoder_output_lengths,
123+
positional_encoding_length=1,
124+
)
125+
step_outputs = self.decoder.fc(outputs).log_softmax(dim=-1)
126+
self.cumulative_ps, self.ongoing_beams = step_outputs.topk(self.beam_size)
127+
128+
self.ongoing_beams = self.ongoing_beams.view(batch_size * self.beam_size, 1)
129+
self.cumulative_ps = self.cumulative_ps.view(batch_size * self.beam_size, 1)
130+
131+
decoder_inputs = torch.IntTensor(batch_size * self.beam_size, 1).fill_(self.sos_id)
132+
decoder_inputs = torch.cat((decoder_inputs, self.ongoing_beams), dim=-1) # bsz * beam x 2
133+
134+
encoder_dim = encoder_outputs.size(2)
135+
encoder_outputs = self._inflate(encoder_outputs, self.beam_size, dim=0)
136+
encoder_outputs = encoder_outputs.view(self.beam_size, batch_size, -1, encoder_dim)
137+
encoder_outputs = encoder_outputs.transpose(0, 1)
138+
encoder_outputs = encoder_outputs.reshape(batch_size * self.beam_size, -1, encoder_dim)
139+
140+
encoder_output_lengths = encoder_output_lengths.unsqueeze(1).repeat(1, self.beam_size).view(-1)
141+
142+
for di in range(2, self.decoder.max_length):
143+
if self._is_all_finished(self.beam_size):
144+
break
145+
146+
decoder_input_lengths = torch.LongTensor(batch_size * self.beam_size).fill_(di)
147+
148+
step_outputs = self.forward_step(
149+
decoder_inputs=decoder_inputs[:, :di],
150+
decoder_input_lengths=decoder_input_lengths,
151+
encoder_outputs=encoder_outputs,
152+
encoder_output_lengths=encoder_output_lengths,
153+
positional_encoding_length=di,
154+
)
155+
step_outputs = self.decoder.fc(step_outputs).log_softmax(dim=-1)
156+
157+
step_outputs = step_outputs.view(batch_size, self.beam_size, -1, 10)
158+
current_ps, current_vs = step_outputs.topk(self.beam_size)
159+
160+
# TODO: Check transformer's beam search
161+
current_ps = current_ps[:, :, -1, :]
162+
current_vs = current_vs[:, :, -1, :]
163+
164+
self.cumulative_ps = self.cumulative_ps.view(batch_size, self.beam_size)
165+
self.ongoing_beams = self.ongoing_beams.view(batch_size, self.beam_size, -1)
166+
167+
current_ps = (current_ps.permute(0, 2, 1) + self.cumulative_ps.unsqueeze(1)).permute(0, 2, 1)
168+
current_ps = current_ps.view(batch_size, self.beam_size ** 2)
169+
current_vs = current_vs.contiguous().view(batch_size, self.beam_size ** 2)
170+
171+
self.cumulative_ps = self.cumulative_ps.view(batch_size, self.beam_size)
172+
self.ongoing_beams = self.ongoing_beams.view(batch_size, self.beam_size, -1)
173+
174+
topk_current_ps, topk_status_ids = current_ps.topk(self.beam_size)
175+
prev_status_ids = (topk_status_ids // self.beam_size)
176+
177+
topk_current_vs = torch.zeros((batch_size, self.beam_size), dtype=torch.long)
178+
prev_status = torch.zeros(self.ongoing_beams.size(), dtype=torch.long)
179+
180+
for batch_idx, batch in enumerate(topk_status_ids):
181+
for idx, topk_status_idx in enumerate(batch):
182+
topk_current_vs[batch_idx, idx] = current_vs[batch_idx, topk_status_idx]
183+
prev_status[batch_idx, idx] = self.ongoing_beams[batch_idx, prev_status_ids[batch_idx, idx]]
184+
185+
self.ongoing_beams = torch.cat([prev_status, topk_current_vs.unsqueeze(2)], dim=2)
186+
self.cumulative_ps = topk_current_ps
187+
188+
if torch.any(topk_current_vs == self.eos_id):
189+
finished_ids = torch.where(topk_current_vs == self.eos_id)
190+
num_successors = [1] * batch_size
191+
192+
for (batch_idx, idx) in zip(*finished_ids):
193+
self.finished[batch_idx].append(self.ongoing_beams[batch_idx, idx])
194+
self.finished_ps[batch_idx].append(self.cumulative_ps[batch_idx, idx])
195+
196+
if self.beam_size != 1:
197+
eos_count = self._get_successor(
198+
current_ps=current_ps,
199+
current_vs=current_vs,
200+
finished_ids=(batch_idx, idx),
201+
num_successor=num_successors[batch_idx],
202+
eos_count=1,
203+
k=self.beam_size,
204+
)
205+
num_successors[batch_idx] += eos_count
206+
207+
ongoing_beams = self.ongoing_beams.clone().view(batch_size * self.beam_size, -1)
208+
decoder_inputs = torch.cat((decoder_inputs, ongoing_beams[:, :-1]), dim=-1)
209+
210+
return self._get_hypothesis()

0 commit comments

Comments
 (0)