Skip to content

Commit 3849087

Browse files
committed
Added method to HomQCQP for extracting the standard form of a problem for use with SCS, Clarabel, etc.
Added a function to run clarabel solver. Added functions to convert between vectorized form of PSD matrices.
1 parent 127a355 commit 3849087

9 files changed

+445
-137
lines changed

_scripts/rot_loop_admm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ def get_junction_tree(self, plot=False):
434434
var2 = G.vs["name"][cliques[i][1]]
435435
clique_obj += [self.get_clique_obj(var1, var2, i)]
436436
for j in range(i + 1, len(cliques)):
437-
# Get seperator set for list
437+
# Get separator set for list
438438
sepset = set(cliques[i]) & set(cliques[j])
439439
if len(sepset) > 0:
440440
junction.add_edge(i, j, weight=-len(sepset), sepset=sepset)

_test/test_homqcqp.py

+60-22
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from poly_matrix import PolyMatrix
77

88
from cert_tools import HomQCQP
9+
from cert_tools.linalg_tools import smat, svec
910
from cert_tools.problems.rot_synch import RotSynchLoopProblem
10-
from cert_tools.sparse_solvers import solve_dsdp
11+
from cert_tools.sdp_solvers import solve_sdp_homqcqp
12+
from cert_tools.sparse_solvers import solve_clarabel, solve_dsdp
1113

1214

1315
def get_chain_rot_prob(N=10):
@@ -20,33 +22,32 @@ def get_loop_rot_prob():
2022

2123
class TestHomQCQP(unittest.TestCase):
2224

23-
def test_solve():
25+
def test_solve(self):
2426
# Create chain of rotations problem:
2527
problem = get_chain_rot_prob()
2628
assert isinstance(problem, HomQCQP), TypeError(
2729
"Problem should be homogenized qcqp object"
2830
)
2931
# Solve SDP via standard method
30-
X, info, time = problem.solve_sdp(verbose=True)
32+
X, info, time = solve_sdp_homqcqp(problem, verbose=True)
3133
# Convert solution
3234
R = problem.convert_sdp_to_rot(X)
3335

34-
def test_get_asg(plot=False):
35-
"""Test retreival of aggregate sparsity graph"""
36+
def test_get_asg(self, plot=False):
37+
"""Test retrieval of aggregate sparsity graph"""
3638
# Test on chain graph
3739
problem = get_chain_rot_prob()
38-
problem.get_asg(rm_homog=True) # Get Graph
39-
problem.triangulate_graph() # Triangulate graph
40+
problem.clique_decomposition()
4041
# No fill in expected
41-
assert any(problem.asg.es["fill_edge"]) is False
42+
assert problem.symb.fill[0] == 0, ValueError("Expected no fill in")
4243
if plot:
4344
problem.plot_asg()
4445

4546
# Test on loop graph
4647
problem = get_loop_rot_prob()
47-
problem.get_asg(rm_homog=True)
48-
problem.triangulate_graph()
49-
assert any(problem.asg.es["fill_edge"]) is True
48+
problem.get_asg(rm_homog=False)
49+
problem.clique_decomposition()
50+
assert problem.symb.fill[0] > 0, ValueError("Expected fill in")
5051
if plot:
5152
problem.plot_asg()
5253

@@ -77,8 +78,8 @@ def test_clique_decomp(self, rm_homog=False, plot=False):
7778
parent = problem.cliques[clique.parent]
7879

7980
vertices = list(parent.var_sizes.keys()) + list(clique.var_sizes.keys())
80-
assert set(clique.seperator).issubset(vertices), ValueError(
81-
"seperator set should be in set of involved clique vertices"
81+
assert set(clique.separator).issubset(vertices), ValueError(
82+
"separator set should be in set of involved clique vertices"
8283
)
8384

8485
# Check that mapping from variables to cliques is correct
@@ -98,7 +99,6 @@ def test_consistency_constraints(self):
9899
# Test chain topology
99100
nvars = 5
100101
problem = get_chain_rot_prob(N=nvars)
101-
problem.get_asg(rm_homog=False) # Get Graph
102102
problem.clique_decomposition() # Run clique decomposition
103103
eq_list = problem.get_consistency_constraints()
104104

@@ -115,11 +115,11 @@ def test_consistency_constraints(self):
115115
x_list, info = solve_dsdp(problem, verbose=True, tol=1e-8)
116116
# Verify that the clique variables are equal on overlaps
117117
for l, clique_l in enumerate(problem.cliques):
118-
# seperator
119-
sepset = clique_l.seperator
118+
# separator
119+
sepset = clique_l.separator
120120
if len(sepset) == 0: # skip the root clique
121121
continue
122-
# fet parent clique and seperator set
122+
# fet parent clique and separator set
123123
k = clique_l.parent
124124
clique_k = problem.cliques[k]
125125

@@ -135,7 +135,6 @@ def test_decompose_matrix(self):
135135
# setup
136136
nvars = 5
137137
problem = get_chain_rot_prob(N=nvars)
138-
problem.get_asg(rm_homog=False) # Get Graph
139138
problem.clique_decomposition() # get clique decomposition
140139
C = problem.C
141140

@@ -173,10 +172,9 @@ def test_solve_dsdp(self):
173172
# Test chain topology
174173
nvars = 5
175174
problem = get_chain_rot_prob(N=nvars)
176-
problem.get_asg(rm_homog=False) # Get agg sparse graph
177175
problem.clique_decomposition() # get cliques
178176
# Solve non-decomposed problem
179-
X, info, time = problem.solve_sdp(verbose=True)
177+
X, info, time = solve_sdp_homqcqp(problem, verbose=True)
180178
# get cliques from non-decomposed solution
181179
c_list_nd = problem.get_cliques_from_psd_mat(X)
182180
# Solve decomposed problem (Interior Point Version)
@@ -208,12 +206,52 @@ def test_solve_dsdp(self):
208206
err_msg="Completed and non-decomposed solutions differ",
209207
)
210208

209+
def test_standard_form(self):
210+
"""Test that the standard form problem definition is correct"""
211+
nvars = 2
212+
problem = get_chain_rot_prob(N=nvars)
213+
problem.get_asg()
214+
P, q, A, b = problem.get_standard_form()
215+
216+
# get solution from MOSEK
217+
X, info, time = solve_sdp_homqcqp(problem, verbose=True)
218+
x = svec(X)
219+
220+
# Check cost matrix
221+
cost = np.dot(b, x)
222+
np.testing.assert_allclose(
223+
cost, info["cost"], atol=1e-12, err_msg="Cost incorrect"
224+
)
225+
# Check constraints
226+
for i, vec in enumerate(A.T):
227+
a = vec.toarray().squeeze(0)
228+
value = np.dot(a, x)
229+
np.testing.assert_allclose(
230+
value, -q[i], atol=1e-10, err_msg=f"Constraint {i} has violation"
231+
)
232+
233+
def test_clarabel(self):
234+
nvars = 2
235+
problem = get_chain_rot_prob(N=nvars)
236+
problem.get_asg()
237+
X_clarabel = solve_clarabel(problem)
238+
X, info, time = solve_sdp_homqcqp(problem, verbose=True)
239+
240+
np.testing.assert_allclose(
241+
X_clarabel,
242+
X,
243+
atol=1e-9,
244+
err_msg="Clarabel and MOSEK solutions differ",
245+
)
246+
211247

212248
if __name__ == "__main__":
213249
test = TestHomQCQP()
214250
# test.test_solve()
215251
# test.test_get_asg(plot=True)
216252
# test.test_clique_decomp(plot=False)
217-
# test.test_consistency_constraints()
253+
test.test_consistency_constraints()
218254
# test.test_decompose_matrix()
219-
test.test_solve_dsdp()
255+
# test.test_solve_dsdp()
256+
# test.test_standard_form()
257+
# test.test_clarabel()

