Skip to content

Commit

Permalink
Speed up mandelbrot (modularml#742)
Browse files Browse the repository at this point in the history
Add suggestions from @zbosons to significantly speed up mandelbrot from 2.68 ms to 1.30 on 8 core machine
  • Loading branch information
jackos authored Sep 12, 2023
1 parent 4b9b26b commit 6ecab31
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 39 deletions.
32 changes: 17 additions & 15 deletions examples/mandelbrot.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ from tensor import Tensor
from utils.index import Index

alias float_type = DType.float64
alias simd_width = simdwidthof[float_type]()
alias simd_width = 2 * simdwidthof[float_type]()

alias width = 960
alias height = 960
Expand All @@ -24,17 +24,22 @@ fn mandelbrot_kernel_SIMD[
simd_width: Int
](c: ComplexSIMD[float_type, simd_width]) -> SIMD[float_type, simd_width]:
"""A vectorized implementation of the inner mandelbrot computation."""
var z = ComplexSIMD[float_type, simd_width](0, 0)
let cx = c.re
let cy = c.im
var x = SIMD[float_type, simd_width](0)
var y = SIMD[float_type, simd_width](0)
var y2 = SIMD[float_type, simd_width](0)
var iters = SIMD[float_type, simd_width](0)

var in_set_mask: SIMD[DType.bool, simd_width] = True
var t: SIMD[DType.bool, simd_width] = True
for i in range(MAX_ITERS):
if not in_set_mask.reduce_or():
if not t.reduce_or():
break
in_set_mask = z.squared_norm() <= 4
iters = in_set_mask.select(iters + 1, iters)
z = z.squared_add(c)

y2 = y*y
y = x.fma(y + y, cy)
t = x.fma(x, y2) <= 4
x = x.fma(x, cx - y2)
iters = t.select(iters + 1, iters)
return iters


Expand All @@ -48,7 +53,7 @@ fn main():

@parameter
fn compute_vector[simd_width: Int](col: Int):
"""Each time we oeprate on a `simd_width` vector of pixels."""
"""Each time we operate on a `simd_width` vector of pixels."""
let cx = min_x + (col + iota[float_type, simd_width]()) * scale_x
let cy = min_y + row * scale_y
let c = ComplexSIMD[float_type, simd_width](cx, cy)
Expand All @@ -65,20 +70,17 @@ fn main():
worker(row)

let vectorized_ms = Benchmark().run[bench[simd_width]]() / 1e6
print("Number of hardware cores:", num_cores())
print("Number of threads:", num_cores())
print("Vectorized:", vectorized_ms, "ms")

# Parallelized
with Runtime() as rt:

@parameter
fn bench_parallel[simd_width: Int]():
parallelize[worker](rt, height, 5 * num_cores())
parallelize[worker](rt, height, height)

alias simd_width = simdwidthof[DType.float64]()
let parallelized_ms = Benchmark().run[
bench_parallel[simd_width]
]() / 1e6
let parallelized_ms = Benchmark().run[bench_parallel[simd_width]]() / 1e6
print("Parallelized:", parallelized_ms, "ms")
print("Parallel speedup:", vectorized_ms / parallelized_ms)

Expand Down
51 changes: 27 additions & 24 deletions examples/notebooks/Mandelbrot.ipynb

Large diffs are not rendered by default.

0 comments on commit 6ecab31

Please sign in to comment.