Skip to content

Commit

Permalink
The rest of the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
LysandreJik committed Oct 27, 2020
1 parent 0671708 commit d64b619
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 8 deletions.
4 changes: 1 addition & 3 deletions src/transformers/modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,9 +1136,7 @@ def call(
training=training,
)

past = (
(encoder_outputs, decoder_outputs[1]) if cast_bool_to_primitive(use_cache, self.config.use_cache) else None
)
past = (encoder_outputs, decoder_outputs[1]) if (use_cache) else None
if not return_dict:
if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class TFModelTesterMixin:
test_resize_embeddings = True
is_encoder_decoder = False

def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> dict:
inputs_dict = copy.deepcopy(inputs_dict)

if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
Expand Down
6 changes: 5 additions & 1 deletion tests/test_modeling_tf_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,11 @@ def test_saved_model_with_attentions_output(self):
model = tf.keras.models.load_model(tmpdirname)
outputs = model(class_inputs_dict)

language_attentions, vision_attentions, cross_encoder_attentions = (outputs[-3], outputs[-2], outputs[-1])
language_attentions, vision_attentions, cross_encoder_attentions = (
outputs[-3],
outputs[-2],
outputs[-1],
)

self.assertEqual(len(language_attentions), self.model_tester.num_hidden_layers["language"])
self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers["vision"])
Expand Down
11 changes: 9 additions & 2 deletions tests/test_modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import tempfile
import unittest

from transformers import T5Config, is_tf_available
Expand Down Expand Up @@ -282,6 +281,14 @@ def test_model_from_pretrained(self):
model = TFT5Model.from_pretrained("t5-small")
self.assertIsNotNone(model)

@slow
def test_saved_model_with_attentions_output(self):
pass

@slow
def test_saved_model_with_hidden_states_output(self):
pass


@require_tf
@require_sentencepiece
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modeling_tf_xlm_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_output_embeds_base_model(self):
"attention_mask": tf.convert_to_tensor([[1, 1, 1, 1, 1, 1]], dtype=tf.int32),
}

output = model(features)["last_hidden_state"]
output = model(features, return_dict=True)["last_hidden_state"]
expected_shape = tf.TensorShape((1, 6, 768))
self.assertEqual(output.shape, expected_shape)
# compare the actual values for a slice.
Expand Down

0 comments on commit d64b619

Please sign in to comment.