Skip to content

Immutable Arrays #8

Closed as not planned
Closed as not planned
Enhancement
@avik-pal

Description

@avik-pal

Testing out the Immutable Arrays from JuliaLang/julia#44381 with #7

TLDR: Performance is a slight pain (seems broadcasting) right now, but it is very straightforward to support these once the functionality is available in Base

EDIT: Code updated to work for Lux 0.4.*

Trial 1: From the Usage Example

using Lux, Random, Functors

make_immutable(x::AbstractArray) = ImmutableArray(copy(x))
make_immutable(x) = x

# Construct the layer
model = Chain(BatchNorm(128), Dense(128, 256, tanh), BatchNorm(256),
                        Chain(Dense(256, 1, tanh), Dense(1, 10)))

# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps)
st_immutable = fmap(make_immutable, st)

# Dummy Input
x = randn(Float32, 128, 1024)
x_immutable = make_immutable(x)

# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)

Standard Abstract Arrays

BenchmarkTools.Trial: 1296 samples with 1 evaluation.
 Range (min  max):  2.125 ms  26.658 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     3.096 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   3.836 ms ±  2.313 ms  ┊ GC (mean ± σ):  2.58% ± 7.71%

    ▂█                                                        
  ▆▄██▇▆▄▄▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂▂▂▁▂▂▂▂▁▁▂▂▂▂▂▁▂▂▂▂▁▂▂▂ ▃
  2.13 ms        Histogram: frequency by time        14.1 ms <

 Memory estimate: 3.60 MiB, allocs estimate: 144.

Immutable Arrays

BenchmarkTools.Trial: 41 samples with 1 evaluation.
 Range (min  max):  107.855 ms  159.665 ms  ┊ GC (min  max): 3.98%  2.64%
 Time  (median):     119.911 ms               ┊ GC (median):    3.54%
 Time  (mean ± σ):   123.706 ms ±  10.746 ms  ┊ GC (mean ± σ):  3.54% ± 0.67%

              ▂█▄                                                
  ▄▁▁▁▁▁▁▁▄▆▄█████▄▁▄▆▄▆▁▁▄▁▁▄▁▁▁▁▁▁▄▁▁▁▁▄▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▆ ▁
  108 ms           Histogram: frequency by time          160 ms <

 Memory estimate: 58.32 MiB, allocs estimate: 3418558.

Trial 2: Only a Dense Layer

# Construct the layer
model = Dense(128, 256)

# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps);
st_immutable = fmap(make_immutable, st);

# Dummy Input
x = randn(Float32, 128, 1024);
x_immutable = make_immutable(x);

# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)

Standard Abstract Arrays

BenchmarkTools.Trial: 4469 samples with 1 evaluation.
 Range (min  max):  483.810 μs  30.894 ms  ┊ GC (min  max): 0.00%   0.00%
 Time  (median):     716.669 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):     1.100 ms ±  1.501 ms  ┊ GC (mean ± σ):  5.01% ± 12.19%

  █▆▆▅▄▃▂▂▂▂▃▃▃▂▁                                              ▁
  █████████████████▇▇▇▆▇▆▅▅▃▃▄▅▅▄▃▅▁▁▆▄▅▁▃▃▃▃▅▁▃▃▃▃▁▃▁▁▃▁▁▁▁▃▅ █
  484 μs        Histogram: log(frequency) by time      7.69 ms <

 Memory estimate: 2.00 MiB, allocs estimate: 4.

Immutable Arrays

BenchmarkTools.Trial: 259 samples with 1 evaluation.
 Range (min  max):  15.392 ms  52.229 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     17.997 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   19.327 ms ±  4.194 ms  ┊ GC (mean ± σ):  1.72% ± 4.44%

    ▃▆█ ▂                                                      
  ▃▆███▆█▇▅▇▇▄▆▃▆▄▄▅▄▄▄▄▄▃▄▄▃▂▁▃▃▂▁▃▂▁▁▂▂▁▂▂▂▁▃▁▃▂▂▁▁▁▂▂▁▂▂▁▂ ▃
  15.4 ms         Histogram: frequency by time        32.6 ms <

 Memory estimate: 7.00 MiB, allocs estimate: 262153.

Seems like there is a lot of time being spent on broadcasting the bias (seems like a problem with broadcasting in general)

