-
Notifications
You must be signed in to change notification settings - Fork 25
Closed
Description
Hello, I found that Enzyme.make_zero! seems to not hit after compilation,
mwe :
using Lux,Reactant,Enzyme,Random,BenchmarkTools
Reactant.set_default_backend("cpu")
dev = reactant_device()
T = f64
rng = Random.default_rng(123)
x = randn(rng,Float64,1,10_000) |> T |> dev
y = copy(x) |> T |> dev
model = Lux.Chain(Lux.Dense(1 => 64,tanh), Lux.Dense(64 => 64,tanh), Lux.Dense(64 => 1))
ps,st = Lux.setup(rng,model) |> T |> dev
dps = Enzyme.make_zero(ps)
loss(m,p,s,x,y) = sum(abs2,first(m(x,p,s)).-y)
function get_grad!(loss,m,p,dp,s,x,y)
Enzyme.make_zero!(dp)
Enzyme.autodiff(Reverse,loss,Const(m),Duplicated(p,dp),Const(s),Const(x),Const(y))
nothing
end
gg = @compile get_grad!(loss,model,ps,dps,st,x,y)
gg(loss,model,ps,dps,st,x,y)
# dps contains correct gradient
display(cpu_device()(dps))
gg(loss,model,ps,dps,st,x,y)
# incorrect gradient
display(cpu_device()(dps))
(layer_1 = (weight = [137.0858704640575; -1021.8445090071898; … ; 746.9016594797259; 179.72655660246477;;], bias = [7.484189012981482, -415.9037422493918, 9.471926343647242, 216.15521090737536, 503.5419781680868, -504.6268818558066, 49.317270824727814, -2.1342143301106926, -139.0897491562625, 212.59612679051307 … 1527.9539957985828, 119.0948372726741, -51.61685419902613, -491.54165964228457, 339.9507691301908, -69.82426021393093, 324.3615174855332, 404.98770950839855, 585.5644808977161, 731.8103712227135]), layer_2 = (weight = [259.37017898922124 -108.52021399697412 … -184.07311532009467 152.27354740599554; 318.2573355289224 -39.648447146359025 … -425.35583881506955 466.08894941449313; … ; 292.48760196218024 -144.06796355192407 … -159.2229753535182 96.77536085154088; -51.28136438802544 20.322924534536813 … 41.281281718339 -38.514000446431254], bias = [232.21741909140164, 7.937352807926769, 315.0543310051971, -42.36576845990219, 397.108024016336, 548.4611619151846, 240.48440978442602, -108.47823497572827, -128.74223790331826, -149.71337592769888 … -605.5741889872332, 284.28406704843104, -753.9593292935859, 573.0296460485242, 41.191071823372994, 452.12173178015473, -695.4886947321189, -131.16455425947234, 325.4815507498512, -42.27691266606935]), layer_3 = (weight = [-12037.767910215289 -10440.564006231758 … -12724.42611029213 -11577.164681542112], bias = [-4951.698200238455]))
(layer_1 = (weight = [274.171740928115; -2043.6890180143796; … ; 1493.8033189594519; 359.45311320492954;;], bias = [14.968378025962965, -831.8074844987837, 18.943852687294484, 432.3104218147507, 1007.0839563361736, -1009.2537637116131, 98.63454164945563, -4.268428660221385, -278.179498312525, 425.19225358102614 … 3055.9079915971656, 238.1896745453482, -103.23370839805226, -983.0833192845691, 679.9015382603816, -139.64852042786185, 648.7230349710665, 809.9754190167971, 1171.1289617954321, 1463.620742445427]), layer_2 = (weight = [518.7403579784425 -217.04042799394824 … -368.14623064018934 304.5470948119911; 636.5146710578448 -79.29689429271805 … -850.7116776301391 932.1778988289863; … ; 584.9752039243605 -288.13592710384813 … -318.4459507070364 193.55072170308176; -102.56272877605088 40.64584906907363 … 82.562563436678 -77.02800089286251], bias = [464.4348381828033, 15.874705615853538, 630.1086620103943, -84.73153691980438, 794.216048032672, 1096.9223238303691, 480.96881956885204, -216.95646995145654, -257.4844758066365, -299.42675185539775 … -1211.1483779744665, 568.5681340968621, -1507.9186585871719, 1146.0592920970485, 82.38214364674599, 904.2434635603095, -1390.9773894642378, -262.3291085189447, 650.9631014997024, -84.5538253321387]), layer_3 = (weight = [-24075.535820430578 -20881.128012463516 … -25448.85222058426 -23154.329363084224], bias = [-9903.39640047691]))
Any idea why ?
Metadata
Metadata
Assignees
Labels
No labels