Skip to content

Commit 7088d48

Browse files
committed
Move rot_synch class to _test/utils.py
1 parent d3a4e8c commit 7088d48

File tree

5 files changed

+250
-608
lines changed

5 files changed

+250
-608
lines changed

_test/test_cliques.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,11 @@ def test_cost(cliques, C_gt):
5252
for var_list, clique in zip(clique_list, problem.cliques):
5353
assert set(clique.var_list).difference(var_list) == set()
5454

55+
# parent of 0 is 1, parent of 1 is 2, parent of 2 is itself (it's the root)
5556
clique_data = {
5657
"cliques": [{"h", "x_1", "x_2"}, {"h", "x_2", "x_3"}, {"h", "x_3", "x_4"}],
5758
"separators": [{"h", "x_2"}, {"h", "x_3"}, {}],
58-
"parents": [
59-
1,
60-
2,
61-
2,
62-
], # parent of 0 is 1, parent of 1 is 2, parent of 2 is itself (it's the root)
59+
"parents": [1, 2, 2],
6360
}
6461
problem.clique_decomposition(clique_data=clique_data)
6562
test_cost(problem.cliques, C_gt)

_test/test_homqcqp.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,15 @@
33

44
import numpy as np
55
from poly_matrix import PolyMatrix
6+
from utils import get_chain_rot_prob, get_loop_rot_prob
67

78
from cert_tools import HomQCQP
89
from cert_tools.hom_qcqp import greedy_cover
910
from cert_tools.linalg_tools import svec
10-
from cert_tools.problems.rot_synch import RotSynchLoopProblem
1111
from cert_tools.sdp_solvers import solve_sdp_homqcqp
1212
from cert_tools.sparse_solvers import solve_clarabel, solve_dsdp
1313

1414

15-
def get_chain_rot_prob(N=10, locked_pose=0):
16-
return RotSynchLoopProblem(N=N, loop_pose=-1, locked_pose=locked_pose)
17-
18-
19-
def get_loop_rot_prob():
20-
return RotSynchLoopProblem()
21-
22-
2315
class TestHomQCQP(unittest.TestCase):
2416

2517
def test_solve(self):

_test/utils.py

