|
23 | 23 |
|
24 | 24 | function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
|
25 | 25 | max_ = NNlib.fast_maximum(x; dims)
|
26 |
| - zero_num = Reactant.promote_to(TracedRNumber{T}, 0) |
27 |
| - one_num = Reactant.promote_to(TracedRNumber{T}, 1) |
28 |
| - @trace if all(isfinite, max_) |
29 |
| - @. out = exp(x - max_) |
30 |
| - else |
31 |
| - cond = max_ .== Inf |
32 |
| - true_pred = ifelse.(x .== Inf, one_num, zero_num) |
33 |
| - @. out = ifelse(cond, true_pred, exp(x - max_)) |
34 |
| - end |
| 26 | + # XXX: Once reverse mode of if is properly supported, we can make it @trace |
| 27 | + # zero_num = Reactant.promote_to(TracedRNumber{T}, 0) |
| 28 | + # one_num = Reactant.promote_to(TracedRNumber{T}, 1) |
| 29 | + # @trace if all(isfinite, max_) |
| 30 | + @. out = exp(x - max_) |
| 31 | + # else |
| 32 | + # cond = max_ .== Inf |
| 33 | + # true_pred = ifelse.(x .== Inf, one_num, zero_num) |
| 34 | + # @. out = ifelse(cond, true_pred, exp(x - max_)) |
| 35 | + # end |
35 | 36 | tmp = dims isa Colon ? sum(out) : sum!(max_, out)
|
36 | 37 | out ./= tmp
|
37 | 38 | return out
|
38 | 39 | end
|
39 | 40 |
|
40 | 41 | function NNlib.logsoftmax!(out::TracedRArray{T}, x::AbstractArray; dims=1) where {T}
|
41 | 42 | max_ = NNlib.fast_maximum(x; dims)
|
42 |
| - inf_num = Reactant.promote_to(TracedRNumber{T}, Inf) |
43 |
| - zero_num = Reactant.promote_to(TracedRNumber{T}, 0) |
44 |
| - @trace if all(isfinite, max_) |
45 |
| - @. out = x - max_ |
46 |
| - else |
47 |
| - cond = max_ .== Inf |
48 |
| - true_pred = ifelse.(x .== Inf, zero_num, -inf_num) |
49 |
| - @. out = ifelse(cond, true_pred, x - max_) |
50 |
| - end |
| 43 | + # XXX: Once reverse mode of if is properly supported, we can make it @trace |
| 44 | + # inf_num = Reactant.promote_to(TracedRNumber{T}, Inf) |
| 45 | + # zero_num = Reactant.promote_to(TracedRNumber{T}, 0) |
| 46 | + # @trace if all(isfinite, max_) |
| 47 | + @. out = x - max_ |
| 48 | + # else |
| 49 | + # cond = max_ .== Inf |
| 50 | + # true_pred = ifelse.(x .== Inf, zero_num, -inf_num) |
| 51 | + # @. out = ifelse(cond, true_pred, x - max_) |
| 52 | + # end |
51 | 53 | @fastmath log_ = log.(sum(exp, out; dims))
|
52 | 54 | out .-= log_
|
53 | 55 | return out
|
|
0 commit comments