Skip to content

Commit

Permalink
Add GNMT (dmlc#64)
Browse files Browse the repository at this point in the history
* Add gnmt

* fix lint

* fix

* fix lint

* fix lint

* move to scripts

* fix

* fix

* fix

* specify beam_size
  • Loading branch information
sxjscience authored and szha committed Apr 22, 2018
1 parent b19f820 commit 043e300
Show file tree
Hide file tree
Showing 11 changed files with 1,587 additions and 24 deletions.
4 changes: 2 additions & 2 deletions gluonnlp/model/attention_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __call__(self, query, key, value=None, mask=None): # pylint: disable=argume
att_weights : Symbol or NDArray
Attention weights. Shape (batch_size, query_length, memory_length)
"""
return super(AttentionCell, self).__call__(query, key, value, mask)
return self.forward(query, key, value, mask)

def forward(self, query, key, value=None, mask=None): # pylint: disable=arguments-differ
if value is None:
Expand Down Expand Up @@ -242,7 +242,7 @@ def __call__(self, query, key, value=None, mask=None):
Attention weights of multiple heads.
Shape (batch_size, num_heads, query_length, memory_length)
"""
return super(MultiHeadAttentionCell, self).__call__(query, key, value, mask)
return self.forward(query, key, value, mask)

def _compute_weight(self, F, query, key, mask=None):
query = self.proj_query(query) # Shape (batch_size, query_length, query_units)
Expand Down
22 changes: 22 additions & 0 deletions scripts/nmt/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# coding: utf-8

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Constants used in the NMT examples."""
import os

CACHE_PATH = os.path.realpath(os.path.join(os.path.realpath(__file__), '..', 'cached'))
20 changes: 11 additions & 9 deletions scripts/nmt/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,16 @@ def _ngrams(segment, n):
return ngram_counts


def compute_bleu(reference_corpus, translation_corpus, max_n=4, smooth=False, lower_case=False):
def compute_bleu(reference_corpus_list, translation_corpus,
max_n=4, smooth=False, lower_case=False):
"""Compute bleu score of translation against references.
Parameters
----------
reference_corpus: list(list(list(str)))
List of lists of references for each translation.
reference_corpus_list: list of list(list(str))
List of references for each translation.
translation_corpus: list(list(str))
List of translations to score.
Translations to score.
max_n: int, default 4
Maximum n-gram order to use when computing BLEU score.
smooth: bool, default False
Expand All @@ -65,13 +66,14 @@ def compute_bleu(reference_corpus, translation_corpus, max_n=4, smooth=False, lo
5-Tuple with the BLEU score, n-gram precisions, brevity penalty,
reference length, and translation length
"""
precision_numerators = Counter()
precision_denominators = Counter()
precision_numerators = [0 for _ in range(max_n)]
precision_denominators = [0 for _ in range(max_n)]
ref_length, trans_length = 0, 0
assert len(reference_corpus) == len(translation_corpus), \
'The number of translations and their references do not match'
for references in reference_corpus_list:
assert len(references) == len(translation_corpus), \
'The number of translations and their references do not match'

for references, translation in zip(reference_corpus, translation_corpus):
for references, translation in zip(zip(*reference_corpus_list), translation_corpus):
if lower_case:
references = [list(map(str.lower, reference)) for reference in references]
translation = list(map(str.lower, translation))
Expand Down
Loading

0 comments on commit 043e300

Please sign in to comment.