Skip to content

Commit ede48e3

Browse files
add check for valid dimension name before creating new variable (#360)
* fix: unsupported dimension names (#359) Add function to ensure that added variables do not use unvalid dimension names. * Update linopy/model.py More condensed check of dim_names Co-authored-by: Lukas Trippe <lkstrp@pm.me> * Update linopy/model.py remove unnecessary else block Co-authored-by: Lukas Trippe <lkstrp@pm.me> * Update linopy/model.py make check_valid_dim_names private Co-authored-by: Lukas Trippe <lkstrp@pm.me> * Update linopy/model.py make check_valid_dim_names private Co-authored-by: Lukas Trippe <lkstrp@pm.me> --------- Co-authored-by: Lukas Trippe <lkstrp@pm.me>
1 parent 30ef9ea commit ede48e3

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

linopy/model.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,32 @@ def check_force_dim_names(self, ds: DataArray | Dataset) -> None:
367367
else:
368368
return
369369

370+
def _check_valid_dim_names(self, ds: DataArray | Dataset) -> None:
371+
"""
372+
Ensure that the added data does not lead to a naming conflict.
373+
374+
Parameters
375+
----------
376+
model : linopy.Model
377+
ds : xr.DataArray/Variable/LinearExpression
378+
Data that should be added to the model.
379+
380+
Raises
381+
------
382+
ValueError
383+
If broadcasted data leads to unsupported dimension names.
384+
385+
Returns
386+
-------
387+
None.
388+
"""
389+
unsupported_dim_names = ["labels", "coeffs", "vars", "sign", "rhs"]
390+
if any(dim in unsupported_dim_names for dim in ds.dims):
391+
raise ValueError(
392+
"Added data contains unsupported dimension names. "
393+
"Dimensions cannot be named 'labels', 'coeffs', 'vars', 'sign' or 'rhs'."
394+
)
395+
370396
def add_variables(
371397
self,
372398
lower: Any = -inf,
@@ -474,6 +500,7 @@ def add_variables(
474500
)
475501
(data,) = xr.broadcast(data)
476502
self.check_force_dim_names(data)
503+
self._check_valid_dim_names(data)
477504

478505
if mask is not None:
479506
mask = as_dataarray(mask, coords=data.coords, dims=data.dims).astype(bool)

test/test_variable_assignment.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,15 @@ def test_variable_assignment_without_coords_and_dims_names():
131131
assert x.dims == ("i", "j")
132132

133133

134+
def test_variable_assignment_without_coords_and_invalid_dims_names():
135+
# setting bounds without explicit coords
136+
m = Model()
137+
lower = np.zeros((10, 10))
138+
upper = np.ones((10, 10))
139+
with pytest.raises(ValueError):
140+
m.add_variables(lower, upper, name="x", dims=["sign", "j"])
141+
142+
134143
def test_variable_assignment_without_coords_in_bounds():
135144
# setting bounds without explicit coords
136145
m = Model()
@@ -141,6 +150,15 @@ def test_variable_assignment_without_coords_in_bounds():
141150
assert x.dims == ("i", "j")
142151

143152

153+
def test_variable_assignment_without_coords_in_bounds_invalid_dims_names():
154+
# setting bounds without explicit coords
155+
m = Model()
156+
lower = xr.DataArray(np.zeros((10, 10)), dims=["i", "sign"])
157+
upper = xr.DataArray(np.ones((10, 10)), dims=["i", "sign"])
158+
with pytest.raises(ValueError):
159+
m.add_variables(lower, upper, name="x")
160+
161+
144162
def test_variable_assignment_without_coords_pandas_types():
145163
# setting bounds without explicit coords
146164
m = Model()

0 commit comments

Comments
 (0)