Skip to content

Commit ab9c91a

Browse files
authored
Absint typeof ijl_new_structt (EnzymeAD#1206)
* Absint typeof ijl_new_structt * add needsshadow check for jl_nthfield_rev
1 parent 770b064 commit ab9c91a

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

src/absint.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ
146146
return absint(operands(arg)[1], partial)
147147
end
148148

149+
if nm == "jl_new_structt" || nm == "ijl_new_structt"
150+
return absint(operands(arg)[1], partial)
151+
end
149152

150153
if LLVM.callconv(arg) == 37 || nm == "julia.call"
151154
index = 1
@@ -234,4 +237,4 @@ function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String}
234237
end
235238
end
236239
return (false, "")
237-
end
240+
end

src/rules/typeunstablerules.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,16 @@ function jl_nthfield_rev(B, orig, gutils, tape)
480480
return
481481
end
482482

483+
needsShadowP = Ref{UInt8}(0)
484+
needsPrimalP = Ref{UInt8}(0)
485+
activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal)
486+
needsPrimal = needsPrimalP[] != 0
487+
needsShadow = needsShadowP[] != 0
488+
489+
if !needsShadow
490+
return
491+
end
492+
483493
ops = collect(operands(orig))
484494
width = get_width(gutils)
485495

0 commit comments

Comments
 (0)