Skip to content

Commit 34b6e1e

Browse files
committed
fix: temporarily avoid tracing in softmax and logsoftmax
1 parent bcdfd46 commit 34b6e1e

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,31 +23,33 @@ end
2323

2424
function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
2525
max_ = NNlib.fast_maximum(x; dims)
26-
zero_num = Reactant.promote_to(TracedRNumber{T}, 0)
27-
one_num = Reactant.promote_to(TracedRNumber{T}, 1)
28-
@trace if all(isfinite, max_)
29-
@. out = exp(x - max_)
30-
else
31-
cond = max_ .== Inf
32-
true_pred = ifelse.(x .== Inf, one_num, zero_num)
33-
@. out = ifelse(cond, true_pred, exp(x - max_))
34-
end
26+
# XXX: Once reverse mode of if is properly supported, we can make it @trace
27+
# zero_num = Reactant.promote_to(TracedRNumber{T}, 0)
28+
# one_num = Reactant.promote_to(TracedRNumber{T}, 1)
29+
# @trace if all(isfinite, max_)
30+
@. out = exp(x - max_)
31+
# else
32+
# cond = max_ .== Inf
33+
# true_pred = ifelse.(x .== Inf, one_num, zero_num)
34+
# @. out = ifelse(cond, true_pred, exp(x - max_))
35+
# end
3536
tmp = dims isa Colon ? sum(out) : sum!(max_, out)
3637
out ./= tmp
3738
return out
3839
end
3940

4041
function NNlib.logsoftmax!(out::TracedRArray{T}, x::AbstractArray; dims=1) where {T}
4142
max_ = NNlib.fast_maximum(x; dims)
42-
inf_num = Reactant.promote_to(TracedRNumber{T}, Inf)
43-
zero_num = Reactant.promote_to(TracedRNumber{T}, 0)
44-
@trace if all(isfinite, max_)
45-
@. out = x - max_
46-
else
47-
cond = max_ .== Inf
48-
true_pred = ifelse.(x .== Inf, zero_num, -inf_num)
49-
@. out = ifelse(cond, true_pred, x - max_)
50-
end
43+
# XXX: Once reverse mode of if is properly supported, we can make it @trace
44+
# inf_num = Reactant.promote_to(TracedRNumber{T}, Inf)
45+
# zero_num = Reactant.promote_to(TracedRNumber{T}, 0)
46+
# @trace if all(isfinite, max_)
47+
@. out = x - max_
48+
# else
49+
# cond = max_ .== Inf
50+
# true_pred = ifelse.(x .== Inf, zero_num, -inf_num)
51+
# @. out = ifelse(cond, true_pred, x - max_)
52+
# end
5153
@fastmath log_ = log.(sum(exp, out; dims))
5254
out .-= log_
5355
return out

0 commit comments

Comments
 (0)