Skip to content

Commit

Permalink
Flax t5 Encoder (huggingface#17784)
Browse files Browse the repository at this point in the history
* first draft adding Flax-t5-encoder and Flax-mt5-encoder

* imports

* after make fixup

* flax t5 encoder test

* black on test

* make fix-copies

* clean

* all_model_classes -> tuple

* clean test

* is_encoder_decoder=False in t5-enc tester

* remove file docstring before FlaxT5Encoder

* black

* isort

* commit suggestions on src/transformers/models/t5/modeling_flax_t5.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* commit suggestions on src/transformers/models/t5/modeling_flax_t5.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* remove _get_encoder_module

* self.decoder_seq_length -> self.encoder_seq_length as t5-enc does not have decoder

* bugfix - self.module_class is class itself, not instance;

* docs for mt5 and t5

* call -> __call__ in t5 doc

* FlaxMT5EncoderModel to TYPE_HINT

* run doc-builder to allow change the files

Co-authored-by: Suraj Patil <surajp815@gmail.com>
  • Loading branch information
crystina-z and patil-suraj authored Jun 29, 2022
1 parent eb1493b commit 692e61e
Show file tree
Hide file tree
Showing 9 changed files with 454 additions and 15 deletions.
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/mt5.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,7 @@ See [`T5TokenizerFast`] for all details.
## FlaxMT5ForConditionalGeneration

[[autodoc]] FlaxMT5ForConditionalGeneration

## FlaxMT5EncoderModel

[[autodoc]] FlaxMT5EncoderModel
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/t5.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,8 @@ T5 is supported by several example scripts, both for pre-training and fine-tunin
- __call__
- encode
- decode

## FlaxT5EncoderModel

[[autodoc]] FlaxT5EncoderModel
- __call__
10 changes: 6 additions & 4 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2704,7 +2704,7 @@
"FlaxMBartPreTrainedModel",
]
)
_import_structure["models.mt5"].extend(["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
_import_structure["models.mt5"].extend(["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
_import_structure["models.opt"].extend(
[
"FlaxOPTForCausalLM",
Expand Down Expand Up @@ -2743,7 +2743,9 @@
]
)
_import_structure["models.speech_encoder_decoder"].append("FlaxSpeechEncoderDecoderModel")
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
_import_structure["models.t5"].extend(
["FlaxT5EncoderModel", "FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"]
)
_import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
_import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"])
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
Expand Down Expand Up @@ -4974,7 +4976,7 @@
FlaxMBartModel,
FlaxMBartPreTrainedModel,
)
from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
from .models.mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model
from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel
from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel
from .models.roberta import (
Expand All @@ -4997,7 +4999,7 @@
FlaxRoFormerPreTrainedModel,
)
from .models.speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/mt5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_mt5"] = ["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]
_import_structure["modeling_flax_mt5"] = ["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]


if TYPE_CHECKING:
Expand Down Expand Up @@ -95,7 +95,7 @@
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
from .modeling_flax_mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model

else:
import sys
Expand Down
29 changes: 28 additions & 1 deletion src/transformers/models/mt5/modeling_flax_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np

from ...utils import logging
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from ..t5.modeling_flax_t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model
from .configuration_mt5 import MT5Config


Expand Down Expand Up @@ -67,6 +67,33 @@ class FlaxMT5Model(FlaxT5Model):
config_class = MT5Config


class FlaxMT5EncoderModel(FlaxT5EncoderModel):
r"""
This class overrides [`FlaxT5EncoderModel`]. Please check the superclass for the appropriate documentation
alongside usage examples.
Examples:
```python
>>> from transformers import FlaxT5EncoderModel, T5Tokenizer
>>> model = FlaxT5EncoderModel.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> inputs = tokenizer(article, return_tensors="np")
>>> with tokenizer.as_target_tokenizer():
... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids
>>> outputs = model(input_ids=inputs["input_ids"])
>>> hidden_states = outputs.last_hidden_state
```"""
model_type = "mt5"
config_class = MT5Config


class FlaxMT5ForConditionalGeneration(FlaxT5ForConditionalGeneration):
r"""
This class overrides [`FlaxT5ForConditionalGeneration`]. Please check the superclass for the appropriate
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/t5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
pass
else:
_import_structure["modeling_flax_t5"] = [
"FlaxT5EncoderModel",
"FlaxT5ForConditionalGeneration",
"FlaxT5Model",
"FlaxT5PreTrainedModel",
Expand Down Expand Up @@ -143,7 +144,12 @@
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .modeling_flax_t5 import (
FlaxT5EncoderModel,
FlaxT5ForConditionalGeneration,
FlaxT5Model,
FlaxT5PreTrainedModel,
)


else:
Expand Down
96 changes: 90 additions & 6 deletions src/transformers/models/t5/modeling_flax_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,18 +929,18 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz
input_ids = jnp.zeros(input_shape, dtype="i4")

attention_mask = jnp.ones_like(input_ids)
decoder_input_ids = jnp.ones_like(input_ids)
decoder_attention_mask = jnp.ones_like(input_ids)
args = [input_ids, attention_mask]
if self.module_class not in [FlaxT5EncoderModule]:
decoder_input_ids = jnp.ones_like(input_ids)
decoder_attention_mask = jnp.ones_like(input_ids)
args.extend([decoder_input_ids, decoder_attention_mask])

params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}

random_params = self.module.init(
rngs,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
*args,
)["params"]

if params is not None:
Expand Down Expand Up @@ -1357,6 +1357,90 @@ class FlaxT5Model(FlaxT5PreTrainedModel):
append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)


@add_start_docstrings(
"The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
T5_START_DOCSTRING,
)
class FlaxT5EncoderModule(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation

def setup(self):
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
)

encoder_config = copy.deepcopy(self.config)
encoder_config.is_decoder = False
encoder_config.is_encoder_decoder = False
encoder_config.causal = False
self.encoder = FlaxT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype)

def __call__(
self,
input_ids=None,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
deterministic: bool = True,
):

# Encode if needed (training, first prediction pass)
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)

return encoder_outputs


class FlaxT5EncoderModel(FlaxT5PreTrainedModel):
module_class = FlaxT5EncoderModule

@add_start_docstrings_to_model_forward(T5_ENCODE_INPUTS_DOCSTRING)
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict

# prepare encoder inputs
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)

# Handle any PRNG if needed
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}

return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)


@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
class FlaxT5ForConditionalGenerationModule(nn.Module):
config: T5Config
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/utils/dummy_flax_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])


class FlaxMT5EncoderModel(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])


class FlaxMT5ForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]

Expand Down Expand Up @@ -970,6 +977,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])


class FlaxT5EncoderModel(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])


class FlaxT5ForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]

Expand Down
Loading

0 comments on commit 692e61e

Please sign in to comment.