_test/test_linalg_tools.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import unittest
2+
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
import scipy.sparse as sp
6+
7+
from cert_tools.linalg_tools import smat, svec
8+
9+
10+
class TestLinAlg(unittest.TestCase):
11+
12+
def get_psd_mat(self, n=3):
13+
# get random symmetric matrix
14+
A = np.random.random((n, n))
15+
U, S, V = np.linalg.svd(A)
16+
S = np.abs(S)
17+
return (U * S) @ U.T
18+
19+
def test_svec(self):
20+
# fix seed
21+
np.random.seed(0)
22+
# generate matrices
23+
S1 = self.get_psd_mat()
24+
S2 = self.get_psd_mat()
25+
# Vectorize
26+
s1 = svec(S1)
27+
s2 = svec(S2)
28+
# test mapping
29+
np.testing.assert_almost_equal(smat(s1), S1)
30+
np.testing.assert_almost_equal(smat(s2), S2)
31+
# products should be equal
32+
prod_mat = np.trace(S1 @ S2)
33+
prod_vec = np.dot(s1, s2)
34+
assert prod_mat == prod_vec, "PSD Inner product not equal"
35+
36+
def test_svec_sparse(self):
37+
# fix seed
38+
np.random.seed(0)
39+
# generate matrices
40+
S1_dense = self.get_psd_mat()
41+
S2_dense = self.get_psd_mat()
42+
# Remove element to make sure still works when not dense
43+
# S1_dense[4, 5] = 0.0
44+
# S1_dense[5, 4] = 0.0
45+
S1 = sp.csc_matrix(S1_dense)
46+
S2 = sp.csc_matrix(S2_dense)
47+
S1.eliminate_zeros()
48+
S2.eliminate_zeros()
49+
# Vectorize
50+
s1 = svec(S1)
51+
s2 = svec(S2)
52+
s1_dense = svec(S1_dense)
53+
s2_dense = svec(S2_dense)
54+
np.testing.assert_almost_equal(s1_dense, s1.toarray().squeeze(0))
55+
np.testing.assert_almost_equal(s2_dense, s2.toarray().squeeze(0))
56+
# test mapping
57+
np.testing.assert_almost_equal(smat(s1), S1.toarray())
58+
np.testing.assert_almost_equal(smat(s2), S2.toarray())
59+
# products should be equal
60+
prod_mat = np.trace(S1.toarray() @ S2.toarray())
61+
prod_vec = (s1 @ s2.T).toarray()
62+
assert prod_mat == prod_vec, "PSD Inner product not equal"
63+
64+
65+
if __name__ == "__main__":
66+
test = TestLinAlg()
67+
test.test_svec()
68+
test.test_svec_sparse()

cert_tools/base_clique.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,25 @@ def __init__(
2727
index,
2828
var_sizes,
2929
parent,
30-
seperator,
30+
separator,
3131
Q: PolyMatrix = PolyMatrix(),
3232
A_list=[],
3333
b_list=[],
3434
):
3535
self.index = index
36-
if var_sizes is not None:
37-
assert "h" in var_sizes, f"Each clique must have a homogenizing variable"
36+
# if var_sizes is not None:
37+
# assert "h" in var_sizes, f"Each clique must have a homogenizing variable"
3838
self.var_sizes = var_sizes
3939
self.var_list = list(self.var_sizes.keys())
4040
self.var_inds, self.size = self._get_start_indices()
4141
# Store clique tree information
42-
for key in seperator:
42+
for key in separator:
4343
assert (
4444
key in self.var_sizes.keys()
45-
), f"seperator element {key} not contained in clique"
46-
self.seperator = seperator # seperator set between this clique and its parent
45+
), f"separator element {key} not contained in clique"
46+
self.separator = separator # separator set between this clique and its parent
4747
self.residual = self.var_list.copy()
48-
for varname in seperator:
48+
for varname in separator:
4949
self.residual.remove(varname)
5050
self.parent = parent # index of the parent clique
5151
self.children = set() # set of children of this clique in the clique tree
@@ -81,7 +81,7 @@ def _get_indices(self, var_list):
8181
_type_: _description_
8282
"""
8383
if type(var_list) is not list:
84-
var_list = list(var_list)
84+
var_list = [var_list]
8585
# Get index slices for the rows
8686
slices = []
8787
for varname in var_list:
@@ -96,10 +96,10 @@ def get_slices(self, mat, var_list_row, var_list_col=[]):
9696
If one list provided then slices are assumed to be symmetric. If two lists are provided, they are interpreted as the row and column lists, respectively.
9797
"""
9898
# Get index slices for the rows
99-
inds1 = self._get_indices(var_list_col)
99+
inds1 = self._get_indices(var_list_row)
100100
# Get index slices for the columns
101101
if len(var_list_col) > 0:
102-
inds2 = self._get_indices(var_list_row)
102+
inds2 = self._get_indices(var_list_col)
103103
else:
104104
# If not defined use the same list as rows
105105
inds2 = inds1

0 commit comments

Comments
 (0)