From 2e37b87f2f24b0c902507b8afb7ebc7ba9b7cf6b Mon Sep 17 00:00:00 2001 From: Daniel Bolya Date: Wed, 26 Apr 2023 00:27:15 -0400 Subject: [PATCH] Updated fix for #27 --- tomesd/merge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tomesd/merge.py b/tomesd/merge.py index cb11281..6e8a513 100644 --- a/tomesd/merge.py +++ b/tomesd/merge.py @@ -49,7 +49,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, if no_rand: rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64) else: - rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device, generator=generator) + rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to(metric.device) # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)