Skip to content

Commit

Permalink
update doc about transform stats (OpenNMT#2095)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zenglinxiao authored Sep 22, 2021
1 parent 546c244 commit 71026eb
Showing 1 changed file with 33 additions and 14 deletions.
47 changes: 33 additions & 14 deletions docs/source/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,6 @@ data:
tgt_prefix: __some_tgt_prefix__
```



### Tokenization

Common options for the tokenization transforms are the following:
Expand All @@ -278,7 +276,7 @@ Common options for the tokenization transforms are the following:

Transform name: `onmt_tokenize`

Class: `onmt.transforms.misc.ONMTTokenizerTransform`
Class: `onmt.transforms.tokenize.ONMTTokenizerTransform`

Additional options are available:
- `src_subword_type`: type of subword model for source side (from `["none", "sentencepiece", "bpe"]`);
Expand All @@ -290,15 +288,15 @@ Additional options are available:

Transform name: `sentencepiece`

Class: `onmt.transforms.misc.SentencePieceTransform`
Class: `onmt.transforms.tokenize.SentencePieceTransform`

The `src_subword_model` and `tgt_subword_model` should be valid sentencepiece models.

#### BPE ([subword-nmt](https://github.com/rsennrich/subword-nmt))

Transform name: `bpe`

Class: `onmt.transforms.misc.BPETransform`
Class: `onmt.transforms.tokenize.BPETransform`

The `src_subword_model` and `tgt_subword_model` should be valid BPE models.

Expand All @@ -323,7 +321,7 @@ These different types of noise can be controlled with the following options:

Transform name: `switchout`

Class: `onmt.transforms.misc.SwitchOutTransform`
Class: `onmt.transforms.sampling.SwitchOutTransform`

Options:

Expand All @@ -333,7 +331,7 @@ Options:

Transform name: `tokendrop`

Class: `onmt.transforms.misc.TokenDropTransform`
Class: `onmt.transforms.sampling.TokenDropTransform`

Options:

Expand All @@ -343,7 +341,7 @@ Options:

Transform name: `tokenmask`

Class: `onmt.transforms.misc.TokenMaskTransform`
Class: `onmt.transforms.sampling.TokenMaskTransform`

Options:

Expand All @@ -360,11 +358,6 @@ You can for instance have a look at the `FilterTooLongTransform` class as a temp
class FilterTooLongTransform(Transform):
"""Filter out sentence that are too long."""
def __init__(self, opts):
super().__init__(opts)
self.src_seq_length = opts.src_seq_length
self.tgt_seq_length = opts.tgt_seq_length
@classmethod
def add_options(cls, parser):
"""Avalilable options relate to this Transform."""
Expand All @@ -374,12 +367,16 @@ class FilterTooLongTransform(Transform):
group.add("--tgt_seq_length", "-tgt_seq_length", type=int, default=200,
help="Maximum target sequence length.")
def _parse_opts(self):
self.src_seq_length = self.opts.src_seq_length
self.tgt_seq_length = self.opts.tgt_seq_length
def apply(self, example, is_train=False, stats=None, **kwargs):
"""Return None if too long else return as is."""
if (len(example['src']) > self.src_seq_length or
len(example['tgt']) > self.tgt_seq_length):
if stats is not None:
stats.filter_too_long()
stats.update(FilterTooLongStats())
return None
else:
return example
Expand All @@ -394,11 +391,33 @@ class FilterTooLongTransform(Transform):

Methods:
- `add_options` allows to add custom options that would be necessary for the transform configuration;
- `_parse_opts` allows to parse options introduced in `add_options` when initialize;
- `apply` is where the transform happens;
- `_repr_args` is for clean logging purposes.

As you can see, there is the `@register_transform` wrapper before the class definition. This will allow for the class to be automatically detected (if put in the proper `transforms` folder) and usable in your training configurations through its `name` argument.

You could also collect statistics for your custom transform by creating a class inheriting `ObservableStats`:

```python
class FilterTooLongStats(ObservableStats):
"""Runing statistics for FilterTooLongTransform."""
__slots__ = ["filtered"]
def __init__(self):
self.filtered = 1
def update(self, other: "FilterTooLongStats"):
self.filtered += other.filtered
```

NOTE:
- Add elements to keep track in the `__init__` and also `__slot__` to make it lightweight;
- Supply update logic in `update` method;
- (Optional) override `__str__` to change default log message format;
- Instantiate and passing the statistic object in the `apply` method of the corresponding transform class;
- statistics will be gathered per corpus per worker, but only first worker will report for its shard by default.

The `example` argument of `apply` is a `dict` of the form:
```
{
Expand Down

0 comments on commit 71026eb

Please sign in to comment.