Skip to content

Commit 99105e1

Browse files
committed
[SymForce] Improve support for 0-size matrices
This fixes: ```py m = sf.Matrix(0, 0) list(m) ``` Which currently throws `ZeroDivisionError`. For `__getitem__`-based iteration, Python calls `__getitem__` in sequence starting from 0, and expects `IndexError` to know when to stop iterating. I fixed a couple other things also; geo_matrix_test now sorta works with empty matrices, but I haven't tried to 100% test everything. Ops don't work, often because they assume creating a (N, 1) Matrix from its storage vector will produce something of the right shape, but for a (0, 1) matrix this isn't true. I've also replaced all the user-reachable assertions (i.e. not SymForce bugs) in matrix.py with exceptions (TypeError and ValueError). Topic: sf-fix-symengine-empty-matrix Reviewers: emil,brad,nathan,ryan-b GitOrigin-RevId: 3a60aabfb5b0e06a0f35418f523582037057b5e9
1 parent 2ba2f49 commit 99105e1

File tree

4 files changed

+147
-73
lines changed

4 files changed

+147
-73
lines changed

symforce/geo/matrix.py

Lines changed: 71 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class Matrix(Storage):
6464
# this class variable as a strong internal consistency check.
6565
SHAPE = (-1, -1)
6666

