Skip to content

make_zero! doesn't apply inside a compiled gradient #1503

@yolhan83

Description

@yolhan83

Hello, I found that Enzyme.make_zero! seems to not hit after compilation,
mwe :

using Lux,Reactant,Enzyme,Random,BenchmarkTools

Reactant.set_default_backend("cpu")
dev = reactant_device()
T = f64

rng = Random.default_rng(123)
x = randn(rng,Float64,1,10_000) |> T |> dev
y = copy(x) |> T |> dev

model = Lux.Chain(Lux.Dense(1 => 64,tanh), Lux.Dense(64 => 64,tanh), Lux.Dense(64 => 1))
ps,st = Lux.setup(rng,model) |> T |> dev
dps = Enzyme.make_zero(ps)

loss(m,p,s,x,y) = sum(abs2,first(m(x,p,s)).-y)
function get_grad!(loss,m,p,dp,s,x,y) 
    Enzyme.make_zero!(dp)
    Enzyme.autodiff(Reverse,loss,Const(m),Duplicated(p,dp),Const(s),Const(x),Const(y))
    nothing
end

gg = @compile get_grad!(loss,model,ps,dps,st,x,y)

gg(loss,model,ps,dps,st,x,y)
# dps contains correct gradient
display(cpu_device()(dps))
gg(loss,model,ps,dps,st,x,y)
# incorrect gradient
display(cpu_device()(dps))
(layer_1 = (weight = [137.0858704640575; -1021.8445090071898; … ; 746.9016594797259; 179.72655660246477;;], bias = [7.484189012981482, -415.9037422493918, 9.471926343647242, 216.15521090737536, 503.5419781680868, -504.6268818558066, 49.317270824727814, -2.1342143301106926, -139.0897491562625, 212.59612679051307  …  1527.9539957985828, 119.0948372726741, -51.61685419902613, -491.54165964228457, 339.9507691301908, -69.82426021393093, 324.3615174855332, 404.98770950839855, 585.5644808977161, 731.8103712227135]), layer_2 = (weight = [259.37017898922124 -108.52021399697412 … -184.07311532009467 152.27354740599554; 318.2573355289224 -39.648447146359025 … -425.35583881506955 466.08894941449313; … ; 292.48760196218024 -144.06796355192407 … -159.2229753535182 96.77536085154088; -51.28136438802544 20.322924534536813 … 41.281281718339 -38.514000446431254], bias = [232.21741909140164, 7.937352807926769, 315.0543310051971, -42.36576845990219, 397.108024016336, 548.4611619151846, 240.48440978442602, -108.47823497572827, -128.74223790331826, -149.71337592769888  …  -605.5741889872332, 284.28406704843104, -753.9593292935859, 573.0296460485242, 41.191071823372994, 452.12173178015473, -695.4886947321189, -131.16455425947234, 325.4815507498512, -42.27691266606935]), layer_3 = (weight = [-12037.767910215289 -10440.564006231758 … -12724.42611029213 -11577.164681542112], bias = [-4951.698200238455]))
(layer_1 = (weight = [274.171740928115; -2043.6890180143796; … ; 1493.8033189594519; 359.45311320492954;;], bias = [14.968378025962965, -831.8074844987837, 18.943852687294484, 432.3104218147507, 1007.0839563361736, -1009.2537637116131, 98.63454164945563, -4.268428660221385, -278.179498312525, 425.19225358102614  …  3055.9079915971656, 238.1896745453482, -103.23370839805226, -983.0833192845691, 679.9015382603816, -139.64852042786185, 648.7230349710665, 809.9754190167971, 1171.1289617954321, 1463.620742445427]), layer_2 = (weight = [518.7403579784425 -217.04042799394824 … -368.14623064018934 304.5470948119911; 636.5146710578448 -79.29689429271805 … -850.7116776301391 932.1778988289863; … ; 584.9752039243605 -288.13592710384813 … -318.4459507070364 193.55072170308176; -102.56272877605088 40.64584906907363 … 82.562563436678 -77.02800089286251], bias = [464.4348381828033, 15.874705615853538, 630.1086620103943, -84.73153691980438, 794.216048032672, 1096.9223238303691, 480.96881956885204, -216.95646995145654, -257.4844758066365, -299.42675185539775  …  -1211.1483779744665, 568.5681340968621, -1507.9186585871719, 1146.0592920970485, 82.38214364674599, 904.2434635603095, -1390.9773894642378, -262.3291085189447, 650.9631014997024, -84.5538253321387]), layer_3 = (weight = [-24075.535820430578 -20881.128012463516 … -25448.85222058426 -23154.329363084224], bias = [-9903.39640047691]))

Any idea why ?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions