Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889"
GraphsMatching = "c3af3a8c-b79e-4b01-bf44-c718d7e0e0d6"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
Expand All @@ -20,7 +21,8 @@ MathOptInterface = "1"
julia = "1"

[extras]
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Test", "HiGHS"]
7 changes: 2 additions & 5 deletions src/CombinatorialLinearOracles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@ using Graphs
using SparseArrays
using GraphsMatching

include("MatchingLinearOracle.jl")
include("SpanningTreeLinearOracle.jl")



include("matchings.jl")
include("spanning_tree.jl")

end
34 changes: 0 additions & 34 deletions src/MatchingLinearOracle.jl

This file was deleted.

88 changes: 88 additions & 0 deletions src/matchings.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@

"""
MatchingLMO{G}(g::Graphs)

Return a vector v corresponding to edges(g), where if v[i] = 1,
the edge i is in the maximum weight matching, and if v[i] = 0, the edge i is not in the matching.
"""
struct MatchingLMO{G} <: FrankWolfe.LinearMinimizationOracle
graph::G
end

function FrankWolfe.compute_extreme_point(
lmo::MatchingLMO,
direction::M;
v=nothing,
kwargs...,
) where {M}
N = length(direction)
v = spzeros(N)
iter = collect(edges(lmo.graph))
g = SimpleGraphFromIterator(iter)
l = nv(g)
add_vertices!(g, l)
w = Dict{typeof(iter[1]),typeof(direction[1])}()
for i in 1:N
add_edge!(g, src(iter[i]) + l, dst(iter[i]) + l)
w[iter[i]] = -direction[i]
w[Edge(src(iter[i]) + l, dst(iter[i]) + l)] = -direction[i]
end

for i in 1:l
add_edge!(g, i, i + l)
w[Edge(i, i + l)] = 0
end

match = GraphsMatching.minimum_weight_perfect_matching(g, w)

K = length(match.mate)
for i in 1:K
for j in 1:N
if (match.mate[i] == src(iter[j]) && dst(iter[j]) == i)
v[j] = 1
end
end
end
return v
end


"""
PerfectMatchingLMO{G}(g::Graphs)

Return a vector v corresponding to edges(g), where if v[i] = 1,
the edge i is in the matching, and if v[i] = 0, the edge i is not in the matching.
If there is not possible perfect matching, all elements of v are set to 0.
"""
struct PerfectMatchingLMO{G} <: FrankWolfe.LinearMinimizationOracle
graph::G
end

function FrankWolfe.compute_extreme_point(
lmo::PerfectMatchingLMO,
direction::M;
v=nothing,
kwargs...,
) where {M}
N = length(direction)
v = spzeros(N)
if (nv(lmo.graph) % 2 != 0)
return v
end
iter = collect(Graphs.edges(lmo.graph))
w = Dict{typeof(iter[1]),typeof(direction[1])}()
for i in 1:N
w[iter[i]] = direction[i]
end

match = GraphsMatching.minimum_weight_perfect_matching(lmo.graph, w)
K = length(match.mate)
for i in 1:K
for j in 1:N
if (match.mate[i] == src(iter[j]) && dst(iter[j]) == i)
v[j] = 1
end
end
end
return v
end
7 changes: 6 additions & 1 deletion src/SpanningTreeLinearOracle.jl → src/spanning_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@ struct SpanningTreeLMO{G} <: FrankWolfe.LinearMinimizationOracle
graph::G
end

function compute_extreme_point(lmo::SpanningTreeLMO, direction::M; v=nothing, kwargs...) where {M}
function FrankWolfe.compute_extreme_point(
lmo::SpanningTreeLMO,
direction::M;
v=nothing,
kwargs...,
) where {M}
N = length(direction)
iter = collect(Graphs.edges(lmo.graph))
distmx = spzeros(N, N)
Expand Down
48 changes: 40 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
using .CombinatorialLinearOracles
using CombinatorialLinearOracles
using Test
using Random
using SparseArrays

using Graphs
using GraphsMatching
using HiGHS
import FrankWolfe

@testset "Matching LMO" begin
@testset "Perfect Matching LMO" begin
N = Int(1e3)
Random.seed!(4321)
g = Graphs.complete_graph(N)
iter = collect(Graphs.edges(g))
M = length(iter)
direction = randn(M)
lmo = CombinatorialLinearOracles.MatchingLMO(g)
v = CombinatorialLinearOracles.compute_extreme_point(lmo, direction)
lmo = CombinatorialLinearOracles.PerfectMatchingLMO(g)
v = FrankWolfe.compute_extreme_point(lmo, direction)
tab = zeros(M)
is_matching = true
for i in 1:M
Expand All @@ -24,9 +29,36 @@ using Graphs
tab[dst(iter[i])] = 1
end
end
@test is_matching == true
@test is_matching
end

@testset "Matching LMO" begin
N = 200
Random.seed!(9754)
g = Graphs.complete_graph(N)
iter = collect(Graphs.edges(g))
M = length(iter)
direction = randn(M)
lmo = CombinatorialLinearOracles.MatchingLMO(g)
v = FrankWolfe.compute_extreme_point(lmo, direction)
adj_mat = spzeros(M, M)
for i in 1:M
adj_mat[src(iter[i]), dst(iter[i])] = direction[i]
end
match_result = GraphsMatching.maximum_weight_matching(g, HiGHS.Optimizer, adj_mat)
v_sol = spzeros(M)
K = length(match_result.mate)
for i in 1:K
for j in 1:M
if (match_result.mate[i] == src(iter[j]) && dst(iter[j]) == i)
v_sol[j] = 1
end
end
end
@test v_sol == v
end


@testset "SpanningTreeLMO" begin
N = 500
Random.seed!(1645)
Expand All @@ -35,12 +67,12 @@ end
M = length(iter)
direction = randn(M)
lmo = CombinatorialLinearOracles.SpanningTreeLMO(g)
v = CombinatorialLinearOracles.compute_extreme_point(lmo, direction)
tree = Array{Edge}(undef, (0,))
v = FrankWolfe.compute_extreme_point(lmo, direction)
tree = eltype(iter)[]
for i in 1:M
if (v[i] == 1)
push!(tree, iter[i])
end
end
@test Graphs.is_tree(SimpleGraphFromIterator(tree)) == true
@test Graphs.is_tree(SimpleGraphFromIterator(tree))
end