67-
def __new__(cls, *args: _T.Any, **kwargs: _T.Any) -> Matrix: # noqa: PLR0915
67+
def __new__(cls, *args: _T.Any, **kwargs: _T.Any) -> Matrix: # noqa: PLR0912, PLR0915
6868
"""
6969
Beast of a method for creating a Matrix. Handles a variety of construction use cases
7070
and *always* returns a fixed size child class of Matrix rather than Matrix itself. The
@@ -74,17 +74,17 @@ def __new__(cls, *args: _T.Any, **kwargs: _T.Any) -> Matrix: # noqa: PLR0915
7474

7575
# 1) Default construction allowed for fixed size.
7676
if len(args) == 0:
77-
assert cls._is_fixed_size(), "Cannot default construct non-fixed matrix."
77+
if not cls._is_fixed_size():
78+
raise TypeError("Cannot default construct non-fixed matrix.")
7879
return cls.zero()
7980

8081
# 2) Construct with another Matrix - this is easy
8182
elif len(args) == 1 and hasattr(args[0], "is_Matrix") and args[0].is_Matrix:
8283
rows, cols = args[0].shape
83-
if cls._is_fixed_size():
84-
assert cls.SHAPE == (
85-
rows,
86-
cols,
87-
), f"Inconsistent shape: expected shape {cls.SHAPE} but found shape {(rows, cols)}"
84+
if cls._is_fixed_size() and cls.SHAPE != (rows, cols):
85+
raise ValueError(
86+
f"Inconsistent shape: expected shape {cls.SHAPE} but found shape {(rows, cols)}"
87+
)
8888
flat_list = list(args[0])
8989

9090
# 3) If there's one argument and it's an array, works for fixed or dynamic size.
@@ -93,18 +93,15 @@ def __new__(cls, *args: _T.Any, **kwargs: _T.Any) -> Matrix: # noqa: PLR0915
9393
# 2D array, shape is known
9494
if len(array) > 0 and isinstance(array[0], (_T.Sequence, np.ndarray)):
9595
# 2D array of scalars
96-
assert not isinstance(array[0][0], Matrix), (
97-
"Use Matrix.block_matrix to construct using matrices"
98-
)
96+
if isinstance(array[0][0], Matrix):
97+
raise TypeError("Use Matrix.block_matrix to construct using matrices")
9998
rows, cols = len(array), len(array[0])
100-
if cls._is_fixed_size():
101-
assert (
102-
rows,
103-
cols,
104-
) == cls.SHAPE, f"{cls} has shape {cls.SHAPE} but arg has shape {(rows, cols)}"
105-
assert all(len(arr) == cols for arr in array), "Inconsistent columns: {}".format(
106-
args
107-
)
99+
if cls._is_fixed_size() and (rows, cols) != cls.SHAPE:
100+
raise ValueError(
101+
f"{cls} has shape {cls.SHAPE} but arg has shape {(rows, cols)}"
102+
)
103+
if not all(len(arr) == cols for arr in array):
104+
raise ValueError(f"Inconsistent columns: {args}")
108105
flat_list = [v for row in array for v in row]
109106

110107
# 1D array - if fixed size this must match data length. If not, assume column vec.
@@ -130,24 +127,31 @@ def __new__(cls, *args: _T.Any, **kwargs: _T.Any) -> Matrix: # noqa: PLR0915
130127
# to an sm.Matrix, do the operation, then convert back.
131128
elif len(args) == 2 and cls.SHAPE == (-1, -1):
132129
rows, cols = args[0], args[1]
133-
assert isinstance(rows, int)
134-
assert isinstance(cols, int)
130+
if not isinstance(rows, int) or rows < 0:
131+
raise ValueError(f"rows must be a nonnegative integer, got {rows}")
132+
if not isinstance(cols, int) or cols < 0:
133+
raise ValueError(f"cols must be a nonnegative integer, got {cols}")
135134
flat_list = [0 for row in range(rows) for col in range(cols)]
136135

137136
# 5) If there are two integer arguments and then a sequence, treat this as a shape and a
138137
# data list directly.
139138
elif len(args) == 3 and isinstance(args[-1], (np.ndarray, _T.Sequence)):
140-
assert isinstance(args[0], int), args
141-
assert isinstance(args[1], int), args
139+
if not isinstance(args[0], int) or args[0] < 0:
140+
raise ValueError(f"rows must be a nonnegative integer, got {args[0]}")
141+
if not isinstance(args[1], int) or args[1] < 0:
142+
raise ValueError(f"cols must be a nonnegative integer, got {args[1]}")
142143
rows, cols = args[0], args[1]
143-
assert len(args[2]) == rows * cols, f"Inconsistent args: {args}"
144+
if len(args[2]) != rows * cols:
145+
raise ValueError(f"Inconsistent args: {args}")
144146
flat_list = list(args[2])
145147

146148
# 6) Two integer arguments plus a callable to initialize values based on (row, col)
147149
# NOTE(hayk): sympy.Symbol is callable, hence the last check.
148150
elif len(args) == 3 and callable(args[-1]) and not hasattr(args[-1], "is_Symbol"):
149-
assert isinstance(args[0], int), args
150-
assert isinstance(args[1], int), args
151+
if not isinstance(args[0], int) or args[0] < 0:
152+
raise ValueError(f"rows must be a nonnegative integer, got {args[0]}")
153+
if not isinstance(args[1], int) or args[1] < 0:
154+
raise ValueError(f"cols must be a nonnegative integer, got {args[1]}")
151155
rows, cols = args[0], args[1]
152156
flat_list = [args[2](row, col) for row in range(rows) for col in range(cols)]
153157

@@ -160,7 +164,7 @@ def __new__(cls, *args: _T.Any, **kwargs: _T.Any) -> Matrix: # noqa: PLR0915
160164

161165
# 8) No match, error out.
162166
else:
163-
raise AssertionError(f"Unknown {cls} constructor for: {args}")
167+
raise ValueError(f"Unknown {cls} constructor for: {args}")
164168

165169
# Get the proper fixed size child class
166170
fixed_size_type = matrix_type_from_shape((rows, cols))
@@ -177,7 +181,9 @@ def __init__(self, *args: _T.Any, **kwargs: _T.Any) -> None:
177181
if _T.TYPE_CHECKING:
178182
self.mat = sf.sympy.Matrix(*args, **kwargs)
179183

180-
assert self.__class__.SHAPE == self.mat.shape, "Inconsistent Matrix"
184+
assert self.__class__.SHAPE == self.mat.shape, (
185+
f"Inconsistent Matrix: {self.__class__.SHAPE} != {self.mat.shape}"
186+
)
181187

182188
@property
183189
def rows(self) -> int:
@@ -207,14 +213,16 @@ def __repr__(self) -> str:
207213

208214
@classmethod
209215
def storage_dim(cls) -> int:
210-
assert cls._is_fixed_size(), f"Type has no size info: {cls}"
216+
if not cls._is_fixed_size():
217+
raise TypeError(f"Type has no size info: {cls}")
211218
return cls.SHAPE[0] * cls.SHAPE[1]
212219

213220
@classmethod
214221
def from_storage(
215222
cls: _T.Type[MatrixT], vec: _T.Union[_T.Sequence[_T.Scalar], Matrix]
216223
) -> MatrixT:
217-
assert cls._is_fixed_size(), f"Type has no size info: {cls}"
224+
if not cls._is_fixed_size():
225+
raise TypeError(f"Type has no size info: {cls}")
218226
if isinstance(vec, Matrix):
219227
vec = list(vec)
220228
rows, cols = cls.SHAPE
@@ -251,7 +259,8 @@ def zero(cls: _T.Type[MatrixT]) -> MatrixT:
251259
"""
252260
Matrix of zeros.
253261
"""
254-
assert cls._is_fixed_size(), f"Type has no size info: {cls}"
262+
if not cls._is_fixed_size():
263+
raise TypeError(f"Type has no size info: {cls}")
255264
return cls.zeros(*cls.SHAPE)
256265

257266
@classmethod
@@ -269,7 +278,8 @@ def one(cls: _T.Type[MatrixT]) -> MatrixT:
269278
"""
270279
Matrix of ones.
271280
"""
272-
assert cls._is_fixed_size(), f"Type has no size info: {cls}"
281+
if not cls._is_fixed_size():
282+
raise TypeError(f"Type has no size info: {cls}")
273283
return cls.ones(*cls.SHAPE)
274284

275285
@classmethod
@@ -359,14 +369,19 @@ def symbolic(cls: _T.Type[MatrixT], name: str, **kwargs: _T.Any) -> MatrixT:
359369
name (str): Name prefix of the symbols
360370
**kwargs (dict): Forwarded to `sf.Symbol`
361371
"""
362-
assert cls._is_fixed_size(), f"Type has no size info: {cls}"
372+
if not cls._is_fixed_size():
373+
raise TypeError(f"Type has no size info: {cls}")
363374
rows, cols = cls.SHAPE
364375

