Skip to content

Commit

Permalink
Added correct dotprod option, renamed old dotprod to general
Browse files Browse the repository at this point in the history
  • Loading branch information
bpopeters committed Jul 23, 2017
1 parent 3c909cc commit c2d28f3
Showing 1 changed file with 59 additions and 53 deletions.
112 changes: 59 additions & 53 deletions onmt/modules/GlobalAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,36 +39,60 @@ def __init__(self, dim, coverage=False, attn_type="dotprod"):

self.dim = dim
self.attn_type = attn_type
assert (self.attn_type in ["dotprod", "mlp"]), (
assert (self.attn_type in ["dotprod", "general", "mlp"]), (
"Please select a valid attention type.")

if self.attn_type == "dotprod":
if self.attn_type == "general":
self.linear_in = nn.Linear(dim, dim, bias=False)
self.linear_out = nn.Linear(dim*2, dim, bias=False)
elif self.attn_type == "mlp":
elif self.attn_type == 'mlp':
self.linear_context = BottleLinear(dim, dim, bias=False)
self.linear_query = nn.Linear(dim, dim, bias=True)
self.mlp_tanh = nn.Tanh()
self.v = BottleLinear(dim, 1, bias=False)
self.linear_out = nn.Linear(dim*2, dim, bias=True)
out_bias = self.attn_type == 'mlp' # mlp wants it with bias
self.linear_out = nn.Linear(dim*2, dim, bias=out_bias)

self.sm = nn.Softmax()
self.tanh = nn.Tanh()
self.mask = None

if coverage:
self.linear_cover = nn.Linear(1, dim, bias=False)

def applyMask(self, mask):
self.mask = mask

def score(self, h_t, h_s):
"""
h_t (FloatTensor): batch x dim
h_s (FloatTensor): batch x src_len x dim
returns scores (FloatTensor): batch x src_len:
raw attention scores for each src index
"""

if self.attn_type in ["general", "dotprod"]:
if self.attn_type == "general":
h_t = self.linear_in(h_t)
return torch.bmm(h_s, h_t.unsqueeze(2)).squeeze(2)
else:
# MLP
# batch x 1 x dim
wq = self.linear_query(h_t).unsqueeze(1)
# batch x src_len x dim
uh = self.linear_context(h_s.contiguous())
# batch x src_len x dim
wquh = uh + wq.expand_as(uh)
# batch x src_len x dim
wquh = self.tanh(wquh)
# batch x src_len
return self.v(wquh.contiguous()).squeeze(2)

def forward(self, input, context, coverage=None):
"""
input (FloatTensor): batch x dim
context (FloatTensor): batch x sourceL x dim
coverage (FloatTensor): batch x sourceL
input (FloatTensor): batch x dim: decoder's rnn's output.
context (FloatTensor): batch x src_len x dim: src hidden states
coverage (FloatTensor): batch x src_len
"""

# Check input sizes
batch, sourceL, dim = context.size()
batch_, dim_ = input.size()
Expand All @@ -84,54 +108,36 @@ def forward(self, input, context, coverage=None):
beam_, batch_, sourceL_ = self.mask.size()
aeq(batch, batch_*beam_)
aeq(sourceL, sourceL_)

if coverage is not None:
context += self.linear_cover(coverage.view(-1).unsqueeze(1)) \
.view_as(context)
cover = coverage.view(-1).unsqueeze(1)
context += self.linear_cover(cover).view_as(context)
context = self.tanh(context)

# Alignment/Attention Function
if self.attn_type == "dotprod":
# batch x dim x 1
targetT = self.linear_in(input).unsqueeze(2)
# batch x sourceL
attn = torch.bmm(context, targetT).squeeze(2)
elif self.attn_type == "mlp":
# batch x 1 x dim
wq = self.linear_query(input).unsqueeze(1)
# batch x sourceL x dim
uh = self.linear_context(context.contiguous())
# batch x sourceL x dim
wquh = uh + wq.expand_as(uh)
# batch x sourceL x dim
wquh = self.mlp_tanh(wquh)
# batch x sourceL
attn = self.v(wquh.contiguous()).squeeze(2)


# compute attention scores, as in Luong et al.
a_t = self.score(input, context)

if self.mask is not None:
attn.data.masked_fill_(self.mask, -float('inf'))

# SoftMax
attn = self.sm(attn)

# Compute context weighted by attention.
# batch x 1 x sourceL
attn3 = attn.view(attn.size(0), 1, attn.size(1))
# batch x dim
weightedContext = torch.bmm(attn3, context).squeeze(1)

# Concatenate the input to context (Luong only)
weightedContext = torch.cat((weightedContext, input), 1)
weightedContext = self.linear_out(weightedContext)
if self.attn_type == "dotprod":
weightedContext = self.tanh(weightedContext)
attention_scores.data.masked_fill_(self.mask, -float('inf'))

# Softmax to normalize attention weights
align_vector = self.sm(a_t)

# the context vector c_t is the weighted average
# over all the source hidden states
c_t = torch.bmm(align_vector.unsqueeze(1), context).squeeze(1)

# concatenate
attn_h_t = self.linear_out(torch.cat([c_t, input], 1))
if self.attn_type in ["general", "dotprod"]:
attn_h_t = self.tanh(attn_h_t)

# Check output sizes
batch_, sourceL_ = attn.size()
batch_, sourceL_ = align_vector.size()
aeq(batch, batch_)
aeq(sourceL, sourceL_)
batch_, dim_ = weightedContext.size()
batch_, dim_ = attn_h_t.size()
aeq(batch, batch_)
aeq(dim, dim_)

return weightedContext, attn
return attn_h_t, align_vector

0 comments on commit c2d28f3

Please sign in to comment.