Skip to content

Commit

Permalink
Fix a TF Vision Encoder Decoder test (huggingface#15896)
Browse files Browse the repository at this point in the history
* send PyTorch inputs to the correct device

* Fix: TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
ydshieh and ydshieh authored Mar 3, 2022
1 parent 39249c9 commit 4cd7ed4
Showing 1 changed file with 5 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,9 @@ def check_pt_tf_equivalence(self, pt_model, tf_model, inputs_dict):
if "labels" in pt_inputs:
pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor)

# send pytorch inputs to the correct device
pt_inputs = {k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()}

with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()

Expand All @@ -321,7 +324,7 @@ def check_pt_tf_equivalence(self, pt_model, tf_model, inputs_dict):
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")

for tf_output, pt_output in zip(tf_outputs, pt_outputs):
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
self.assert_almost_equals(tf_output.numpy(), pt_output.detach().to("cpu").numpy(), 1e-3)

# PT -> TF
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
Expand All @@ -341,7 +344,7 @@ def check_pt_tf_equivalence(self, pt_model, tf_model, inputs_dict):
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")

for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.detach().to("cpu").numpy(), 1e-3)

def check_equivalence_pt_to_tf(self, config, decoder_config, inputs_dict):

Expand Down

0 comments on commit 4cd7ed4

Please sign in to comment.