365376
row_names = [str(r_i) for r_i in range(rows)]
366377
col_names = [str(c_i) for c_i in range(cols)]
367378

368-
assert len(row_names) == rows
369-
assert len(col_names) == cols
379+
if len(row_names) != rows:
380+
raise ValueError(f"Number of row names {len(row_names)} does not match rows {rows}")
381+
if len(col_names) != cols:
382+
raise ValueError(
383+
f"Number of column names {len(col_names)} does not match columns {cols}"
384+
)
370385

371386
if cols == 1:
372387
if ops.StorageOps.use_latex_friendly_symbols():
@@ -377,21 +392,19 @@ def symbolic(cls: _T.Type[MatrixT], name: str, **kwargs: _T.Any) -> MatrixT:
377392
symbols = []
378393
for r_i in range(rows):
379394
_name = format_string.format(name, row_names[r_i])
380-
symbols.append([sf.Symbol(_name, **kwargs)])
395+
symbols.append(sf.Symbol(_name, **kwargs))
381396
else:
382397
if ops.StorageOps.use_latex_friendly_symbols():
383398
format_string = "{}_{{{}, {}}}"
384399
else:
385400
format_string = "{}[{}, {}]"
386401
symbols = []
387402
for r_i in range(rows):
388-
col_symbols = []
389403
for c_i in range(cols):
390404
_name = format_string.format(name, row_names[r_i], col_names[c_i])
391-
col_symbols.append(sf.Symbol(_name, **kwargs))
392-
symbols.append(col_symbols)
405+
symbols.append(sf.Symbol(_name, **kwargs))
393406

