Skip to content

Commit 4d2d2d0

Browse files
committed
[ENH] Speed up build_mask
Signed-off-by: chenhe <38067903+chAwater@users.noreply.github.com>
1 parent 6cc98fc commit 4d2d2d0

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

visualizations/animation.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,20 @@
2323

2424

2525
def build_mask(s: int, margin: float = 2., dtype=torch.float32):
26-
mask = torch.zeros(1, 1, s, s, dtype=dtype)
26+
mask = torch.ones(1, 1, s, s, dtype=dtype)
2727
c = (s - 1) / 2
2828
t = (c - margin / 100. * c) ** 2
2929
sig = 2.
30-
for x in range(s):
31-
for y in range(s):
32-
r = (x - c) ** 2 + (y - c) ** 2
33-
if r > t:
34-
mask[..., x, y] = np.exp((t - r) / sig ** 2)
35-
else:
36-
mask[..., x, y] = 1.
30+
y, x = np.ogrid[:s, :s]
31+
r = (x - c) ** 2 + (y - c) ** 2
32+
# r > t
33+
outer_mask = ((t - r) / sig ** 2)
34+
outer_mask = outer_mask ** (r > t) # To prevent overflow
35+
outer_mask = (r > t) * np.exp(outer_mask)
36+
# r <= t
37+
inner_mask = (r <= t)
38+
mask = mask * outer_mask + mask * inner_mask
39+
3740
return mask
3841

3942

0 commit comments

Comments
 (0)