Closed as not planned
Description
Testing out the Immutable Arrays from JuliaLang/julia#44381 with #7
TLDR: Performance is a slight pain (seems broadcasting) right now, but it is very straightforward to support these once the functionality is available in Base
EDIT: Code updated to work for Lux 0.4.*
Trial 1: From the Usage Example
using Lux, Random, Functors
make_immutable(x::AbstractArray) = ImmutableArray(copy(x))
make_immutable(x) = x
# Construct the layer
model = Chain(BatchNorm(128), Dense(128, 256, tanh), BatchNorm(256),
Chain(Dense(256, 1, tanh), Dense(1, 10)))
# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps)
st_immutable = fmap(make_immutable, st)
# Dummy Input
x = randn(Float32, 128, 1024)
x_immutable = make_immutable(x)
# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)
Standard Abstract Arrays
BenchmarkTools.Trial: 1296 samples with 1 evaluation.
Range (min … max): 2.125 ms … 26.658 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 3.096 ms ┊ GC (median): 0.00%
Time (mean ± σ): 3.836 ms ± 2.313 ms ┊ GC (mean ± σ): 2.58% ± 7.71%
▂█
▆▄██▇▆▄▄▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂▂▂▁▂▂▂▂▁▁▂▂▂▂▂▁▂▂▂▂▁▂▂▂ ▃
2.13 ms Histogram: frequency by time 14.1 ms <
Memory estimate: 3.60 MiB, allocs estimate: 144.
Immutable Arrays
BenchmarkTools.Trial: 41 samples with 1 evaluation.
Range (min … max): 107.855 ms … 159.665 ms ┊ GC (min … max): 3.98% … 2.64%
Time (median): 119.911 ms ┊ GC (median): 3.54%
Time (mean ± σ): 123.706 ms ± 10.746 ms ┊ GC (mean ± σ): 3.54% ± 0.67%
▂█▄
▄▁▁▁▁▁▁▁▄▆▄█████▄▁▄▆▄▆▁▁▄▁▁▄▁▁▁▁▁▁▄▁▁▁▁▄▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▆ ▁
108 ms Histogram: frequency by time 160 ms <
Memory estimate: 58.32 MiB, allocs estimate: 3418558.
Trial 2: Only a Dense Layer
# Construct the layer
model = Dense(128, 256)
# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps);
st_immutable = fmap(make_immutable, st);
# Dummy Input
x = randn(Float32, 128, 1024);
x_immutable = make_immutable(x);
# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)
Standard Abstract Arrays
BenchmarkTools.Trial: 4469 samples with 1 evaluation.
Range (min … max): 483.810 μs … 30.894 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 716.669 μs ┊ GC (median): 0.00%
Time (mean ± σ): 1.100 ms ± 1.501 ms ┊ GC (mean ± σ): 5.01% ± 12.19%
█▆▆▅▄▃▂▂▂▂▃▃▃▂▁ ▁
█████████████████▇▇▇▆▇▆▅▅▃▃▄▅▅▄▃▅▁▁▆▄▅▁▃▃▃▃▅▁▃▃▃▃▁▃▁▁▃▁▁▁▁▃▅ █
484 μs Histogram: log(frequency) by time 7.69 ms <
Memory estimate: 2.00 MiB, allocs estimate: 4.
Immutable Arrays
BenchmarkTools.Trial: 259 samples with 1 evaluation.
Range (min … max): 15.392 ms … 52.229 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 17.997 ms ┊ GC (median): 0.00%
Time (mean ± σ): 19.327 ms ± 4.194 ms ┊ GC (mean ± σ): 1.72% ± 4.44%
▃▆█ ▂
▃▆███▆█▇▅▇▇▄▆▃▆▄▄▅▄▄▄▄▄▃▄▄▃▂▁▃▃▂▁▃▂▁▁▂▂▁▂▂▂▁▃▁▃▂▂▁▁▁▂▂▁▂▂▁▂ ▃
15.4 ms Histogram: frequency by time 32.6 ms <
Memory estimate: 7.00 MiB, allocs estimate: 262153.
Seems like there is a lot of time being spent on broadcasting the bias (seems like a problem with broadcasting in general)
julia> @benchmark $ps_immutable.weight * $x_immutable
BenchmarkTools.Trial: 4032 samples with 1 evaluation.
Range (min … max): 346.287 μs … 51.079 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 540.489 μs ┊ GC (median): 0.00%
Time (mean ± σ): 1.224 ms ± 1.854 ms ┊ GC (mean ± σ): 2.36% ± 8.18%
█▆▄▄▃▁▁▁ ▂▂▁▁▁▂▂▁▁ ▁▁ ▁
█████████████████████████▇▇▇▆▇▆▇▆▆▃▆▆▆▅▅▅▅▄▅▅▅▆▅▅▅▅▅▅▄▃▁▁▁▃▃ █
346 μs Histogram: log(frequency) by time 8.78 ms <
Memory estimate: 1.00 MiB, allocs estimate: 5.
julia> @benchmark $ps_immutable.weight * $x_immutable .+ $ps_immutable.bias
BenchmarkTools.Trial: 338 samples with 1 evaluation.
Range (min … max): 11.177 ms … 33.105 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 13.699 ms ┊ GC (median): 0.00%
Time (mean ± σ): 14.792 ms ± 3.901 ms ┊ GC (mean ± σ): 2.43% ± 5.87%
█▃
▅██▇▇▅▅▇▅▇▅▅▄▅▅▄▃▃▄▄▃▂▃▃▁▂▃▁▃▂▃▃▃▁▃▂▂▂▁▃▁▁▂▂▂▂▁▁▁▁▁▁▁▁▁▃▂▁▂ ▃
11.2 ms Histogram: frequency by time 30.9 ms <
Memory estimate: 7.00 MiB, allocs estimate: 262153.
Trial 3: No broadcasting
model = Dense(128, 256; bias=false)
# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps);
st_immutable = fmap(make_immutable, st);
# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)
Standard Abstract Arrays
BenchmarkTools.Trial: 5501 samples with 1 evaluation.
Range (min … max): 295.161 μs … 23.801 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 451.402 μs ┊ GC (median): 0.00%
Time (mean ± σ): 899.925 μs ± 1.386 ms ┊ GC (mean ± σ): 3.10% ± 8.68%
█▆▆▄▃▂▁▂▁▁▁▂▂▂▂▁ ▁ ▁
██████████████████▇█▇█▇▇▆▆▇▇▆▆▆▆▆▆▅▅▅▆▅▅▁▆▄▆▅▃▅▄▅▄▆▄▅▁▄▆▅▅▃▅ █
295 μs Histogram: log(frequency) by time 6.98 ms <
Memory estimate: 1.00 MiB, allocs estimate: 2.
Immutable Arrays
BenchmarkTools.Trial: 5303 samples with 1 evaluation.
Range (min … max): 311.574 μs … 26.953 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 436.316 μs ┊ GC (median): 0.00%
Time (mean ± σ): 930.509 μs ± 1.488 ms ┊ GC (mean ± σ): 3.23% ± 8.75%
█▆▅▃▂▁ ▁▁▂▁▁ ▁
█████████████████▆█▇▇▆▆▆▆▆▆▆▆▆▅▅▅▅▅▅▄▄▅▅▅▅▂▅▂▄▅▄▅▄▄▃▂▃▄▄▂▃▂▃ █
312 μs Histogram: log(frequency) by time 7.61 ms <
Memory estimate: 1.00 MiB, allocs estimate: 5.
Trial 4
model = Chain(Dense(128, 256; bias=false), Chain(Dense(256, 512; bias=false),
Dense(512, 10; bias=false)))
# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps);
st_immutable = fmap(make_immutable, st);
# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)
Standard Abstract Arrays
BenchmarkTools.Trial: 1372 samples with 1 evaluation.
Range (min … max): 1.380 ms … 49.871 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 2.918 ms ┊ GC (median): 0.00%
Time (mean ± σ): 3.615 ms ± 3.116 ms ┊ GC (mean ± σ): 2.42% ± 7.94%
▅█ ▃
███▇▆▇██▇▆▅▄▄▄▃▃▃▃▂▃▃▃▂▃▂▃▂▂▂▂▁▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▁▂▂ ▃
1.38 ms Histogram: frequency by time 15.8 ms <
Memory estimate: 3.04 MiB, allocs estimate: 6.
Immutable Arrays
BenchmarkTools.Trial: 894 samples with 1 evaluation.
Range (min … max): 1.505 ms … 66.104 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 4.153 ms ┊ GC (median): 0.00%
Time (mean ± σ): 5.561 ms ± 5.432 ms ┊ GC (mean ± σ): 1.87% ± 7.54%
█▆▅▅▅▄▅▆▆▅▄▄▂▂▂▂▁ ▁ ▁ ▁
█████████████████▇█▆███▆▇█▅▆▇███▆▇█▄▇▇▇▅▄▆▅▅▁▄▁▆▄▁▅▇▅▄▄▆▁▅ █
1.5 ms Histogram: log(frequency) by time 23.1 ms <
Memory estimate: 3.04 MiB, allocs estimate: 17.