394-
return cls(sf.sympy.Matrix(symbols))
407+
return cls(sf.sympy.Matrix(rows, cols, symbols))
395408

396409
def row_join(self, right: Matrix) -> Matrix:
397410
"""
@@ -426,17 +439,19 @@ def block_matrix(cls, array: _T.Sequence[_T.Sequence[Matrix]]) -> Matrix:
426439
block_rows = mat_row[0].shape[0]
427440
block_cols = 0
428441
for mat in mat_row:
429-
assert mat.shape[0] == block_rows, (
430-
"Inconsistent row number accross block: expected {} got {}".format(
431-
block_rows, mat.shape[0]
442+
if mat.shape[0] != block_rows:
443+
raise ValueError(
444+
"Inconsistent row number accross block: expected {}, got {}".format(
445+
block_rows, mat.shape[0]
446+
)
432447
)
433-
)
434448
block_cols += mat.shape[1]
435-
assert block_cols == cols, (
436-
"Inconsistent column number accross block: expected {} got {}".format(
437-
cols, block_cols
449+
if block_cols != cols:
450+
raise ValueError(
451+
"Inconsistent column number accross block: expected {}, got {}".format(
452+
cols, block_cols
453+
)
438454
)
439-
)
440455

441456
# Fill the new matrix data vector
442457
flat_list = []
@@ -625,7 +640,8 @@ def multiply_elementwise(self: MatrixT, rhs: MatrixT) -> MatrixT:
625640
Do the elementwise multiplication between self and rhs, and return the result as a new
626641
:class:`Matrix`
627642
"""
628-
assert self.shape == rhs.shape
643+
if self.shape != rhs.shape:
644+
raise TypeError(f"Cannot multiply elementwise: shapes {self.shape} and {rhs.shape}")
629645
return self.__class__(self.mat.multiply_elementwise(rhs.mat))
630646

631647
def applyfunc(self: MatrixT, func: _T.Callable) -> MatrixT:
@@ -898,7 +914,8 @@ def to_flat_list(self) -> _T.List[_T.Scalar]:
898914

899915
@classmethod
900916
def from_flat_list(cls, vec: _T.Sequence[_T.Scalar]) -> Matrix:
901-
assert cls._is_fixed_size(), f"Type has no size info: {cls}"
917+
if not cls._is_fixed_size():
918+
raise TypeError(f"Type has no size info: {cls}")
902919
return cls(vec)
903920

904921
def to_numpy(self, scalar_type: type = np.float64) -> np.ndarray:
@@ -920,16 +937,17 @@ def column_stack(cls, *columns: Matrix) -> Matrix:
920937

921938
for col in columns:
922939
# assert that each column is a vector
923-
assert col.shape == columns[0].shape
924-
assert sum(dim > 1 for dim in col.shape) <= 1
940+
if col.shape != columns[0].shape or sum(dim > 1 for dim in col.shape) > 1:
941+
raise TypeError(f"Column has shape {col.shape}, should be a vector (N, 1)")
925942

926943
return cls([col.to_flat_list() for col in columns]).T
927944

928945
def is_vector(self) -> bool:
929946
return (self.shape[0] == 1) or (self.shape[1] == 1)
930947

931948
def _assert_is_vector(self) -> None:
932-
assert self.is_vector(), "Not a vector."
949+
if not self.is_vector():
950+
raise TypeError(f"Not a vector, shape {self.shape}")
933951

