Skip to content

Commit

Permalink
refactor: use setfield and make make_zero!! type-stable
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 13, 2024
1 parent 315e4e8 commit f60db4d
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 99 deletions.
2 changes: 2 additions & 0 deletions ext/LuxEnzymeExt/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ module LuxEnzymeExt
using ADTypes: AutoEnzyme
using Enzyme: Enzyme, Active, Const, Duplicated
using EnzymeCore: EnzymeCore
using Setfield: @set!
using Static: False, True

using Lux: Lux
using Lux.Training: TrainingBackendCache, TrainState
Expand Down
51 changes: 25 additions & 26 deletions ext/LuxEnzymeExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,42 +1,43 @@
function Lux.Training.compute_gradients(
::AutoEnzyme, obj_fn::F, data, ts::TrainState) where {F}
dps = Enzyme.make_zero(ts.parameters)
ad::AutoEnzyme, obj_fn::F, data, ts::TrainState) where {F}
dps = Lux.recursive_make_zero(ts.parameters)

obj_fn_wrap, st_wrap, stats_wrap = Lux.Training.wrap_objective_function(
obj_fn, ts.model, ts.parameters, ts.states, data, Val(true))
obj_fn, ts.model, ts.parameters, ts.states, data, True())

_, loss = Enzyme.autodiff(
EnzymeCore.ReverseWithPrimal, Const(obj_fn_wrap), Active, Const(ts.model),
Duplicated(ts.parameters, dps), Const(ts.states), Const(data))

cache = TrainingBackendCache{:Enzyme, false}(
dps, (; obj_fn=obj_fn_wrap, st_wrap, stats_wrap))
ts_new = TrainState(cache, obj_fn, ts.model, ts.parameters, st_wrap[],
ts.optimizer, ts.optimizer_state, ts.step)

return dps, loss, stats_wrap[], ts_new
cache = TrainingBackendCache(
ad, False(), dps, (; obj_fn=obj_fn_wrap, st_wrap, stats_wrap))
@set! ts.cache = cache
@set! ts.objective_function = obj_fn
@set! ts.states = st_wrap[]
return dps, loss, stats_wrap[], ts
end

const AUTODIFF_CACHE_TYPE = TrainingBackendCache{
:Enzyme, false, PS, <:NamedTuple{(:obj_fn, :st_wrap, :stats_wrap)}} where {PS}
<:AutoEnzyme, False, PS, <:NamedTuple{(:obj_fn, :st_wrap, :stats_wrap)}} where {PS}

function Lux.Training.compute_gradients(
::AutoEnzyme, obj_fn::F, data, ts::TrainState{<:AUTODIFF_CACHE_TYPE, F}) where {F}
dps = Lux.recursive_make_zero!!(ts.cache.dparameters)
# dps = Lux.recursive_make_zero!!(ts.cache.dparameters)
Enzyme.make_zero!(ts.cache.dparameters)
dps = ts.cache.dparameters

_, loss = Enzyme.autodiff(
EnzymeCore.ReverseWithPrimal, Const(ts.cache.extras.obj_fn), Active,
Const(ts.model), Duplicated(ts.parameters, dps), Const(ts.states), Const(data))

ts_new = TrainState(
ts.cache, obj_fn, ts.model, ts.parameters, ts.cache.extras.st_wrap[],
ts.optimizer, ts.optimizer_state, ts.step)
@set! ts.objective_function = obj_fn
@set! ts.states = ts.cache.extras.st_wrap[]

return dps, loss, ts.cache.extras.stats_wrap[], ts_new
return dps, loss, ts.cache.extras.stats_wrap[], ts
end

function Lux.Training.compute_gradients(ad::AutoEnzyme, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{:Enzyme, false}}) where {F}
ts::TrainState{<:TrainingBackendCache{<:AutoEnzyme, False}}) where {F}
@warn "Detected calls to `compute_gradients(::AutoEnzyme, ...)` with objective \
function that is changing across function calls. This can lead to the \
generation of slow code" maxlog=1
Expand All @@ -46,15 +47,14 @@ function Lux.Training.compute_gradients(ad::AutoEnzyme, obj_fn::F, data,
Const{typeof(ts.model)}, Duplicated{typeof(ts.parameters)},
Const{typeof(ts.states)}, Const{typeof(data)})

cache = TrainingBackendCache{:Enzyme, false}(ts.cache.dparameters, (; forward, reverse))
ts_new = TrainState(cache, obj_fn, ts.model, ts.parameters, ts.states,
ts.optimizer, ts.optimizer_state, ts.step)

return Lux.Training.compute_gradients(ad, obj_fn, data, ts_new)
cache = TrainingBackendCache(ad, False(), ts.cache.dparameters, (; forward, reverse))
@set! ts.cache = cache
@set! ts.objective_function = obj_fn
return Lux.Training.compute_gradients(ad, obj_fn, data, ts)
end

const AUTODIFF_THUNK_CACHE_TYPE = TrainingBackendCache{
:Enzyme, false, PS, <:NamedTuple{(:forward, :reverse)}} where {PS}
<:AutoEnzyme, False, PS, <:NamedTuple{(:forward, :reverse)}} where {PS}

function Lux.Training.compute_gradients(::AutoEnzyme, obj_fn::F, data,
ts::TrainState{<:AUTODIFF_THUNK_CACHE_TYPE, F}) where {F}
Expand All @@ -67,8 +67,7 @@ function Lux.Training.compute_gradients(::AutoEnzyme, obj_fn::F, data,
Const(obj_fn), Const(ts.model), params, Const(ts.states), Const(data),
(one(loss), Lux.recursive_make_zero(st_), Lux.recursive_make_zero(stats)), tape)

ts_new = TrainState(ts.cache, obj_fn, ts.model, ts.parameters, st_,
ts.optimizer, ts.optimizer_state, ts.step)