julia> @benchmark $ps_immutable.weight * $x_immutable
BenchmarkTools.Trial: 4032 samples with 1 evaluation.
 Range (min  max):  346.287 μs  51.079 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     540.489 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):     1.224 ms ±  1.854 ms  ┊ GC (mean ± σ):  2.36% ± 8.18%

  █▆▄▄▃▁▁▁ ▂▂▁▁▁▂▂▁▁  ▁▁                                       ▁
  █████████████████████████▇▇▇▆▇▆▇▆▆▃▆▆▆▅▅▅▅▄▅▅▅▆▅▅▅▅▅▅▄▃▁▁▁▃▃ █
  346 μs        Histogram: log(frequency) by time      8.78 ms <

 Memory estimate: 1.00 MiB, allocs estimate: 5.

julia> @benchmark $ps_immutable.weight * $x_immutable .+ $ps_immutable.bias
BenchmarkTools.Trial: 338 samples with 1 evaluation.
 Range (min  max):  11.177 ms  33.105 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     13.699 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   14.792 ms ±  3.901 ms  ┊ GC (mean ± σ):  2.43% ± 5.87%

   █▃                                                          
  ▅██▇▇▅▅▇▅▇▅▅▄▅▅▄▃▃▄▄▃▂▃▃▁▂▃▁▃▂▃▃▃▁▃▂▂▂▁▃▁▁▂▂▂▂▁▁▁▁▁▁▁▁▁▃▂▁▂ ▃
  11.2 ms         Histogram: frequency by time        30.9 ms <

 Memory estimate: 7.00 MiB, allocs estimate: 262153.

Trial 3: No broadcasting

model = Dense(128, 256; bias=false)

# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps);
st_immutable = fmap(make_immutable, st);

# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)

Standard Abstract Arrays

BenchmarkTools.Trial: 5501 samples with 1 evaluation.
 Range (min  max):  295.161 μs  23.801 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     451.402 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   899.925 μs ±  1.386 ms  ┊ GC (mean ± σ):  3.10% ± 8.68%

  █▆▆▄▃▂▁▂▁▁▁▂▂▂▂▁ ▁                                           ▁
  ██████████████████▇█▇█▇▇▆▆▇▇▆▆▆▆▆▆▅▅▅▆▅▅▁▆▄▆▅▃▅▄▅▄▆▄▅▁▄▆▅▅▃▅ █
  295 μs        Histogram: log(frequency) by time      6.98 ms <

 Memory estimate: 1.00 MiB, allocs estimate: 2.

Immutable Arrays

BenchmarkTools.Trial: 5303 samples with 1 evaluation.
 Range (min  max):  311.574 μs  26.953 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     436.316 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   930.509 μs ±  1.488 ms  ┊ GC (mean ± σ):  3.23% ± 8.75%

  █▆▅▃▂▁   ▁▁▂▁▁                                               ▁
  █████████████████▆█▇▇▆▆▆▆▆▆▆▆▆▅▅▅▅▅▅▄▄▅▅▅▅▂▅▂▄▅▄▅▄▄▃▂▃▄▄▂▃▂▃ █
  312 μs        Histogram: log(frequency) by time      7.61 ms <

 Memory estimate: 1.00 MiB, allocs estimate: 5.

Trial 4

model = Chain(Dense(128, 256; bias=false), Chain(Dense(256, 512; bias=false),
                                                                                   Dense(512, 10; bias=false)))

# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps);
st_immutable = fmap(make_immutable, st);

# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)

Standard Abstract Arrays

BenchmarkTools.Trial: 1372 samples with 1 evaluation.
 Range (min  max):  1.380 ms  49.871 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     2.918 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   3.615 ms ±  3.116 ms  ┊ GC (mean ± σ):  2.42% ± 7.94%

  ▅█    ▃                                                     
  ███▇▆▇██▇▆▅▄▄▄▃▃▃▃▂▃▃▃▂▃▂▃▂▂▂▂▁▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▁▂▂ ▃
  1.38 ms        Histogram: frequency by time        15.8 ms <

 Memory estimate: 3.04 MiB, allocs estimate: 6.

Immutable Arrays

BenchmarkTools.Trial: 894 samples with 1 evaluation.
 Range (min  max):  1.505 ms  66.104 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     4.153 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   5.561 ms ±  5.432 ms  ┊ GC (mean ± σ):  1.87% ± 7.54%

  █▆▅▅▅▄▅▆▆▅▄▄▂▂▂▂▁     ▁  ▁     ▁                            
  █████████████████▇█▆███▆▇█▅▆▇███▆▇█▄▇▇▇▅▄▆▅▅▁▄▁▆▄▁▅▇▅▄▄▆▁▅ █
  1.5 ms       Histogram: log(frequency) by time     23.1 ms <

 Memory estimate: 3.04 MiB, allocs estimate: 17.

cc @ChrisRackauckas @ianatol @aviatesk

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions