Skip to content

Commit

Permalink
TF generate refactor - Greedy Search (huggingface#15562)
Browse files Browse the repository at this point in the history
* TF generate start refactor

* Add tf tests for sample generate

* re-organize

* boom boom

* Apply suggestions from code review

* re-add

* add all code

* make random greedy pass

* make encoder-decoder random work

* further improvements

* delete bogus file

* make gpt2 and t5 tests work

* finish logits tests

* correct logits processors

* correct past / encoder_outputs drama

* refactor some methods

* another fix

* refactor shape_list

* fix more shape list

* import shape
_list

* finish docs

* fix imports

* make style

* correct tf utils

* Fix TFRag as well

* Apply Lysandre's and Sylvais suggestions

* Update tests/test_generation_tf_logits_process.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* Update src/transformers/tf_utils.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* remove cpu according to gante

* correct logit processor

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
  • Loading branch information
patrickvonplaten and Rocketknight1 authored Feb 15, 2022
1 parent a3dbbc3 commit 2e12b90
Show file tree
Hide file tree
Showing 56 changed files with 1,475 additions and 206 deletions.
18 changes: 18 additions & 0 deletions docs/source/internal/generation_utils.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,24 @@ generation.
[[autodoc]] InfNanRemoveLogitsProcessor
- __call__

[[autodoc]] TFLogitsProcessor
- __call__

[[autodoc]] TFLogitsProcessorList
- __call__

[[autodoc]] TFMinLengthLogitsProcessor
- __call__

[[autodoc]] TFNoBadWordsLogitsProcessor
- __call__

[[autodoc]] TFNoRepeatNGramLogitsProcessor
- __call__

[[autodoc]] TFRepetitionPenaltyLogitsProcessor
- __call__

[[autodoc]] FlaxLogitsProcessor
- __call__

Expand Down
17 changes: 17 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,6 +1592,14 @@
_import_structure["activations_tf"] = []
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
_import_structure["generation_tf_logits_process"] = [
"TFLogitsProcessor",
"TFLogitsProcessorList",
"TFMinLengthLogitsProcessor",
"TFNoBadWordsLogitsProcessor",
"TFNoRepeatNGramLogitsProcessor",
"TFRepetitionPenaltyLogitsProcessor",
]
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"]
_import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
_import_structure["modeling_tf_outputs"] = []
Expand Down Expand Up @@ -2046,6 +2054,7 @@
]
)
_import_structure["optimization_tf"] = ["AdamWeightDecay", "GradientAccumulator", "WarmUp", "create_optimizer"]
_import_structure["tf_utils"] = []
_import_structure["trainer_tf"] = ["TFTrainer"]

else:
Expand Down Expand Up @@ -3572,6 +3581,14 @@

# Benchmarks
from .benchmark.benchmark_tf import TensorFlowBenchmark
from .generation_tf_logits_process import (
TFLogitsProcessor,
TFLogitsProcessorList,
TFMinLengthLogitsProcessor,
TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor,
)
from .generation_tf_utils import tf_top_k_top_p_filtering
from .keras_callbacks import KerasMetricCallback, PushToHubCallback
from .modeling_tf_layoutlm import (
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/generation_flax_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import inspect
from abc import ABC

import jax
import jax.lax as lax
Expand Down Expand Up @@ -48,7 +47,7 @@
"""


class FlaxLogitsProcessor(ABC):
class FlaxLogitsProcessor:
"""Abstract base class for all logit processors that can be applied during generation."""

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
Expand All @@ -59,7 +58,7 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
)


class FlaxLogitsWarper(ABC):
class FlaxLogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/generation_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import inspect
import math
from abc import ABC
from typing import Callable, Iterable, List, Optional

import numpy as np
Expand Down Expand Up @@ -49,7 +48,7 @@
"""


class LogitsProcessor(ABC):
class LogitsProcessor:
"""Abstract base class for all logit processors that can be applied during generation."""

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
Expand All @@ -60,7 +59,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
)


class LogitsWarper(ABC):
class LogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
Expand Down
Loading

0 comments on commit 2e12b90

Please sign in to comment.