+247
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
import numpy as np
2+
from poly_matrix import PolyMatrix
3+
from pylgmath import so3op
4+
5+
from cert_tools import HomQCQP
6+
7+
# Global Defaults
8+
ER_MIN = 1e6
9+
10+
11+
class RotSynchLoopProblem(HomQCQP):
12+
"""Class to generate and solve a rotation synchronization problem with loop
13+
constraints. The problem is generated with ground truth rotations and noisy
14+
measurements between them. The goal is to recover the ground truth rotations
15+
using a semidefinite programming approach. The loop constraints are encoded
16+
as O(3) constraints in the SDP.
17+
Problem is vectorized so that we can add a homogenization variable.
18+
19+
Attributes:
20+
N (int): Number of poses in the problem
21+
sigma (float): Standard deviation of noise in measurements
22+
R_gt (np.array): Ground truth rotations
23+
meas_dict (dict): Dictionary of noisy measurements
24+
cost (PolyMatrix): Cost matrix for the SDP
25+
constraints (list): List of O(3) constraints for the SDP
26+
"""
27+
28+
"""Rotation synchronization problem configured in a loop (non-chordal)
29+
"""
30+
31+
def __init__(self, N=10, sigma=1e-3, loop_pose=3, locked_pose=0, seed=0):
32+
super().__init__()
33+
34+
np.random.seed(seed)
35+
# generate ground truth poses
36+
aaxis_ab_rand = np.random.uniform(-np.pi / 2, np.pi / 2, size=(N, 3, 1))
37+
R_gt = so3op.vec2rot(aaxis_ab_rand)
38+
# Associated variable list
39+
self.var_sizes = {"h": 1}
40+
for i in range(N):
41+
self.var_sizes[str(i)] = 9
42+
# Generate Measurements as a dictionary on tuples
43+
self.loop_pose = str(loop_pose) # Loop relinks to chain at this pose
44+
self.locked_pose = str(locked_pose) # Pose locked at this pose
45+
self.meas_dict = {}
46+
for i in range(0, N):
47+
R_pert = so3op.vec2rot(sigma * np.random.randn(3, 1))
48+
if i == N - 1:
49+
if loop_pose > 0:
50+
j = loop_pose
51+
else:
52+
continue
53+
else:
54+
j = i + 1
55+
self.meas_dict[(str(i), str(j))] = R_pert @ R_gt[i] @ R_gt[j].T
56+
# Store data
57+
self.R_gt = R_gt
58+
self.N = N
59+
self.sigma = sigma
60+
# Define obj and constraints
61+
self.C = self.define_objective()
62+
self.As = self.define_constraints()
63+
64+
def define_objective(self) -> PolyMatrix:
65+
"""Get the cost matrix associated with the problem. Assume equal weighting
66+
for all measurments
67+
"""
68+
Q = PolyMatrix()
69+
# Construct matrix from measurements
70+
for i, j in self.meas_dict.keys():
71+
Q += self.get_rel_cost_mat(i, j)
72+
73+
# # Add prior measurement on first pose (tightens the relaxation)
74+
# Q += self.get_prior_cost_mat(self, 1)
75+
76+
return Q
77+
78+
def get_rel_cost_mat(self, i, j) -> PolyMatrix:
79+
"""Get cost representation for relative rotation measurement"""
80+
Q = PolyMatrix()
81+
if (i, j) in self.meas_dict.keys():
82+
meas = self.meas_dict[(i, j)]
83+
else:
84+
meas = self.meas_dict[(j, i)]
85+
Q[i, j] = -np.kron(np.eye(3), meas)
86+
Q[i, i] = 2 * np.eye(9)
87+
Q[j, j] = 2 * np.eye(9)
88+
return Q
89+
90+
def get_prior_cost_mat(self, index, weight=1) -> PolyMatrix:
91+
"""Get cost representation for prior measurement"""
92+
weight = self.N ^ 2
93+
index = 1
94+
Q = PolyMatrix()
95+
Q["h", "h"] += 6 * weight
96+
Q["h", index] += self.R_gt[index].reshape((9, 1), order="F").T * weight
97+
return Q
98+
99+
def define_constraints(self) -> list[PolyMatrix]:
100+
"""Generate all constraints for the problem"""
101+
constraints = []
102+
for key in self.var_sizes.keys():
103+
if key == "h":
104+
continue
105+
else:
106+
# Otherwise add rotation constriants
107+
constraints += self.get_O3_constraints(key)
108+
constraints += self.get_handedness_constraints(key)
109+
constraints += self.get_row_col_constraints(key)
110+
# lock the appropriate pose
111+
if key == self.locked_pose:
112+
constraints += self.get_locking_constraint(key)
113+
114+
return constraints
115+
116+
def get_locking_constraint(self, index):
117+
"""Get constraint that locks a particular pose to its ground truth value
118+
rather than adding a prior cost term. This should remove the gauge freedom
119+
from the problem, giving a rank-1 solution"""
120+
121+
r_gt = self.R_gt[int(index)].reshape((9, 1), order="F")
122+
constraints = []
123+
for k in range(9):
124+
A = PolyMatrix()
125+
e_k = np.zeros((1, 9))
126+
e_k[0, k] = 1
127+
A["h", index] = e_k / 2
128+
A["h", "h"] = -r_gt[k]
129+
constraints.append(A)
130+
return constraints
131+
132+
@staticmethod
133+
def get_O3_constraints(index):
134+
"""Generate O3 constraints for the problem"""
135+
constraints = []
136+
for k in range(3):
137+
for l in range(k, 3):
138+
A = PolyMatrix()
139+
E = np.zeros((3, 3))
140+
if k == l:
141+
E[k, l] = 1
142+
b = 1.0
143+
else:
144+
E[k, l] = 1
145+
E[l, k] = 1
146+
b = 0.0
147+
A[index, index] = np.kron(np.eye(3), E)
148+
A["h", "h"] = -b
149+
constraints.append(A)
150+
return constraints
151+
152+
@staticmethod
153+
def get_handedness_constraints(index):
154+
"""Generate Handedness Constraints - Equivalent to the determinant =1
155+
constraint for rotation matrices. See Tron,R et al:
156+
On the Inclusion of Determinant Constraints in Lagrangian Duality for 3D SLAM"""
157+
constraints = []
158+
i, j, k = 0, 1, 2
159+
for col_ind in range(3):
160+
l, m, n = 0, 1, 2
161+
for row_ind in range(3):
162+
# Define handedness matrix and vector
163+
mat = np.zeros((9, 9))
164+
mat[3 * j + m, 3 * k + n] = 1 / 2
165+
mat[3 * j + n, 3 * k + m] = -1 / 2
166+
mat = mat + mat.T
167+
vec = np.zeros((9, 1))
168+
vec[i * 3 + l] = -1 / 2
169+
# Create constraint
170+
A = PolyMatrix()
171+
A[index, index] = mat
172+
A[index, "h"] = vec
173+
constraints.append(A)
174+
# cycle row indices
175+
l, m, n = m, n, l
176+
# Cycle column indicies
177+
i, j, k = j, k, i
178+
return constraints
179+
180+
@staticmethod
181+
def get_row_col_constraints(index):
182+
"""Generate constraint that every row vector length equal every column vector length"""
183+
constraints = []
184+
for i in range(3):
185+
for j in range(3):
186+
A = PolyMatrix()
187+
c_col = np.zeros(9)
188+
ind = 3 * j + np.array([0, 1, 2])
189+
c_col[ind] = np.ones(3)
190+
c_row = np.zeros(9)
191+
ind = np.array([0, 3, 6]) + i
192+
c_row[ind] = np.ones(3)
193+
A[index, index] = np.diag(c_col - c_row)
194+
constraints.append(A)
195+
return constraints
196+
197+
@staticmethod
198+
def get_homog_constraint():
199+
"""generate homogenizing constraint"""
200+
A = PolyMatrix()
201+
A["h", "h"] = 1
202+
return [(A, 1.0)]
203+
204+
def convert_sdp_to_rot(self, X, er_min=ER_MIN):
205+
"""
206+
Converts a solution matrix to a list of rotations.
207+
208+
Parameters:
209+
- X: numpy.ndarray
210+
The solution matrix.
211+
212+
Returns:
213+
- R: list
214+
A list of rotation matrices.
215+
"""
216+
# Extract via SVD
217+
U, S, V = np.linalg.svd(X)
218+
# Eigenvalue ratio check
219+
assert S[0] / S[1] > er_min, ValueError("SDP is not Rank-1")
220+
x = U[:, 0] * np.sqrt(S[0])
221+
# Convert to list of rotations
222+
# R_vec = x[1:]
223+
R_vec = X[1:, 0]
224+
R_block = R_vec.reshape((3, -1), order="F")
225+
# Check determinant - Since we just take the first column, its possible for
226+
# the entire solution to be flipped
227+
if np.linalg.det(R_block[:, :3]) < 0:
228+
sign = -1
229+
else:
230+
sign = 1
231+
232+
R = {}
233+
cnt = 0
234+
for key in self.var_sizes.keys():
235+
if "h" == key:
236+
continue
237+
R[key] = sign * R_block[:, 3 * cnt : 3 * (cnt + 1)]
238+
cnt += 1
239+
return R
240+
241+
242+
def get_chain_rot_prob(N=10, locked_pose=0):
243+
return RotSynchLoopProblem(N=N, loop_pose=-1, locked_pose=locked_pose)
244+
245+
246+
def get_loop_rot_prob():
247+
return RotSynchLoopProblem()

cert_tools/problems/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)