Skip to content

Commit

Permalink
Single list
Browse files Browse the repository at this point in the history
  • Loading branch information
LysandreJik committed Mar 17, 2022
1 parent 8e6aaa1 commit bb5874a
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,8 @@ def check_pt_tf_models(tf_model, pt_model):
"TFFunnelForPreTraining",
"TFElectraForPreTraining",
"TFXLMWithLMHeadModel",
] + ["TFTransfoXLLMHeadModel"]:
"TFTransfoXLLMHeadModel",
]:
self.assertEqual(tf_loss is None, pt_loss is None)

tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
Expand All @@ -490,7 +491,8 @@ def check_pt_tf_models(tf_model, pt_model):
"TFFunnelForPreTraining",
"TFElectraForPreTraining",
"TFXLMWithLMHeadModel",
] + ["TFTransfoXLLMHeadModel"]:
"TFTransfoXLLMHeadModel",
]:
self.assertEqual(tf_keys, pt_keys)

# Since we deliberately make some tests pass above (regarding the `loss`), let's still try to test
Expand Down

0 comments on commit bb5874a

Please sign in to comment.