Skip to content

Commit

Permalink
Zygote workarounds
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Jun 26, 2020
1 parent 1b708f7 commit 8fc5e74
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
18 changes: 13 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,20 @@ end
normalize x sums to 1
"""
function s1(x, dims=:)
m = minimum(x, dims=dims)
x = copy(x)
if any(<(0), m)
x .-= m
if isderiving()
m = minimum(real(x), dims=dims)
if any(<(0), m)
x = x .- m
end
return x ./ sum(x, dims=dims)
else
m = minimum(x, dims=dims)
x = float.(x)
if any(<(0), m)
x .-= m
end
x ./= sum(x, dims=dims)
end
x ./= sum(x, dims=dims)
end

"""
Expand Down
2 changes: 1 addition & 1 deletion test/test_barycenter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ end


a = ones(k) |> s1
b = [[1,2,3,4] |> s1 for _ in eachindex(Y)]
b = [[1.0,2,3,4] |> s1 for _ in eachindex(Y)]
M = SpectralDistances.distmat_euclidean(X,Y[1])
g1,a1,b1 = ot_jump(M,a,b[1]) .|> r6
g2,a2,b2 = sinkhorn_log(M,a,b[1], β=0.001, iters=50000, printerval=5000, tol=1e-9) .|> r6
Expand Down

0 comments on commit 8fc5e74

Please sign in to comment.