934952
def _assert_sanity(self) -> None:
935953
assert self.shape == self.SHAPE, "Inconsistent Matrix!. shape={}, SHAPE={}".format(
@@ -945,7 +963,7 @@ def _is_fixed_size(cls) -> bool:
945963
Return ``True`` if this is a type with fixed dimensions set, e.g. :class:`Matrix31` instead
946964
of :class:`Matrix`.
947965
"""
948-
return cls.SHAPE[0] > 0 and cls.SHAPE[1] > 0
966+
return cls.SHAPE[0] > -1 and cls.SHAPE[1] > -1
949967

950968
def _ipython_display_(self) -> None: # noqa: PLW3201
951969
"""

test/geo_matrix_test.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@ def test_construction(self) -> None:
3636
# 2) Matrix(sf.sympy.Matrix([[1, 2], [3, 4]])) # Matrix22 with [1, 2, 3, 4] data
3737
self.assertIsInstance(sf.M(sf.sympy.Matrix([[1, 2], [3, 4]])), sf.M22)
3838
self.assertEqual(sf.M(sf.sympy.Matrix([[1, 2], [3, 4]])), sf.M([[1, 2], [3, 4]]))
39-
self.assertRaises(AssertionError, lambda: sf.V3(sf.V2()))
39+
self.assertRaises(ValueError, lambda: sf.V3(sf.V2()))
4040

4141
# 3A) Matrix([[1, 2], [3, 4]]) # Matrix22 with [1, 2, 3, 4] data
4242
self.assertIsInstance(sf.M([[1, 2], [3, 4]]), sf.M22)
4343
self.assertEqual(sf.M([[1, 2], [3, 4]]), sf.M22([1, 2, 3, 4]))
44-
self.assertRaises(AssertionError, lambda: sf.M([[1, 2], [3, 4, 5]]))
45-
self.assertRaises(AssertionError, lambda: sf.M([[sf.M22(), sf.M23()]]))
46-
self.assertRaises(AssertionError, lambda: sf.M11([[1, 2]]))
44+
self.assertRaises(ValueError, lambda: sf.M([[1, 2], [3, 4, 5]]))
45+
self.assertRaises(TypeError, lambda: sf.M([[sf.M22(), sf.M23()]]))
46+
self.assertRaises(ValueError, lambda: sf.M11([[1, 2]]))
4747

4848
# 3B) Matrix22([1, 2, 3, 4]) # Matrix22 with [1, 2, 3, 4] data (must matched fixed shape)
4949
self.assertIsInstance(sf.M22([1, 2, 3, 4]), sf.M22)
@@ -73,8 +73,8 @@ def test_construction(self) -> None:
7373
self.assertEqual(sf.M21(4, 3), sf.M([[4], [3]]))
7474
self.assertEqual(sf.M12(4, 3), sf.M([[4, 3]]))
7575
self.assertEqual(sf.M(4, 3), sf.M.zeros(4, 3))
76-
self.assertRaises(AssertionError, lambda: sf.M22(1, 2, 3))
77-
self.assertRaises(AssertionError, lambda: sf.M22(1, 2, 3, 4, 5))
76+
self.assertRaises(ValueError, lambda: sf.M22(1, 2, 3))
77+
self.assertRaises(ValueError, lambda: sf.M22(1, 2, 3, 4, 5))
7878

7979
# Test large size (not statically defined)
8080
self.assertEqual(type(sf.M(12, 4)).__name__, "Matrix12_4")
@@ -190,8 +190,8 @@ def test_symbolic_operations(self) -> None:
190190
self.assertStorageNear(numpy_mat, geo_mat.to_numpy())
191191

192192
# Make sure we assert when calling a method that expects fixed size on sf.M
193-
self.assertRaises(AssertionError, lambda: sf.M.symbolic("C"))
194-
self.assertRaises(AssertionError, lambda: sf.M.from_storage([1, 2, 3]))
193+
self.assertRaises(TypeError, lambda: sf.M.symbolic("C"))
194+
self.assertRaises(TypeError, lambda: sf.M.from_storage([1, 2, 3]))
195195

196196
def test_constructor_helpers(self) -> None:
197197
"""
@@ -208,7 +208,7 @@ def test_constructor_helpers(self) -> None:
208208

209209
rand_vec_long = np.random.rand(i + 2)
210210
self.assertRaises(ValueError, vec, rand_vec_long)
211-
self.assertRaises(AssertionError, vec, *rand_vec_long)
211+
self.assertRaises(ValueError, vec, *rand_vec_long)
212212

213213
eye_matrix_constructors = [sf.I1, sf.I2, sf.I3, sf.I4, sf.I5, sf.I6]
214214
for i, mat in enumerate(eye_matrix_constructors):
@@ -268,8 +268,8 @@ def test_block_matrix(self) -> None:
268268
self.assertEqual(
269269
sf.M.block_matrix([[M22, M21], [M13]]), sf.M([[1, 1, 5], [1, 1, 5], [6, 6, 6]])
270270
)
271-
self.assertRaises(AssertionError, lambda: sf.M.block_matrix([[M22, M23], [M11, sf.M15()]]))
272-
self.assertRaises(AssertionError, lambda: sf.M.block_matrix([[M22, sf.M33()], [M11, M14]]))
271+
self.assertRaises(ValueError, lambda: sf.M.block_matrix([[M22, M23], [M11, sf.M15()]]))
272+
self.assertRaises(ValueError, lambda: sf.M.block_matrix([[M22, sf.M33()], [M11, M14]]))
273273

274274
def test_transpose(self) -> None:
275275
"""
@@ -381,7 +381,7 @@ def test_multiply_elementwise(self) -> None:
381381

382382
self.assertEqual(a.multiply_elementwise(b), expected_result)
383383

384-
with self.assertRaises(AssertionError):
384+
with self.assertRaises(TypeError):
385385
# This should fail mypy, since it's actually wrong
386386
a.multiply_elementwise(sf.M43()) # type: ignore
387387

@@ -489,6 +489,39 @@ def test_vector_methods(self) -> None:
489489
self.assertEqual(sf.V3.unit_z().y, 0)
490490
self.assertEqual(sf.V3.unit_z().z, 1)
491491

492+
def test_empty_matrix(self) -> None:
493+
"""
494+
Tests some basic operations on empty matrices
495+
496+
TODO(aaron): Test more operations
497+
"""
498+
element = sf.Matrix(0, 0)
499+
500+
dims = element.SHAPE
501+
self.assertEqual(element.zero(), sf.Matrix.zeros(dims[0], dims[1]))
502+
self.assertEqual(element.one(), sf.Matrix.ones(dims[0], dims[1]))
503+
504+
self.assertEqual(sf.Matrix(0, 1).shape, (0, 1))
505+
self.assertEqual(sf.Matrix(1, 0).shape, (1, 0))
506+
self.assertEqual(sf.Matrix(0, 0).shape, (0, 0))
507+
508+
self.assertEqual(sf.Matrix(0, 1).symbolic("x"), sf.Matrix(0, 1))
509+
self.assertEqual(sf.Matrix(1, 0).symbolic("x"), sf.Matrix(1, 0))
510+
self.assertEqual(sf.Matrix(0, 0).symbolic("x"), sf.Matrix(0, 0))
511+
512+
self.assertEqual(list(sf.Matrix(0, 1)), [])
513+
self.assertEqual(list(sf.Matrix(1, 0)), [])
514+
self.assertEqual(list(sf.Matrix(0, 0)), [])
515+
516+
if symforce.get_symbolic_api() == "sympy":
517+
self.assertEqual(str(sf.Matrix(1, 0)), "Matrix(1, 0, [])")
518+
self.assertEqual(str(sf.Matrix(0, 1)), "Matrix(0, 1, [])")
519+
self.assertEqual(str(sf.Matrix(0, 0)), "Matrix(0, 0, [])")
520+
else:
521+
self.assertEqual(str(sf.Matrix(1, 0)), "[]\n")
522+
self.assertEqual(str(sf.Matrix(0, 1)), "")
523+
self.assertEqual(str(sf.Matrix(0, 0)), "")
524+
492525

493526
if __name__ == "__main__":
494527
TestCase.main()

0 commit comments

Comments
 (0)