Skip to content

Commit 6b4cfb1

Browse files
committed
feat: Optimize
it is so slow that 5000 is about max loops/points. this optimization makes it calculate 50000 points with ease, and much more if needed.
1 parent 8511d03 commit 6b4cfb1

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

fern.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,42 @@
55

66
class AffineTransform:
77
def __init__(self, a=0., b=0., c=0., d=0., e=0., f=0.):
8-
self._abcd = self._ef = np.asarray([[a, b], [c, d]])
8+
self._abcd = np.asarray([[a, b], [c, d]])
99
self._ef = np.asarray([[e], [f]])
1010

1111
def __call__(self, x, y):
1212
# np.dot maybe introducing overhead?
13+
# probably not, it's much faster now after optimizing, by parametrizing the scatter plot
1314
return np.dot(self._abcd, [[x], [y]]) + self._ef
1415

1516

16-
f1 = AffineTransform(0,0,0,0.16,0,0)
17+
f1 = AffineTransform(0, 0, 0, 0.16, 0, 0)
1718
f2 = AffineTransform(0.85, 0.04, -0.04, 0.85, 0, 1.60)
1819
f3 = AffineTransform(0.20, -0.26, 0.23, 0.22, 0, 1.6)
1920
f4 = AffineTransform(-0.15, 0.28, 0.26, 0.24, 0, 0.44)
2021

2122
functions = [f1, f2, f3, f4]
2223
p_cumulative = [0.01, 0.86, 0.92, 1.00]
24+
25+
2326
def select(funcs, probcum):
2427
r = np.random.random()
2528
for j, p in enumerate(probcum):
2629
if r < p:
2730
return funcs[j]
2831

2932

33+
#--------plotting and saving pic
3034
plt.axis("equal")
3135
plt.axis('off')
32-
x = [[0], [0]]
33-
plt.scatter(0, 0)
34-
# TODO: optimize
35-
for i in range(5000):
36-
x = select(functions, p_cumulative)(x[0][0], x[1][0])
37-
plt.scatter(x[0], x[1])
36+
n = 50000
37+
fern = np.zeros((n + 1, 2), dtype=float)
38+
39+
for i in range(n):
40+
fern[i + 1] = (select(functions, p_cumulative)(fern[i][0], fern[i][1])).reshape(2)
3841

42+
plt.scatter(fern[:, 0], fern[:, 1], s=0.2, c="g", marker=",")
3943
plt.savefig(pathlib.Path(__file__).parent.resolve().__str__() + '\\figures\\' + "barnsley_fern",
40-
dpi=400, s=0.1)
44+
dpi=400)
4145

46+
plt.show()

0 commit comments

Comments
 (0)