Skip to content

Commit b104441

Browse files
committed
WIP: update the follow-up
1 parent 31177c5 commit b104441

File tree

2 files changed

+636
-0
lines changed

2 files changed

+636
-0
lines changed

e2e/mandelbrot/mandelbrot_qgpu3.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# -----------------------------------------------------------------------------
2+
# From Numpy to Python
3+
# Copyright (2017) Nicolas P. Rougier - BSD license
4+
# More information at https://github.com/rougier/numpy-book
5+
# -----------------------------------------------------------------------------
6+
import math
7+
import numpy as np
8+
import time
9+
10+
# need to import before torch
11+
from matplotlib import colors
12+
import matplotlib.pyplot as plt
13+
14+
import torch
15+
torch.set_default_device("cpu")
16+
17+
18+
# ### Original NumPy version. ###
19+
20+
def mandelbrot(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon=2.0):
21+
# Adapted from https://www.ibm.com/developerworks/community/blogs/jfp/...
22+
# .../entry/How_To_Compute_Mandelbrodt_Set_Quickly?lang=en
23+
X = np.linspace(xmin, xmax, xn, dtype=np.float32)
24+
Y = np.linspace(ymin, ymax, yn, dtype=np.float32)
25+
C = X + Y[:,None]*1j
26+
N = np.zeros(C.shape, dtype=int)
27+
Z = np.zeros(C.shape, np.complex64)
28+
for n in range(maxiter):
29+
I = np.less(abs(Z), horizon)
30+
N[I] = n
31+
Z[I] = Z[I]**2 + C[I]
32+
N[N == maxiter-1] = 0
33+
return Z, N
34+
35+
36+
37+
# ### Compiled analog. ###
38+
39+
# For torch.Dynamo, need to work around
40+
# 1. Complex numbers: add a trailing length-2 dimension for Re and Im parts.
41+
# 2. Avoid fancy indexing: use with np.where instead to avoid data dependency
42+
#
43+
# Also:
44+
# 1. Only compile the inner loop, to keep compile time and memory consumption
45+
# under control (otherwise, can run into OOM while compiling)
46+
47+
def abs2(a):
48+
r"""abs(a) replacement."""
49+
return a[..., 0]**2 + a[..., 1]**2
50+
51+
52+
def sq2(a):
53+
"""a**2 replacement."""
54+
z = np.empty_like(a)
55+
z[..., 0] = a[..., 0]**2 - a[..., 1]**2
56+
z[..., 1] = 2 * a[..., 0] * a[..., 1]
57+
return z
58+
59+
60+
@torch.compile
61+
def step(n0, c, Z, N, horizon, chunksize):
62+
for j in range(chunksize):
63+
n = n0 + j
64+
I = abs2(Z) < horizon**2
65+
N = np.where(I, n, N) # N[I] = n
66+
Z = np.where(I[..., None], sq2(Z) + c, Z) # Z[I] = Z[I]**2 + C[I]
67+
return Z, N
68+
69+
70+
def mandelbrot_c(xmin, xmax, ymin, ymax, xn, yn, horizon=2**10, maxiter=5):
71+
x = np.linspace(xmin, xmax, xn, dtype='float32')
72+
y = np.linspace(ymin, ymax, yn, dtype='float32')
73+
c = np.stack(np.broadcast_arrays(x[None, :], y[:, None]), axis=-1)
74+
75+
N = np.zeros(c.shape[:-1], dtype='int')
76+
Z = np.zeros_like(c, dtype='float32')
77+
78+
chunksize=50
79+
n_chunks = maxiter // chunksize
80+
81+
for i_chunk in range(n_chunks):
82+
n0 = i_chunk*chunksize
83+
Z, N = step(n0, c, Z, N, horizon, chunksize)
84+
85+
N = np.where(N == maxiter-1, 0, N) # N[N == maxiter-1] = 0
86+
return Z, N
87+
88+
89+
90+
# plot a nice figure
91+
def visualize(Z, N, horizon, xn, yn):
92+
log_horizon = math.log(horizon, 2)
93+
M = np.nan_to_num(N + 1 - np.log(np.log(abs(Z)))/np.log(2) + log_horizon)
94+
95+
dpi = 72
96+
width = 10
97+
height = 10*yn/xn
98+
99+
fig = plt.figure(figsize=(width, height), dpi=dpi)
100+
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0], frameon=False, aspect=1)
101+
102+
light = colors.LightSource(azdeg=315, altdeg=10)
103+
104+
plt.imshow(light.shade(M, cmap=plt.cm.hot, vert_exag=1.5,
105+
norm = colors.PowerNorm(0.3), blend_mode='hsv'),
106+
extent=[xmin, xmax, ymin, ymax], interpolation="bicubic")
107+
ax.set_xticks([])
108+
ax.set_yticks([])
109+
plt.savefig("mandelbrot.png")
110+
# plt.show()
111+
112+
113+
114+
if __name__ == '__main__':
115+
# start up
116+
xmax, xmin, xn = -2.25, 0.75, 3000 // 2
117+
ymax, ymin, yn = -1.25, 1.25, 2500 // 2
118+
119+
maxiter = 200
120+
horizon = 2**10
121+
122+
# time numpy
123+
start_time = time.time()
124+
Z, N = mandelbrot(xmin, xmax, ymin, ymax, xn, yn, horizon=horizon, maxiter=maxiter)
125+
end_time = time.time()
126+
numpy_time = end_time - start_time
127+
print("\n\nnumpy: elapsed=", numpy_time)
128+
129+
130+
start_time = time.time()
131+
step = torch.compile(step)
132+
end_time = time.time()
133+
print("compile: ", end_time - start_time)
134+
135+
# compile, warm up, time
136+
for _ in range(3):
137+
mandelbrot_c(xmin, xmax, ymin, ymax, xn, yn, horizon=horizon, maxiter=maxiter)
138+
139+
# measure
140+
start_time = time.time()
141+
nreps = 100
142+
for _ in range(nreps):
143+
Z, N = mandelbrot_c(xmin, xmax, ymin, ymax, xn, yn, horizon=horizon, maxiter=maxiter)
144+
end_time = time.time()
145+
compiled_time = (end_time - start_time) / nreps
146+
print("compiled: elapsed=", compiled_time, ' speedup = ', numpy_time / compiled_time)
147+
148+
# Visualization
149+
Z = Z[..., 0] + 1j*Z[..., 1]
150+
visualize(Z, N, horizon, xn, yn)
151+

0 commit comments

Comments
 (0)