return dps, loss, stats, ts_new
@set! ts.objective_function = obj_fn
@set! ts.states = st_
return dps, loss, stats, ts
end
4 changes: 3 additions & 1 deletion ext/LuxReverseDiffExt/LuxReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ using ArrayInterface: ArrayInterface
using FunctionWrappers: FunctionWrapper
using ReverseDiff: ReverseDiff, ForwardExecutor, ReverseExecutor, TrackedArray, TrackedReal,
@grad_from_chainrules
using Setfield: @set!
using Static: False, True

using Lux: Lux, Utils
using Lux.Training: TrainingBackendCache, TrainState
using Lux.Training: Training, TrainingBackendCache, TrainState
using LuxCore: LuxCore
using MLDataDevices: CPUDevice

Expand Down
73 changes: 32 additions & 41 deletions ext/LuxReverseDiffExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,53 +1,50 @@
# Uncompiled ReverseDiff
function Lux.Training.compute_gradients(
ad::AutoReverseDiff{false}, obj_fn::F, data, ts::TrainState) where {F}
grads = Lux.recursive_make_zero(ts.parameters)
ts_new = TrainState(
TrainingBackendCache{:ReverseDiff, true}(grads, nothing), obj_fn, ts.model,
ts.parameters, ts.states, ts.optimizer, ts.optimizer_state, ts.step)
return Lux.Training.compute_gradients(ad, obj_fn, data, ts_new)
@set! ts.cache = TrainingBackendCache(
ad, True(), Lux.recursive_make_zero(ts.parameters), nothing)
@set! ts.objective_function = obj_fn
return Lux.Training.compute_gradients(ad, obj_fn, data, ts)
end

function Lux.Training.compute_gradients(::AutoReverseDiff{false}, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT}
dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters)
ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{false}}}) where {F}
dparams = Training.dparameters(ts.cache)
tape = ReverseDiff.InstructionTape()
ps_tracked = Lux.recursive_map(Utils.Fix3(TrackedArray, tape), ts.parameters, dparams)

loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data)
loss.deriv = true
ReverseDiff.reverse_pass!(tape)

ts_new = TrainState(
TrainingBackendCache{:ReverseDiff, false}(ts.cache.dparameters, nothing),
obj_fn, ts.model, ts.parameters, st, ts.optimizer, ts.optimizer_state, ts.step)

return ts.cache.dparameters, ReverseDiff.value(loss), stats, ts_new
@set! ts.cache.first_try = False()
@set! ts.objective_function = obj_fn
@set! ts.states = st
return dparams, ReverseDiff.value(loss), stats, ts
end

# Compiled ReverseDiff
function Lux.Training.compute_gradients(
ad::AutoReverseDiff{true}, obj_fn::F, data, ts::TrainState) where {F}
grads = Lux.recursive_make_zero(ts.parameters)
data_cache = deepcopy(data)
ps_cache = deepcopy(ts.parameters)
extras = (; data_cache, ps_cache)

ts_new = TrainState(
TrainingBackendCache{:ReverseDiff, true}(grads, extras), nothing, ts.model,
ts.parameters, ts.states, ts.optimizer, ts.optimizer_state, ts.step)
return Lux.Training.compute_gradients(ad, obj_fn, data, ts_new)
@set! ts.cache = TrainingBackendCache(
ad, True(), Lux.recursive_make_zero(ts.parameters),
(; data_cache=deepcopy(data), ps_cache=deepcopy(ts.parameters)))
@set! ts.objective_function = nothing

return Lux.Training.compute_gradients(ad, obj_fn, data, ts)
end

## Tape hasn't been compiled yet / Function mismatch so recompile
function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{:ReverseDiff, FT}}) where {F, FT}
function Lux.Training.compute_gradients(ad::AutoReverseDiff{true}, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{true}}}) where {F}
if LuxCore.statelength(ts.states) != 0
throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for Lux \
models with non-empty state `st`."))
end

if FT # do a dry run
first_try = ts.cache.first_try isa True

if first_try # do a dry run
_, st_, stats = obj_fn(ts.model, ts.parameters, ts.states, data)
if stats != NamedTuple()
throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for \
Expand All @@ -59,20 +56,18 @@ function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data
end
end

dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters)
dparams = Training.dparameters(ts.cache)

(; ps_cache, data_cache) = ts.cache.extras
if !FT
if !first_try
Lux.recursive_copyto!(ps_cache, ts.parameters)
Lux.recursive_copyto!(data_cache, data)
end

obj_fn_wrap = first obj_fn

tape = ReverseDiff.InstructionTape()
ps_tracked = Lux.recursive_map(Utils.Fix3(TrackedArray, tape), ps_cache, dparams)

loss = obj_fn_wrap(ts.model, ps_tracked, ts.states, data_cache)
loss = first(obj_fn(ts.model, ps_tracked, ts.states, data_cache))
loss.deriv = true
ReverseDiff.reverse_pass!(tape)

Expand All @@ -81,18 +76,14 @@ function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data
reverse_executor = [FunctionWrapper{Nothing, Tuple{}}(ReverseExecutor(tape[i]))
for i in length(tape):-1:1]

compiled_extras = (;
ps_cache, data_cache, forward_executor, reverse_executor, output=loss)
ts_new = TrainState(
TrainingBackendCache{:ReverseDiff, false}(ts.cache.dparameters, compiled_extras),
obj_fn, ts.model, ts.parameters, ts.states,
ts.optimizer, ts.optimizer_state, ts.step)

return dparams, ReverseDiff.value(loss), NamedTuple(), ts_new
@set! ts.cache = TrainingBackendCache(ad, False(), dparams,
(; ps_cache, data_cache, forward_executor, reverse_executor, output=loss))
@set! ts.objective_function = obj_fn
return dparams, ReverseDiff.value(loss), NamedTuple(), ts
end

function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{:ReverseDiff, false}, F}) where {F}
ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{true}}, F}) where {F}
(; ps_cache, data_cache, output) = ts.cache.extras

dparams = Lux.recursive_make_zero!!(ts.cache.dparameters)
Expand All @@ -107,7 +98,7 @@ function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data
wrapper()
end

