Skip to content

Commit

Permalink
tests: fix pytorch tensor placement errors (#33485)
Browse files Browse the repository at this point in the history
This commit fixes the following errors:
* Fix "expected all tensors to be on the same device" error
* Fix "can't convert device type tensor to numpy"

According to pytorch documentation torch.Tensor.numpy(force=False)
performs conversion only if tensor is on CPU (plus few other restrictions)
which is not the case. For our case we need force=True since we just
need a data and don't care about tensors coherency.

Fixes: #33517
See: https://pytorch.org/docs/2.4/generated/torch.Tensor.numpy.html

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
  • Loading branch information
dvrogozh authored Sep 25, 2024
1 parent 52daf4e commit 5e2916b
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
if v.dtype == bfloat16:
v = v.float()
pt_state_dict[k] = v.numpy()
pt_state_dict[k] = v.cpu().numpy()

model_prefix = flax_model.base_model_prefix

Expand Down
11 changes: 7 additions & 4 deletions tests/models/clip/test_modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,7 @@ def test_equivalence_pt_to_flax(self):
with self.subTest(model_class.__name__):
# load PyTorch class
pt_model = model_class(config).eval()
pt_model.to(torch_device)
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
Expand Down Expand Up @@ -881,7 +882,7 @@ def test_equivalence_pt_to_flax(self):
fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)

with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
Expand All @@ -892,7 +893,7 @@ def test_equivalence_pt_to_flax(self):
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)

# overwrite from common since FlaxCLIPModel returns nested output
# which is not supported in the common test
Expand Down Expand Up @@ -921,6 +922,7 @@ def test_equivalence_flax_to_pt(self):
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()

pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
pt_model.to(torch_device)

# make sure weights are tied in PyTorch
pt_model.tie_weights()
Expand All @@ -940,11 +942,12 @@ def test_equivalence_flax_to_pt(self):
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")

for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)

with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)

with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
Expand All @@ -953,7 +956,7 @@ def test_equivalence_flax_to_pt(self):
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)

@slow
def test_model_from_pretrained(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,15 +297,15 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}

with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()

fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)

# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -315,7 +315,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)

# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -330,7 +330,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)

def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/informer/test_modeling_informer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def check_encoder_decoder_model_standalone(self, config, inputs_dict):

embed_positions = InformerSinusoidalPositionalEmbedding(
config.context_length + config.prediction_length, config.d_model
)
).to(torch_device)
self.parent.assertTrue(torch.equal(model.encoder.embed_positions.weight, embed_positions.weight))
self.parent.assertTrue(torch.equal(model.decoder.embed_positions.weight, embed_positions.weight))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,15 +412,15 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}

with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()

fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)

# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -430,7 +430,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)

# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -445,7 +445,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)

def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,15 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}

with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()

fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)

# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -259,7 +259,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)

# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -274,7 +274,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)

def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,15 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}

with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()

fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)

# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -178,7 +178,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)

# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -193,7 +193,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):

self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)

def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,15 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mas
# prepare inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values}
pt_inputs = inputs_dict
flax_inputs = {k: v.numpy() for k, v in pt_inputs.items()}
flax_inputs = {k: v.numpy(force=True) for k, v in pt_inputs.items()}

with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()

fx_outputs = fx_model(**flax_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)

# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -197,7 +197,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mas
fx_outputs_loaded = fx_model_loaded(**flax_inputs).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)

# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -212,7 +212,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mas

self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)

def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
Expand Down

0 comments on commit 5e2916b

Please sign in to comment.