Skip to content

Commit

Permalink
Merge branch 'fix-batch-data-test' into 'main'
Browse files Browse the repository at this point in the history
Fix batch data test

See merge request heka/medkit!239

changelog: Fix batch data test
  • Loading branch information
olvb committed Nov 28, 2023
2 parents 1c949ff + a13ee05 commit ef6b034
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions tests/unit/training/test_batch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,8 @@ def test_dict_properties():
@pytest.mark.skipif(not TEST_CUDA, reason="cuda is not available")
def test_to_device():
cpu = torch.device("cpu")
gpu = torch.device("gpu")
data = BatchData(
inputs=["hello", "world"], outputs=[torch.tensor(0), torch.tensor(1)]
)
gpu = torch.device("cuda:0")
data = BatchData(inputs=["hello", "world"], outputs=torch.tensor([0, 1]))
new_data = data.to_device(gpu)
for tensor_cpu, tensor_gpu in zip(data["outputs"], new_data["outputs"]):
assert tensor_cpu.device == cpu
Expand Down

0 comments on commit ef6b034

Please sign in to comment.