-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add docs of CLI scripts * Small fix of docstring * Add a document of metric customization * Add documents * Use directives * Add documents of timer * Highlight which lines are changed
- Loading branch information
Showing
17 changed files
with
636 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
Python interface | ||
================ | ||
|
||
mbrs is implemented in Python and PyTorch. | ||
|
||
.. seealso:: | ||
|
||
:doc:`References of Python API <./source/mbrs>` | ||
Detailed documentation of Python API. | ||
|
||
Examples | ||
-------- | ||
This is a Python API example of COMET-MBR. | ||
|
||
.. code:: python | ||
from mbrs.metrics import MetricCOMET | ||
from mbrs.decoders import DecoderMBR | ||
SOURCE = "ありがとう" | ||
HYPOTHESES = ["Thanks", "Thank you", "Thank you so much", "Thank you.", "thank you"] | ||
# Setup COMET. | ||
metric_cfg = MetricCOMET.Config( | ||
model="Unbabel/wmt22-comet-da", | ||
batch_size=64, | ||
fp16=True, | ||
) | ||
metric = MetricCOMET(metric_cfg) | ||
# Setup MBR decoding. | ||
decoder_cfg = DecoderMBR.Config() | ||
decoder = DecoderMBR(decoder_cfg, metric) | ||
# Decode by COMET-MBR. | ||
# This example regards the hypotheses themselves as the pseudo-references. | ||
# Args: (hypotheses, pseudo-references, source) | ||
output = decoder.decode(HYPOTHESES, HYPOTHESES, source=SOURCE, nbest=1) | ||
print(f"Selected index: {output.idx}") | ||
print(f"Output sentence: {output.sentence}") | ||
print(f"Expected score: {output.score}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
Command-line interface | ||
====================== | ||
|
||
mbrs provides useful command-line interface (CLI) scripts. | ||
|
||
.. seealso:: | ||
|
||
:doc:`Manual of CLI options <./cli_help>` | ||
Detailed documentation of CLI options. | ||
|
||
Overview | ||
-------- | ||
|
||
Command-line interface can run MBR decoding from command-line. | ||
Before running MBR decoding, you can generate hypothesis sentences with :code:`mbrs-generate`: | ||
|
||
.. code:: bash | ||
mbrs-generate \ | ||
sources.txt \ | ||
--output hypotheses.txt \ | ||
--lang_pair en-de \ | ||
--model facebook/m2m100_418M \ | ||
--num_candidates 1024 \ | ||
--sampling eps --epsilon 0.02 \ | ||
--batch_size 8 --sampling_size 8 --fp16 \ | ||
--report_format rounded_outline | ||
Beam search can also be used by replacing :code:`--sampling eps --epsilon 0.02` with :code:`--beam_size 10`. | ||
|
||
Next, MBR decoding and other decoding methods can be executed with :code:`mbrs-decode`. | ||
This example regards the hypothesis set as the pseudo-reference set. | ||
|
||
.. code:: bash | ||
mbrs-decode \ | ||
hypotheses.txt \ | ||
--num_candidates 1024 \ | ||
--nbest 1 \ | ||
--source sources.txt \ | ||
--references hypotheses.txt \ | ||
--output translations.txt \ | ||
--report report.txt --report_format rounded_outline \ | ||
--decoder mbr \ | ||
--metric comet \ | ||
--metric.model Unbabel/wmt22-comet-da \ | ||
--metric.batch_size 64 --metric.fp16 true | ||
Finally, you can evaluate the score with :code:`mbrs-score`: | ||
|
||
.. code:: bash | ||
mbrs-score \ | ||
hypotheses.txt \ | ||
--sources sources.txt \ | ||
--references hypotheses.txt \ | ||
--format json \ | ||
--metric bleurt \ | ||
--metric.batch_size 64 --metric.fp16 true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
Manual of CLI options | ||
===================== | ||
|
||
mbrs-generate | ||
~~~~~~~~~~~~~ | ||
|
||
.. argparse:: | ||
:module: mbrs.cli.generate | ||
:func: get_argparser | ||
:prog: mbrs-generate | ||
|
||
mbrs-decode | ||
~~~~~~~~~~~ | ||
|
||
.. argparse:: | ||
:module: mbrs.cli.decode | ||
:func: get_argparser | ||
:prog: mbrs-decode | ||
|
||
mbrs-score | ||
~~~~~~~~~~ | ||
|
||
.. argparse:: | ||
:module: mbrs.cli.score | ||
:func: get_argparser | ||
:prog: mbrs-score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
How to define a new decoder | ||
=========================== | ||
|
||
Examples | ||
~~~~~~~~ | ||
|
||
This tutorial explains how to define a new decoder. | ||
The below example implements the naive MBR decoding and extends the output object to return other features. | ||
|
||
1. Inherit an abstract class defined in :code:`mbrs.decoders.base`. | ||
|
||
- :code:`DecoderReferenceBased` is mainly used for MBR decoding that returns the N most probable hypotheses using sets of hypotheses and pseudo-references. | ||
|
||
.. code-block:: python | ||
:emphasize-lines: 1- | ||
from mbrs.decoders.base import DecoderReferenceBased | ||
class DecoderMBRWithAllScores(DecoderReferenceBased): | ||
"""Naive MBR decoder class.""" | ||
2. Define the configuration dataclass if you need to add options. | ||
|
||
- Configuration dataclass :code:`DecoderMBRWithAllScores.Config` should inherit that of the parent class for consistency. | ||
|
||
.. code-block:: python | ||
:emphasize-lines: 1,9- | ||
from dataclasses import dataclass | ||
from mbrs.decoders.base import DecoderReferenceBased | ||
class DecoderMBRWithAllScores(DecoderReferenceBased): | ||
"""Naive MBR decoder class.""" | ||
@dataclass | ||
class Config(DecoderMBRWithAllScores.Config): | ||
"""Naive MBR decoder configuration.""" | ||
sort_scores: bool = False | ||
3. Child classes of :code:`DecoderReferenceBased` requires to implement the :code:`decode()` method. | ||
|
||
.. code-block:: python | ||
:emphasize-lines: 2,16- | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
from mbrs.decoders.base import DecoderReferenceBased | ||
class DecoderMBRWithAllScores(DecoderReferenceBased): | ||
"""Naive MBR decoder class.""" | ||
@dataclass | ||
class Config(DecoderMBRWithAllScores.Config): | ||
"""Naive MBR decoder configuration.""" | ||
sort_scores: bool = False | ||
def decode( | ||
self, | ||
hypotheses: list[str], | ||
references: list[str], | ||
source: Optional[str] = None, | ||
nbest: int = 1, | ||
reference_lprobs: Optional[Tensor] = None, | ||
) -> DecoderMBRWithAllScores.Output: | ||
expected_scores = self.metric.expected_scores( | ||
hypotheses, references, source, reference_lprobs=reference_lprobs | ||
) | ||
topk_scores, topk_indices = self.metric.topk(expected_scores, k=nbest) | ||
return self.Output( | ||
idx=topk_indices, | ||
sentence=[hypotheses[idx] for idx in topk_indices], | ||
score=topk_scores, | ||
) | ||
4. In this example, we extend the output dataclass to include all expected scores. | ||
|
||
- :code:`DecoderMBRWithAllScores.Output` needs to inherit the parent output dataclass. | ||
|
||
.. code-block:: python | ||
:emphasize-lines: 4,16-18,33-36,42 | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
from torch import Tensor | ||
from mbrs.decoders.base import DecoderReferenceBased | ||
class DecoderMBRWithAllScores(DecoderReferenceBased): | ||
"""Naive MBR decoder class.""" | ||
@dataclass | ||
class Config(DecoderMBRWithAllScores.Config): | ||
sort_scores: bool = False | ||
@dataclass | ||
class Output(DecoderReferenceBased.Output): | ||
all_scores: Optional[Tensor] = None | ||
def decode( | ||
self, | ||
hypotheses: list[str], | ||
references: list[str], | ||
source: Optional[str] = None, | ||
nbest: int = 1, | ||
reference_lprobs: Optional[Tensor] = None, | ||
) -> DecoderMBRWithAllScores.Output: | ||
expected_scores = self.metric.expected_scores( | ||
hypotheses, references, source, reference_lprobs=reference_lprobs | ||
) | ||
topk_scores, topk_indices = self.metric.topk(expected_scores, k=nbest) | ||
if self.cfg.sort_scores: | ||
all_scores = expected_scores.sort(dim=-1, descending=self.metric.HIGH_IS_BETTER) | ||
else: | ||
all_scores = expected_scores | ||
return self.Output( | ||
idx=topk_indices, | ||
sentence=[hypotheses[idx] for idx in topk_indices], | ||
score=topk_scores, | ||
all_scores=all_scores, | ||
) | ||
5. Finally, register the class to be called from CLI. | ||
|
||
- Just add :code:`@register("mbr_with_all_scores")` to the class definition. | ||
|
||
.. code-block:: python | ||
:emphasize-lines: 9 | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
from torch import Tensor | ||
from mbrs.decoders.base import DecoderReferenceBased, register | ||
@register("mbr_with_all_scores") | ||
class DecoderMBRWithAllScores(DecoderReferenceBased): | ||
"""Naive MBR decoder class.""" | ||
@dataclass | ||
class Config(DecoderMBRWithAllScores.Config): | ||
sort_scores: bool = False | ||
@dataclass | ||
class Output(DecoderReferenceBased.Output): | ||
all_scores: Optional[Tensor] = None | ||
def decode( | ||
self, | ||
hypotheses: list[str], | ||
references: list[str], | ||
source: Optional[str] = None, | ||
nbest: int = 1, | ||
reference_lprobs: Optional[Tensor] = None, | ||
) -> DecoderMBRWithAllScores.Output: | ||
expected_scores = self.metric.expected_scores( | ||
hypotheses, references, source, reference_lprobs=reference_lprobs | ||
) | ||
topk_scores, topk_indices = self.metric.topk(expected_scores, k=nbest) | ||
if self.cfg.sort_scores: | ||
all_scores = expected_scores.sort(dim=-1, descending=self.metric.HIGH_IS_BETTER) | ||
else: | ||
all_scores = expected_scores | ||
return self.Output( | ||
idx=topk_indices, | ||
sentence=[hypotheses[idx] for idx in topk_indices], | ||
score=topk_scores, | ||
all_scores=all_scores, | ||
) | ||
.. note:: | ||
|
||
All methods should have the same types for both inputs and outputs as the base class. |
Oops, something went wrong.