Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Release Notes
Upcoming Version
----------------

* Fix warning when multiplying variables with pd.Series containing time-zone aware index
* Fix docs (pick highs solver)
* Add the `sphinx-copybutton` to the documentation
* Add ``auto_mask`` parameter to ``Model`` class that automatically masks variables and constraints where bounds, coefficients, or RHS values contain NaN. This eliminates the need to manually create mask arrays when working with sparse or incomplete data.
Expand All @@ -20,7 +21,6 @@ Version 0.6.0
--------------

**Features**

* Add ``mock_solve`` option to ``Model.solve()`` for quick testing without actual solving
* Add support for SOS1 and SOS2 (Special Ordered Sets) constraints via ``Model.add_sos_constraints()`` and ``Model.remove_sos_constraints()``
* Add ``simplify`` method to ``LinearExpression`` to combine duplicate terms
Expand Down
70 changes: 67 additions & 3 deletions linopy/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from collections.abc import Callable, Generator, Hashable, Iterable, Sequence
from functools import partial, reduce, wraps
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
from typing import TYPE_CHECKING, Any, Generic, Literal, ParamSpec, TypeVar, overload
from warnings import warn

import numpy as np
import pandas as pd
import polars as pl
import xarray as xr
from numpy import arange, signedinteger
from xarray import DataArray, Dataset, apply_ufunc, broadcast
from xarray import align as xr_align
Expand Down Expand Up @@ -45,6 +46,48 @@
from linopy.variables import Variable


class CoordAlignWarning(UserWarning): ...


class TimezoneAlignError(ValueError): ...


P = ParamSpec("P")
R = TypeVar("R")


class CatchDatetimeTypeError:
"""Context manager that catches datetime-related TypeErrors and re-raises as TimezoneAlignError."""

def __enter__(self) -> CatchDatetimeTypeError:
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> Literal[False]:
if exc_type is TypeError and exc_val is not None:
if "Cannot interpret 'datetime" in str(exc_val):
raise TimezoneAlignError(
"Timezone information across datetime coordinates not aligned."
) from exc_val
return False


def catch_datetime_type_error_and_re_raise(func: Callable[P, R]) -> Callable[P, R]:
"""Decorator that catches datetime-related TypeErrors and re-raises as TimezoneAlignError."""

@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
with CatchDatetimeTypeError():
result = func(*args, **kwargs)
return result

return wrapper


def set_int_index(series: pd.Series) -> pd.Series:
"""
Convert string index to int index.
Expand Down Expand Up @@ -128,6 +171,21 @@ def get_from_iterable(lst: DimsLike | None, index: int) -> Any | None:
return lst[index] if 0 <= index < len(lst) else None


def try_to_convert_to_pd_datetime_index(
coord: xr.DataArray | Sequence | pd.Index | Any,
) -> pd.DatetimeIndex | xr.DataArray | Sequence | pd.Index | Any:
if isinstance(coord, pd.DatetimeIndex):
return coord
try:
if isinstance(coord, xr.DataArray):
index = coord.to_index()
assert isinstance(index, pd.DatetimeIndex)
return index
return pd.DatetimeIndex(coord)
except Exception:
return coord


def pandas_to_dataarray(
arr: pd.DataFrame | pd.Series,
coords: CoordsLike | None = None,
Expand Down Expand Up @@ -168,7 +226,10 @@ def pandas_to_dataarray(
shared_dims = set(pandas_coords.keys()) & set(coords.keys())
non_aligned = []
for dim in shared_dims:
pd_coord = pandas_coords[dim]
coord = coords[dim]
if isinstance(pd_coord, pd.DatetimeIndex):
coord = try_to_convert_to_pd_datetime_index(coord)
if not isinstance(coord, pd.Index):
coord = pd.Index(coord)
if not pandas_coords[dim].equals(coord):
Expand All @@ -178,7 +239,8 @@ def pandas_to_dataarray(
f"coords for dimension(s) {non_aligned} is not aligned with the pandas object. "
"Previously, the indexes of the pandas were ignored and overwritten in "
"these cases. Now, the pandas object's coordinates are taken considered"
" for alignment."
" for alignment.",
CoordAlignWarning,
)

return DataArray(arr, coords=None, dims=dims, **kwargs)
Expand Down Expand Up @@ -468,6 +530,7 @@ def maybe_group_terms_polars(df: pl.DataFrame) -> pl.DataFrame:
return df.select(keys + ["coeffs"] + rest)


@catch_datetime_type_error_and_re_raise
def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset:
"""
Join multiple xarray Dataarray's to a Dataset and warn if coordinates are not equal.
Expand All @@ -477,14 +540,15 @@ def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset:
except ValueError:
warn(
"Coordinates across variables not equal. Perform outer join.",
UserWarning,
CoordAlignWarning,
)
arrs = xr_align(*dataarrays, join="outer")
if integer_dtype:
arrs = tuple([ds.fillna(-1).astype(int) for ds in arrs])
return Dataset({ds.name: ds for ds in arrs})


