Skip to content

Commit

Permalink
Updated fix for #27
Browse files Browse the repository at this point in the history
  • Loading branch information
dbolya committed Apr 26, 2023
1 parent 8fca570 commit 2e37b87
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tomesd/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
if no_rand:
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
else:
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device, generator=generator)
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to(metric.device)

# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
Expand Down

0 comments on commit 2e37b87

Please sign in to comment.