ts_new = TrainState(ts.cache, obj_fn, ts.model, ts.parameters, ts.states,
ts.optimizer, ts.optimizer_state, ts.step)
return dparams, ReverseDiff.value(output), NamedTuple(), ts_new
@set! ts.cache.first_try = False()
@set! ts.objective_function = obj_fn
return dparams, ReverseDiff.value(output), NamedTuple(), ts
end
4 changes: 3 additions & 1 deletion ext/LuxTrackerExt/LuxTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ module LuxTrackerExt
using ADTypes: AbstractADType, AutoTracker
using ArrayInterface: ArrayInterface
using ChainRulesCore: ChainRulesCore
using Setfield: @set!
using Static: False, True
using Tracker: Tracker, TrackedArray, TrackedReal, @grad_from_chainrules

using Lux: Lux, Utils
using Lux.Training: TrainingBackendCache, TrainState
using Lux.Training: Training, TrainingBackendCache, TrainState

const CRC = ChainRulesCore

Expand Down
24 changes: 12 additions & 12 deletions ext/LuxTrackerExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
function Lux.Training.compute_gradients(::AutoTracker, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{:Tracker, FT}}) where {F, FT}
dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters)
ps_tracked = construct_tracked_params(ts.parameters, dparams)
ts::TrainState{<:TrainingBackendCache{AutoTracker}}) where {F}
dps = Training.dparameters(ts.cache)
ps_tracked = construct_tracked_params(ts.parameters, dps)

loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data)
Tracker.back!(loss)

ts_new = TrainState(
TrainingBackendCache{:Tracker, false}(ts.cache.dparameters, nothing), obj_fn,
ts.model, ts.parameters, st, ts.optimizer, ts.optimizer_state, ts.step)
@set! ts.cache.first_try = False()
@set! ts.objective_function = obj_fn
@set! ts.states = st

return dparams, loss.data, stats, ts_new
return dps, loss.data, stats, ts
end

function Lux.Training.compute_gradients(
::AutoTracker, obj_fn::F, data, ts::TrainState) where {F}
ad::AutoTracker, obj_fn::F, data, ts::TrainState) where {F}
grads = Lux.recursive_make_zero(ts.parameters)
ts_new = TrainState(
TrainingBackendCache{:Tracker, true}(grads, nothing), obj_fn, ts.model,
ts.parameters, ts.states, ts.optimizer, ts.optimizer_state, ts.step)
return Lux.Training.compute_gradients(AutoTracker(), obj_fn, data, ts_new)
cache = TrainingBackendCache(ad, True(), grads, nothing)
@set! ts.cache = cache
@set! ts.objective_function = obj_fn
return Lux.Training.compute_gradients(ad, obj_fn, data, ts)
end
8 changes: 1 addition & 7 deletions src/helpers/recursive_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,12 @@ function recursive_map(f::F, x::AbstractArray{T}, args...) where {F, T}
(T <: Number || isbitstype(T)) && return f(x, args...) # Not all Number types (BigFloat) are bitstype
return f.(x, args...)
end
function recursive_map(f::F, x::Tuple, args...) where {F}
function recursive_map(f::F, x::Union{NamedTuple, Tuple}, args...) where {F}
map_fn = let f = f
(args_...) -> recursive_map(f, args_...)
end
return map(map_fn, x, args...)
end
function recursive_map(f::F, x::NamedTuple{fields}, args...) where {F, fields}
map_fn = let f = f
(args_...) -> recursive_map(f, args_...)
end
return NamedTuple{fields}(map(map_fn, values(x), values.(args)...))
end
recursive_map(f::F, x, args...) where {F} = fmap(f, x, args...)

@compat(public,
Expand Down
23 changes: 13 additions & 10 deletions src/helpers/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ConcreteStructs: @concrete
using FastClosures: @closure
using Optimisers: Optimisers
using Setfield: @set!
using Static: StaticBool, Static, False, True

using ..Lux: Lux
using LuxCore: LuxCore, AbstractLuxLayer
Expand Down Expand Up @@ -50,13 +51,10 @@ Constructor for [`TrainState`](@ref).
## Arguments
- `rng`: Random Number Generator.
- `ps`: Parameters of the model.
- `st`: States of the model.
- `model`: `Lux` model.
- `optimizer`: Optimizer from `Optimisers.jl`.
- `transform_variables`: Function to transform the variables of the model. Typically used
to transfer variables to GPU / CPU.
## Returns
Expand All @@ -67,12 +65,18 @@ function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.Abstr
return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0)
end

@concrete struct TrainingBackendCache{backend, first_try}
@concrete struct TrainingBackendCache
backend
first_try <: StaticBool
dparameters
extras
end

training_backend(::TrainingBackendCache{backend}) where {backend} = backend
dparameters(cache::TrainingBackendCache) = dparameters(cache, cache.first_try)
function dparameters(cache::TrainingBackendCache, ::False)
return Lux.recursive_make_zero!!(cache.dparameters)
end
dparameters(cache::TrainingBackendCache, ::True) = cache.dparameters

function Base.show(io::IO, ::MIME"text/plain", ts::TrainState)
println(io, "TrainState")
Expand All @@ -83,8 +87,7 @@ function Base.show(io::IO, ::MIME"text/plain", ts::TrainState)
print(io, " step: ", ts.step)
if ts.cache !== nothing
if ts.cache isa TrainingBackendCache
print(io,
"\n cache: $(nameof(typeof(ts.cache))){$(training_backend(ts.cache))}")
print(io, "\n cache: $(nameof(typeof(ts.cache)))($(ts.cache.backend))")
else
print(io, "\n cache: $(nameof(typeof(ts.cache)))")
end
Expand Down Expand Up @@ -198,7 +201,7 @@ for package in (:Zygote, :Tracker, :ReverseDiff, :Enzyme)
end
end

function generate_wrappers(::F, m, ps, st, data, ::Val{false}) where {F}
function generate_wrappers(::F, m, ps, st, data, ::False) where {F}
@warn "Detected function wrapper generation with function being updated between calls. \
This will generate type-unstable code. A possible reason for this is \
`TrainState` was compiled (first call to `compute_gradients`) with function \
Expand All @@ -208,13 +211,13 @@ function generate_wrappers(::F, m, ps, st, data, ::Val{false}) where {F}
end