@catch_datetime_type_error_and_re_raise
def assign_multiindex_safe(ds: Dataset, **fields: Any) -> Dataset:
"""
Assign a field to a xarray Dataset while being safe against warnings about multiindex corruption.
Expand Down
7 changes: 5 additions & 2 deletions linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
LocIndexer,
as_dataarray,
assign_multiindex_safe,
catch_datetime_type_error_and_re_raise,
check_common_keys_values,
check_has_nulls,
check_has_nulls_polars,
Expand Down Expand Up @@ -506,6 +507,7 @@ def __neg__(self: GenericExpression) -> GenericExpression:
"""
return self.assign_multiindex_safe(coeffs=-self.coeffs, const=-self.const)

@catch_datetime_type_error_and_re_raise
def _multiply_by_linear_expression(
self, other: LinearExpression | ScalarLinearExpression
) -> LinearExpression | QuadraticExpression:
Expand Down Expand Up @@ -533,6 +535,7 @@ def _multiply_by_linear_expression(
res = res + self.reset_const() * other.const
return res

@catch_datetime_type_error_and_re_raise
def _multiply_by_constant(
self: GenericExpression, other: ConstantLike
) -> GenericExpression:
Expand Down Expand Up @@ -1456,7 +1459,7 @@ def to_polars(self) -> pl.DataFrame:

The resulting DataFrame represents a long table format of the all
non-masked expressions with non-zero coefficients. It contains the
columns `coeffs`, `vars`, `const`. The coeffs and vars columns will be null if the expression is constant.
columns `vars`, `coeffs`, `const`. The coeffs and vars columns will be null if the expression is constant.

Returns
-------
Expand All @@ -1472,7 +1475,7 @@ def to_polars(self) -> pl.DataFrame:
df = filter_nulls_polars(df)
df = maybe_group_terms_polars(df)
check_has_nulls_polars(df, name=self.type)
return df
return df.select(["vars", "coeffs", "const"])

def simplify(self) -> LinearExpression:
"""
Expand Down
2 changes: 2 additions & 0 deletions linopy/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
LocIndexer,
as_dataarray,
assign_multiindex_safe,
catch_datetime_type_error_and_re_raise,
check_has_nulls,
check_has_nulls_polars,
filter_nulls_polars,
Expand Down Expand Up @@ -296,6 +297,7 @@ def loc(self) -> LocIndexer:
def to_pandas(self) -> pd.Series:
return self.labels.to_pandas()

@catch_datetime_type_error_and_re_raise
def to_linexpr(
self,
coefficient: ConstantLike = 1,
Expand Down
71 changes: 69 additions & 2 deletions test/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
@author: fabian
"""

from datetime import datetime
from zoneinfo import ZoneInfo

import numpy as np
import pandas as pd
import polars as pl
Expand All @@ -16,6 +19,7 @@

from linopy import LinearExpression, Model, Variable
from linopy.common import (
CoordAlignWarning,
align,
as_dataarray,
assign_multiindex_safe,
Expand All @@ -27,6 +31,8 @@
)
from linopy.testing import assert_linequal, assert_varequal

UTC = ZoneInfo("UTC")


def test_as_dataarray_with_series_dims_default() -> None:
target_dim = "dim_0"
Expand Down Expand Up @@ -74,6 +80,67 @@ def test_as_dataarray_with_series_dims_priority() -> None:
assert list(da.coords[target_dim].values) == target_index


def test_as_datarray_with_tz_aware_series_index() -> None:
time_index = pd.date_range(
start=datetime(2025, 1, 1),
freq="15min",
periods=4,
tz=UTC,
name="time",
)
other_index = pd.Index(name="time", data=[0, 1, 2, 3])

panda_series = pd.Series(index=time_index, data=1.0)

data_array = xr.DataArray(data=[0, 1, 2, 3], coords=[time_index])
result = as_dataarray(arr=panda_series, coords=data_array.coords)
assert time_index.equals(result.coords["time"].to_index())

data_array = xr.DataArray(data=[0, 1, 2, 3], coords=[other_index])
with pytest.warns(CoordAlignWarning):
result = as_dataarray(arr=panda_series, coords=data_array.coords)
assert time_index.equals(result.coords["time"].to_index())

coords = {"time": time_index}
result = as_dataarray(arr=panda_series, coords=coords)
assert time_index.equals(result.coords["time"].to_index())

coords = {"time": [0, 1, 2, 3]}
result = as_dataarray(arr=panda_series, coords=coords)
assert time_index.equals(result.coords["time"].to_index())


def test_as_datarray_with_tz_aware_dataframe_columns_index() -> None:
time_index = pd.date_range(
start=datetime(2025, 1, 1),
freq="15min",
periods=4,
tz=UTC,
name="time",
)
other_index = pd.Index(name="time", data=[0, 1, 2, 3])

index = pd.Index([0, 1, 2, 3], name="x")
pandas_df = pd.DataFrame(index=index, columns=time_index, data=1.0)

data_array = xr.DataArray(data=[0, 1, 2, 3], coords=[time_index])
result = as_dataarray(arr=pandas_df, coords=data_array.coords)
assert time_index.equals(result.coords["time"].to_index())

data_array = xr.DataArray(data=[0, 1, 2, 3], coords=[other_index])
with pytest.warns(CoordAlignWarning):
result = as_dataarray(arr=pandas_df, coords=data_array.coords)
assert time_index.equals(result.coords["time"].to_index())

coords = {"time": time_index}
result = as_dataarray(arr=pandas_df, coords=coords)
assert time_index.equals(result.coords["time"].to_index())

coords = {"time": [0, 1, 2, 3]}
result = as_dataarray(arr=pandas_df, coords=coords)
assert time_index.equals(result.coords["time"].to_index())


def test_as_dataarray_with_series_dims_subset() -> None:
target_dim = "dim_0"
target_index = ["a", "b", "c"]
Expand All @@ -100,7 +167,7 @@ def test_as_dataarray_with_series_override_coords() -> None:
target_dim = "dim_0"
target_index = ["a", "b", "c"]
s = pd.Series([1, 2, 3], index=target_index)
with pytest.warns(UserWarning):
with pytest.warns(CoordAlignWarning):
da = as_dataarray(s, coords=[[1, 2, 3]])
assert isinstance(da, DataArray)
assert da.dims == (target_dim,)
Expand Down Expand Up @@ -219,7 +286,7 @@ def test_as_dataarray_dataframe_override_coords() -> None:
target_index = ["a", "b"]
target_columns = ["A", "B"]
df = pd.DataFrame([[1, 2], [3, 4]], index=target_index, columns=target_columns)
with pytest.warns(UserWarning):
with pytest.warns(CoordAlignWarning):
da = as_dataarray(df, coords=[[1, 2], [2, 3]])
assert isinstance(da, DataArray)
assert da.dims == target_dims
Expand Down
30 changes: 30 additions & 0 deletions test/test_linear_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

from __future__ import annotations

from datetime import datetime
from zoneinfo import ZoneInfo

import numpy as np
import pandas as pd
import polars as pl
Expand All @@ -15,11 +18,14 @@
from xarray.testing import assert_equal

from linopy import LinearExpression, Model, QuadraticExpression, Variable, merge
from linopy.common import TimezoneAlignError
from linopy.constants import HELPER_DIMS, TERM_DIM
from linopy.expressions import ScalarLinearExpression
from linopy.testing import assert_linequal, assert_quadequal
from linopy.variables import ScalarVariable

UTC = ZoneInfo("UTC")


@pytest.fixture
def m() -> Model:
Expand Down Expand Up @@ -1230,6 +1236,30 @@ def test_cumsum(m: Model, multiple: float) -> None:
cumsum.nterm == 2


def test_timezone_alignment_failure() -> None:
utc_index = pd.date_range(
start=datetime(2025, 1, 1),
freq="15min",
periods=4,
tz=UTC,
name="time",
)
tz_naive_index = pd.date_range(
start=datetime(2025, 1, 1),
freq="15min",
periods=4,
tz=None,
name="time",
)
model = Model()
series1 = pd.Series(index=tz_naive_index, data=1.0)
expr = model.add_variables(coords=[utc_index], name="var1") * 1.0

with pytest.raises(TimezoneAlignError):
# We expect to get a useful error (TimezoneAlignError) instead of a not implemented error falsely claiming that we cannot multiply these types together
_ = expr * series1


def test_simplify_basic(x: Variable) -> None:
"""Test basic simplification with duplicate terms."""
expr = 2 * x + 3 * x + 1 * x
Expand Down
Loading