Skip to content

Commit

Permalink
Have EnsembleDecoder set attentional property. (OpenNMT#1381)
Browse files Browse the repository at this point in the history
  • Loading branch information
flauted authored and vince62s committed Apr 3, 2019
1 parent f09cc8c commit 9809c4c
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions onmt/decoders/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn

from onmt.encoders.encoder import EncoderBase
from onmt.decoders.decoder import DecoderBase
from onmt.models import NMTModel
import onmt.model_builder

Expand Down Expand Up @@ -42,11 +43,13 @@ def forward(self, src, lengths=None):
return enc_hidden, memory_bank, lengths


class EnsembleDecoder(nn.Module):
class EnsembleDecoder(DecoderBase):
"""Dummy Decoder that delegates to individual real Decoders."""
def __init__(self, model_decoders):
super(EnsembleDecoder, self).__init__()
self.model_decoders = nn.ModuleList(model_decoders)
model_decoders = nn.ModuleList(model_decoders)
attentional = any([dec.attentional for dec in model_decoders])
super(EnsembleDecoder, self).__init__(attentional)
self.model_decoders = model_decoders

def forward(self, tgt, memory_bank, memory_lengths=None, step=None):
"""See :func:`onmt.decoders.decoder.DecoderBase.forward()`."""
Expand All @@ -65,7 +68,8 @@ def forward(self, tgt, memory_bank, memory_lengths=None, step=None):
def combine_attns(self, attns):
result = {}
for key in attns[0].keys():
result[key] = torch.stack([attn[key] for attn in attns]).mean(0)
result[key] = torch.stack(
[attn[key] for attn in attns if attn[key] is not None]).mean(0)
return result

def init_state(self, src, memory_bank, enc_hidden):
Expand Down

0 comments on commit 9809c4c

Please sign in to comment.