# Run the code when trying to compile the function for the first time.
function generate_wrappers(objective_function::F, m, ps, st, data, ::Val{true}) where {F}
function generate_wrappers(objective_function::F, m, ps, st, data, ::True) where {F}
_, stₙ, statsₙ = objective_function(m, ps, st, data)
return Ref{typeof(stₙ)}(stₙ), Ref{typeof(statsₙ)}(statsₙ)
end

function wrap_objective_function(
objective_function::F, m, ps, st, data, first_try::Val) where {F}
objective_function::F, m, ps, st, data, first_try::StaticBool) where {F}
st_updated, stats = generate_wrappers(objective_function, m, ps, st, data, first_try)

wrapped_objective_function = @closure (model, ps, st, data) -> begin
Expand Down
Loading

3 comments on commit f60db4d

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/115175

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.0.3 -m "<description of version>" f60db4d929c708f839f72a16611fabe711de7c7c
git push origin v1.0.3

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lux Benchmarks

Benchmark suite Current: f60db4d Previous: 0b51676 Ratio
Dense(512 => 512, identity)(512 x 128)/forward/CPU/2 thread(s) 414500 ns 411125 ns 1.01
Dense(512 => 512, identity)(512 x 128)/forward/CPU/4 thread(s) 322250 ns 322750 ns 1.00
Dense(512 => 512, identity)(512 x 128)/forward/CPU/8 thread(s) 322708.5 ns 244083 ns 1.32
Dense(512 => 512, identity)(512 x 128)/forward/CPU/1 thread(s) 741958 ns 740229 ns 1.00
Dense(512 => 512, identity)(512 x 128)/forward/GPU/CUDA 44250.5 ns 43576 ns 1.02
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/2 thread(s) 1327167 ns 1361688 ns 0.97
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/4 thread(s) 2451688 ns 2448167 ns 1.00
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/8 thread(s) 14209750 ns 16505500 ns 0.86
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/1 thread(s) 2193937.5 ns 2198042 ns 1.00
Dense(512 => 512, identity)(512 x 128)/zygote/GPU/CUDA 207380 ns 207361 ns 1.00
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/2 thread(s) 1468292 ns 1419479 ns 1.03
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/4 thread(s) 923959 ns 931729 ns 0.99
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/8 thread(s) 1598937.5 ns 1582917 ns 1.01
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/1 thread(s) 2242395.5 ns 2213229 ns 1.01
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 1762396 ns 1768708 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1028250 ns 1072541.5 ns 0.96
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1537583 ns 1542417 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 2885833.5 ns 3010167 ns 0.96
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/GPU/CUDA 208790 ns 208923 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 12117833 ns 12164458 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 8811750 ns 8831167 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 9165333.5 ns 9231125 ns 0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 18605125 ns 18575542 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1497201 ns 1506706 ns 0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 17314916 ns 17297875 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 13952000 ns 13966709 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 14449937 ns 14490229 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 21832333 ns 21825958 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 250356604.5 ns 250077771 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 148503729 ns 148351292 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 115663250 ns 116742208 ns 0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 452727834 ns 446235042 ns 1.01
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/GPU/CUDA 5471701 ns 5474148 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1224679334 ns 1226735000 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 932428750 ns 933099541 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 831047479.5 ns 833488083 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 1654023458 ns 1628798917 ns 1.02
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 31662494 ns 31247743 ns 1.01
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1141591625 ns 1139513458 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 1004360417 ns 1004012958 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1322994750 ns 1343460771 ns 0.98
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 1741933375 ns 1729098333 ns 1.01
lenet(28, 28, 1, 32)/forward/CPU/2 thread(s) 1120833.5 ns 1084187.5 ns 1.03
lenet(28, 28, 1, 32)/forward/CPU/4 thread(s) 1620917 ns 1632875 ns 0.99
lenet(28, 28, 1, 32)/forward/CPU/8 thread(s) 3462083 ns 3807833 ns 0.91
lenet(28, 28, 1, 32)/forward/CPU/1 thread(s) 779667 ns 781500 ns 1.00
lenet(28, 28, 1, 32)/forward/GPU/CUDA 270336.5 ns 269181 ns 1.00
lenet(28, 28, 1, 32)/zygote/CPU/2 thread(s) 2988271 ns 2973917 ns 1.00
lenet(28, 28, 1, 32)/zygote/CPU/4 thread(s) 4139875 ns 4123458 ns 1.00
lenet(28, 28, 1, 32)/zygote/CPU/8 thread(s) 9659916 ns 11391021 ns 0.85
lenet(28, 28, 1, 32)/zygote/CPU/1 thread(s) 3132834 ns 3140229.5 ns 1.00
lenet(28, 28, 1, 32)/zygote/GPU/CUDA 1134352.5 ns 1147789 ns 0.99
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 2338166 ns 2327458.5 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1437021 ns 1427875 ns 1.01
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1669291 ns 1552208 ns 1.08
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 4193000 ns 4203041 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/GPU/CUDA 210459.5 ns 209123 ns 1.01
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 19441042 ns 19423562 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 16082770.5 ns 16279416 ns 0.99
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 17400416.5 ns 17361812 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 25866000 ns 25815125 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1593435 ns 1606839 ns 0.99
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 34177125 ns 34524104 ns 0.99
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 30976000 ns 31057875 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 31151000 ns 31105416 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 36261000 ns 36883875 ns 0.98
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 4537333 ns 4526208.5 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2776604 ns 2777083.5 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2913645.5 ns 2685312.5 ns 1.09
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 8378750 ns 8381562.5 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/GPU/CUDA 420670 ns 373639 ns 1.13
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 38891374.5 ns 38887521 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 32306292 ns 32509584 ns 0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 32384208 ns 32333229 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 51948083 ns 51833125 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2620746.5 ns 2633953 ns 0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 88847729 ns 88607687.5 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 114070333.5 ns 113743125 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 226493250 ns 227726583 ns 0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 73885250 ns 74951083 ns 0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 268317334 ns 267716166 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 159216084 ns 159256375 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 127078708 ns 123708895.5 ns 1.03
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 492762417 ns 485091625 ns 1.02
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/GPU/CUDA 6963353 ns 7022924 ns 0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1469208062.5 ns 1478680979 ns 0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 1179701333 ns 1179547083 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 1064469187.5 ns 1066054563 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 2018298416.5 ns 2001889209 ns 1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 34585385 ns 34822377.5 ns 0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1726168042 ns 1724298291 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 1532131312.5 ns 1565497271 ns 0.98
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1753217833 ns 1925114250 ns 0.91
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 2220540250 ns 2239111625 ns 0.99
lenet(28, 28, 1, 128)/forward/CPU/2 thread(s) 2032250 ns 2028500 ns 1.00
lenet(28, 28, 1, 128)/forward/CPU/4 thread(s) 2850166.5 ns 2967646 ns 0.96
lenet(28, 28, 1, 128)/forward/CPU/8 thread(s) 7482625 ns 8104667 ns 0.92
lenet(28, 28, 1, 128)/forward/CPU/1 thread(s) 2429979 ns 2308041.5 ns 1.05
lenet(28, 28, 1, 128)/forward/GPU/CUDA 267353.5 ns 272667 ns 0.98
lenet(28, 28, 1, 128)/zygote/CPU/2 thread(s) 9603854 ns 9619395.5 ns 1.00
lenet(28, 28, 1, 128)/zygote/CPU/4 thread(s) 11874437.5 ns 12015166 ns 0.99
lenet(28, 28, 1, 128)/zygote/CPU/8 thread(s) 24867021 ns 26324292 ns 0.94
lenet(28, 28, 1, 128)/zygote/CPU/1 thread(s) 11308542 ns 11677541 ns 0.97
lenet(28, 28, 1, 128)/zygote/GPU/CUDA 1173785 ns 1188628.5 ns 0.99
vgg16(32, 32, 3, 32)/forward/CPU/2 thread(s) 380634584 ns 383215354.5 ns 0.99
vgg16(32, 32, 3, 32)/forward/CPU/4 thread(s) 287745375 ns 284366604.5 ns 1.01
vgg16(32, 32, 3, 32)/forward/CPU/8 thread(s) 243501229 ns 261725395.5 ns 0.93
vgg16(32, 32, 3, 32)/forward/CPU/1 thread(s) 452284375.5 ns 453056042 ns 1.00
vgg16(32, 32, 3, 32)/forward/GPU/CUDA 5016811.5 ns 5009701 ns 1.00
vgg16(32, 32, 3, 32)/zygote/CPU/2 thread(s) 1137459875 ns 1160384584 ns 0.98
vgg16(32, 32, 3, 32)/zygote/CPU/4 thread(s) 943993333 ns 912166042 ns 1.03
vgg16(32, 32, 3, 32)/zygote/CPU/8 thread(s) 898262625 ns 984922208 ns 0.91
vgg16(32, 32, 3, 32)/zygote/CPU/1 thread(s) 1411909416 ns 1396092167 ns 1.01
vgg16(32, 32, 3, 32)/zygote/GPU/CUDA 18115193 ns 18111984 ns 1.00
lenet(28, 28, 1, 64)/forward/CPU/2 thread(s) 1060437 ns 1053833 ns 1.01
lenet(28, 28, 1, 64)/forward/CPU/4 thread(s) 2017041.5 ns 1605958 ns 1.26
lenet(28, 28, 1, 64)/forward/CPU/8 thread(s) 5113542 ns 5411083 ns 0.95
lenet(28, 28, 1, 64)/forward/CPU/1 thread(s) 1366833 ns 1296875 ns 1.05
lenet(28, 28, 1, 64)/forward/GPU/CUDA 265207 ns 265721 ns 1.00
lenet(28, 28, 1, 64)/zygote/CPU/2 thread(s) 6505083 ns 6510958 ns 1.00
lenet(28, 28, 1, 64)/zygote/CPU/4 thread(s) 12271187.5 ns 13082584 ns 0.94
lenet(28, 28, 1, 64)/zygote/CPU/8 thread(s) 18806687.5 ns 21760833.5 ns 0.86
lenet(28, 28, 1, 64)/zygote/CPU/1 thread(s) 6078250 ns 5984375 ns 1.02
lenet(28, 28, 1, 64)/zygote/GPU/CUDA 1214045 ns 1208949 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 70581646 ns 70494333 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 43485459 ns 43641125 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 39436292 ns 39690584 ns 0.99
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 132675958 ns 133468354 ns 0.99
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/GPU/CUDA 1863920 ns 1945255.5 ns 0.96
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 355687833.5 ns 356723479.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 270693083.5 ns 271306709 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 254405500.5 ns 254269771 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 538777458 ns 536238459 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 12367452 ns 12301288 ns 1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 396200000 ns 395599834 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 402727854 ns 377440167 ns 1.07
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 668679417 ns 697289229.5 ns 0.96
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 708861625 ns 708495833 ns 1.00
vgg16(32, 32, 3, 128)/forward/CPU/2 thread(s) 1187349792 ns 1188885083 ns 1.00
vgg16(32, 32, 3, 128)/forward/CPU/4 thread(s) 694829104 ns 692916625 ns 1.00
vgg16(32, 32, 3, 128)/forward/CPU/8 thread(s) 629932709 ns 642915416.5 ns 0.98
vgg16(32, 32, 3, 128)/forward/CPU/1 thread(s) 1779143271 ns 1776695937.5 ns 1.00
vgg16(32, 32, 3, 128)/forward/GPU/CUDA 13225818 ns 12306515 ns 1.07
vgg16(32, 32, 3, 128)/zygote/CPU/2 thread(s) 3622108083.5 ns 3668882667 ns 0.99
vgg16(32, 32, 3, 128)/zygote/CPU/4 thread(s) 2828172709 ns 2834396125 ns 1.00
vgg16(32, 32, 3, 128)/zygote/CPU/8 thread(s) 2724737708 ns 2699395792 ns 1.01
vgg16(32, 32, 3, 128)/zygote/CPU/1 thread(s) 5083300000 ns 5050853166 ns 1.01
vgg16(32, 32, 3, 128)/zygote/GPU/CUDA 49807086.5 ns 49852240.5 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 3420729.5 ns 3422958 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2074875 ns 2075583 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2525042 ns 2513666 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 6011833 ns 6018396 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/GPU/CUDA 315086 ns 317455.5 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 26295500 ns 26048666 ns 1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 18987458 ns 19094062.5 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 19862667 ns 19316000 ns 1.03
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 39218853.5 ns 39190562.5 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2478386 ns 2466381 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 55626729.5 ns 55369583 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 81917708 ns 82210395.5 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 172510354 ns 173994812.5 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 45569417 ns 45354333 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 1782395.5 ns 1779187.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1093791.5 ns 1097834 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1586291.5 ns 1568791 ns 1.01
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 3026979 ns 3021312 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/GPU/CUDA 213440.5 ns 210623 ns 1.01
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 12557083 ns 12543916 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 9205917 ns 9277708.5 ns 0.99
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 9717709 ns 9594229.5 ns 1.01
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 18945396 ns 18987604.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1545222 ns 1527868.5 ns 1.01
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 17667958 ns 17650708 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 14312292 ns 14335458 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 14670667 ns 14544250 ns 1.01
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 22150709 ns 22174250 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 70496583.5 ns 70431125 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 43541375 ns 43537125 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 39470417 ns 39620583 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 132760312.5 ns 132531916.5 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/GPU/CUDA 1958343 ns 1888879 ns 1.04
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 358409083 ns 360439083.5 ns 0.99
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 346583313 ns 347132666.5 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 304589375 ns 304637542 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 725990125 ns 722631792 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 13320357 ns 13304668 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 418971104 ns 419234750 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 419729042 ns 421465729 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 662505333 ns 724319500 ns 0.91
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 715138292 ns 714217917 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/2 thread(s) 1450437 ns 1705416 ns 0.85
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/4 thread(s) 1298979 ns 1350333.5 ns 0.96
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/8 thread(s) 1344645.5 ns 1170667 ns 1.15
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/1 thread(s) 2365917 ns 2385333.5 ns 0.99
mlp7layer_bn(gelu)(32 x 256)/forward/GPU/CUDA 590150.5 ns 580442.5 ns 1.02
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/2 thread(s) 8684833 ns 8948271 ns 0.97
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/4 thread(s) 12890000 ns 12980437.5 ns 0.99
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/8 thread(s) 30836166.5 ns 32353312.5 ns 0.95
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/1 thread(s) 9843750 ns 9804417 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/zygote/GPU/CUDA 1473920 ns 1427987.5 ns 1.03
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/2 thread(s) 17999292 ns 17962354 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/4 thread(s) 16546208 ns 17440000 ns 0.95
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/8 thread(s) 29181291 ns 29738291 ns 0.98
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/1 thread(s) 14097584 ns 14431937.5 ns 0.98
Dense(512 => 512, relu)(512 x 128)/forward/CPU/2 thread(s) 693250 ns 669833.5 ns 1.03
Dense(512 => 512, relu)(512 x 128)/forward/CPU/4 thread(s) 521417 ns 529250 ns 0.99
Dense(512 => 512, relu)(512 x 128)/forward/CPU/8 thread(s) 1040750 ns 1065708.5 ns 0.98
Dense(512 => 512, relu)(512 x 128)/forward/CPU/1 thread(s) 724875 ns 725395.5 ns 1.00
Dense(512 => 512, relu)(512 x 128)/forward/GPU/CUDA 48072 ns 47647 ns 1.01
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/2 thread(s) 1566292 ns 1549104 ns 1.01
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/4 thread(s) 1002937.5 ns 1038917 ns 0.97
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/8 thread(s) 1370333.5 ns 1517584 ns 0.90
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/1 thread(s) 2257250 ns 2269896 ns 0.99
Dense(512 => 512, relu)(512 x 128)/zygote/GPU/CUDA 238196.5 ns 233022 ns 1.02
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/2 thread(s) 1571020.5 ns 1582916 ns 0.99
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/4 thread(s) 1080916 ns 1087854.5 ns 0.99
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/8 thread(s) 1541833 ns 1464166 ns 1.05
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/1 thread(s) 2236209 ns 2190854 ns 1.02
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 3399875 ns 3413625 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2047875 ns 2047083 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2515021 ns 2507333.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 6005375 ns 6011813 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/GPU/CUDA 286172.5 ns 284231.5 ns 1.01
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 24087042 ns 24149000 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 17224041.5 ns 17330312.5 ns 0.99
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 17292291 ns 17059271 ns 1.01
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 37522062.5 ns 37480499.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2407498 ns 2394265 ns 1.01
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 53768270.5 ns 53573937.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 83654187.5 ns 83649500 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 169263021 ns 172928458 ns 0.98
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 44565333.5 ns 44425187.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 250492042 ns 249999250 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 148428250 ns 148223583 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 115397479.5 ns 116384896 ns 0.99
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 450610604 ns 447335937.5 ns 1.01
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/GPU/CUDA 5443833 ns 5449146 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1101924667 ns 1105347792 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 855192187.5 ns 857822708.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 827218333.5 ns 830398396 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 1763706625 ns 1762030583 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 29367206 ns 28862807 ns 1.02
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1019223979 ns 1020245354 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 945177042 ns 966178875 ns 0.98
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1303173167 ns 1293466208 ns 1.01
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 1739257541.5 ns 1724193375.5 ns 1.01
mlp7layer_bn(relu)(32 x 256)/forward/CPU/2 thread(s) 1211708 ns 1306896.5 ns 0.93
mlp7layer_bn(relu)(32 x 256)/forward/CPU/4 thread(s) 981875 ns 984292 ns 1.00
mlp7layer_bn(relu)(32 x 256)/forward/CPU/8 thread(s) 948167 ns 778437.5 ns 1.22
mlp7layer_bn(relu)(32 x 256)/forward/CPU/1 thread(s) 2062875 ns 1958750 ns 1.05
mlp7layer_bn(relu)(32 x 256)/forward/GPU/CUDA 569657 ns 566426 ns 1.01
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/2 thread(s) 5819083.5 ns 6042375 ns 0.96
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/4 thread(s) 4699250 ns 6715125 ns 0.70
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/8 thread(s) 24610750.5 ns 26872708 ns 0.92
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/1 thread(s) 7096333 ns 6973417 ns 1.02
mlp7layer_bn(relu)(32 x 256)/zygote/GPU/CUDA 1369164.5 ns 1365853 ns 1.00
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/2 thread(s) 11390750 ns 11215770.5 ns 1.02
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/4 thread(s) 9112562.5 ns 10033208 ns 0.91
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/8 thread(s) 17263667 ns 17672208 ns 0.98
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/1 thread(s) 8694666.5 ns 8568500 ns 1.01
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/2 thread(s) 384000 ns 399500 ns 0.96
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/4 thread(s) 364688 ns 399291.5 ns 0.91
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/8 thread(s) 2302437.5 ns 3544167 ns 0.65
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/1 thread(s) 89750 ns 88459 ns 1.01
Dense(128 => 128, gelu)(128 x 128)/forward/GPU/CUDA 27591.5 ns 27618 ns 1.00
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/2 thread(s) 391125 ns 397459 ns 0.98
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/4 thread(s) 382584 ns 445041.5 ns 0.86
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/8 thread(s) 4380375 ns 4819375 ns 0.91
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/1 thread(s) 258417 ns 259833 ns 0.99
Dense(128 => 128, gelu)(128 x 128)/zygote/GPU/CUDA 220859 ns 219889.5 ns 1.00
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/2 thread(s) 421604 ns 428313 ns 0.98
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/4 thread(s) 411750 ns 475541 ns 0.87
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/8 thread(s) 4491917 ns 4960437.5 ns 0.91
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/1 thread(s) 271250 ns 271333 ns 1.00
Dense(128 => 128, relu)(128 x 128)/forward/CPU/2 thread(s) 329896 ns 343709 ns 0.96
Dense(128 => 128, relu)(128 x 128)/forward/CPU/4 thread(s) 300084 ns 333937.5 ns 0.90
Dense(128 => 128, relu)(128 x 128)/forward/CPU/8 thread(s) 750333 ns 769833 ns 0.97
Dense(128 => 128, relu)(128 x 128)/forward/CPU/1 thread(s) 54375 ns 53125 ns 1.02
Dense(128 => 128, relu)(128 x 128)/forward/GPU/CUDA 27841 ns 28016 ns 0.99
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/2 thread(s) 355792 ns 362209 ns 0.98
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/4 thread(s) 247167 ns 342792 ns 0.72
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/8 thread(s) 868125 ns 897833 ns 0.97
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/1 thread(s) 151750 ns 152583 ns 0.99
Dense(128 => 128, relu)(128 x 128)/zygote/GPU/CUDA 205968 ns 205326.5 ns 1.00
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/2 thread(s) 368375 ns 378500 ns 0.97
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/4 thread(s) 261709 ns 358042 ns 0.73
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/8 thread(s) 714208 ns 728708 ns 0.98
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/1 thread(s) 151125 ns 150833.5 ns 1.00
vgg16(32, 32, 3, 64)/forward/CPU/2 thread(s) 601673542 ns 603479208 ns 1.00
vgg16(32, 32, 3, 64)/forward/CPU/4 thread(s) 433401687 ns 429058104 ns 1.01
vgg16(32, 32, 3, 64)/forward/CPU/8 thread(s) 378552750 ns 385950542 ns 0.98
vgg16(32, 32, 3, 64)/forward/CPU/1 thread(s) 874120625 ns 872372584 ns 1.00
vgg16(32, 32, 3, 64)/forward/GPU/CUDA 7030592 ns 7023071 ns 1.00
vgg16(32, 32, 3, 64)/zygote/CPU/2 thread(s) 2007087354.5 ns 2010730958 ns 1.00
vgg16(32, 32, 3, 64)/zygote/CPU/4 thread(s) 1632009874.5 ns 1608264687.5 ns 1.01
vgg16(32, 32, 3, 64)/zygote/CPU/8 thread(s) 1618542583.5 ns 1653085833 ns 0.98
vgg16(32, 32, 3, 64)/zygote/CPU/1 thread(s) 2637429416 ns 2638084625 ns 1.00
vgg16(32, 32, 3, 64)/zygote/GPU/CUDA 26054721.5 ns 25932761 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/2 thread(s) 523500 ns 535250 ns 0.98
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/4 thread(s) 435895.5 ns 433291.5 ns 1.01
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/8 thread(s) 1828249.5 ns 3023791.5 ns 0.60
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/1 thread(s) 866354 ns 880791 ns 0.98
Dense(512 => 512, gelu)(512 x 128)/forward/GPU/CUDA 47636 ns 46986 ns 1.01
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/2 thread(s) 1763270.5 ns 1881604 ns 0.94
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/4 thread(s) 2797458.5 ns 2798729 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/8 thread(s) 14370145.5 ns 16356750 ns 0.88
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/1 thread(s) 2769562.5 ns 2759229 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/zygote/GPU/CUDA 248789.5 ns 246659.5 ns 1.01
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/2 thread(s) 1945916.5 ns 1962958.5 ns 0.99
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/4 thread(s) 5043500 ns 5070604 ns 0.99
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/8 thread(s) 14572416 ns 16396875 ns 0.89
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/1 thread(s) 2785979.5 ns 2785625.5 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/2 thread(s) 1374375 ns 1614125 ns 0.85
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/4 thread(s) 1189542 ns 1235583 ns 0.96
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/8 thread(s) 1224645.5 ns 1027208 ns 1.19
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/1 thread(s) 2299000 ns 2300875 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/forward/GPU/CUDA 583268.5 ns 587018.5 ns 0.99
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/2 thread(s) 5918791 ns 5921542 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/4 thread(s) 7147000 ns 5089688 ns 1.40
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/8 thread(s) 24359584 ns 26372271 ns 0.92
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/1 thread(s) 7320208 ns 7288250 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/zygote/GPU/CUDA 1348690.5 ns 1379747.5 ns 0.98
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/2 thread(s) 13093542 ns 13324958 ns 0.98
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/4 thread(s) 12017167 ns 12237645.5 ns 0.98
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/8 thread(s) 20888000 ns 21281499.5 ns 0.98
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/1 thread(s) 10214417 ns 10668750 ns 0.96
Dense(16 => 16, relu)(16 x 128)/forward/CPU/2 thread(s) 2375 ns 4417 ns 0.54
Dense(16 => 16, relu)(16 x 128)/forward/CPU/4 thread(s) 2500 ns 2583.5 ns 0.97
Dense(16 => 16, relu)(16 x 128)/forward/CPU/8 thread(s) 3333.5 ns 2750 ns 1.21
Dense(16 => 16, relu)(16 x 128)/forward/CPU/1 thread(s) 2958 ns 2500 ns 1.18
Dense(16 => 16, relu)(16 x 128)/forward/GPU/CUDA 24628 ns 24754 ns 0.99
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/2 thread(s) 7291.5 ns 7459 ns 0.98
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/4 thread(s) 7083 ns 7250 ns 0.98
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/8 thread(s) 7333.5 ns 7333 ns 1.00
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/1 thread(s) 7083 ns 7083 ns 1
Dense(16 => 16, relu)(16 x 128)/zygote/GPU/CUDA 209898.5 ns 213008 ns 0.99
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/2 thread(s) 8250 ns 8375 ns 0.99
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/4 thread(s) 8208 ns 8583 ns 0.96
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/8 thread(s) 8375 ns 8459 ns 0.99
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/1 thread(s) 5958 ns 5834 ns 1.02
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/2 thread(s) 10458 ns 10625 ns 0.98
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/4 thread(s) 12937.5 ns 13708 ns 0.94
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/8 thread(s) 10708 ns 12042 ns 0.89
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/1 thread(s) 7250 ns 7500 ns 0.97
Dense(16 => 16, gelu)(16 x 128)/forward/GPU/CUDA 24907 ns 25091.5 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/2 thread(s) 19875 ns 20250 ns 0.98
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/4 thread(s) 20104.5 ns 19959 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/8 thread(s) 20125 ns 20083 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/1 thread(s) 20000 ns 19875 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/GPU/CUDA 230594 ns 231793 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/2 thread(s) 23583.5 ns 23625 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/4 thread(s) 23708 ns 23667 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/8 thread(s) 23625 ns 23666 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/1 thread(s) 21333 ns 21084 ns 1.01
Dense(128 => 128, identity)(128 x 128)/forward/CPU/2 thread(s) 28459 ns 28708 ns 0.99
Dense(128 => 128, identity)(128 x 128)/forward/CPU/4 thread(s) 28542 ns 29292 ns 0.97
Dense(128 => 128, identity)(128 x 128)/forward/CPU/8 thread(s) 28770.5 ns 28375 ns 1.01
Dense(128 => 128, identity)(128 x 128)/forward/CPU/1 thread(s) 45917 ns 46584 ns 0.99
Dense(128 => 128, identity)(128 x 128)/forward/GPU/CUDA 25803 ns 26247 ns 0.98
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/2 thread(s) 230250 ns 222250 ns 1.04
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/4 thread(s) 288166 ns 279729.5 ns 1.03
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/8 thread(s) 4212042 ns 4335396.5 ns 0.97
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/1 thread(s) 145000 ns 145208 ns 1.00
Dense(128 => 128, identity)(128 x 128)/zygote/GPU/CUDA 207914 ns 203061 ns 1.02
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/2 thread(s) 342187.5 ns 333124.5 ns 1.03
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/4 thread(s) 333166 ns 322500 ns 1.03
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/8 thread(s) 411895.5 ns 861333 ns 0.48
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/1 thread(s) 160646 ns 160750 ns 1.00
Dense(16 => 16, identity)(16 x 128)/forward/CPU/2 thread(s) 1750 ns 1875 ns 0.93
Dense(16 => 16, identity)(16 x 128)/forward/CPU/4 thread(s) 1791 ns 1958 ns 0.91
Dense(16 => 16, identity)(16 x 128)/forward/CPU/8 thread(s) 2250 ns 2416 ns 0.93
Dense(16 => 16, identity)(16 x 128)/forward/CPU/1 thread(s) 1958 ns 1792 ns 1.09
Dense(16 => 16, identity)(16 x 128)/forward/GPU/CUDA 23251.5 ns 23061 ns 1.01
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/2 thread(s) 5208 ns 5458 ns 0.95
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/4 thread(s) 5208 ns 5500 ns 0.95
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/8 thread(s) 5500 ns 5375 ns 1.02
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/1 thread(s) 5291 ns 5375 ns 0.98
Dense(16 => 16, identity)(16 x 128)/zygote/GPU/CUDA 245332 ns 243257 ns 1.01
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/2 thread(s) 11291.5 ns 11333.5 ns 1.00
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/4 thread(s) 11375 ns 11208 ns 1.01
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/8 thread(s) 11458 ns 11667 ns 0.98
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/1 thread(s) 6959 ns 6833 ns 1.02
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 79898667 ns 79834791 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 49104563 ns 49125291 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 44920792 ns 43259375 ns 1.04
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 151542042 ns 151428917 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/GPU/CUDA 2713787 ns 2726005 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 665144875 ns 498680292 ns 1.33
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 414328875 ns 414152083 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 399605708 ns 396991709 ns 1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 687317792 ns 689086500 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 14579874 ns 14585553 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 718439500 ns 712438146 ns 1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 685447833 ns 683887166 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 1000305625 ns 1013847083 ns 0.99
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 992652792 ns 999589459 ns 0.99

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.