Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compatibility with numpy 2 #382

Merged
merged 4 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
## Changelog

### 3.0.2 - 27/Jul/2024

- Adds compatibility with numpy v2 by replacing deprecated types (thanks [@MartinCapraro](https://github.com/MartinCapraro))

### 3.0.1 - 23/Jul/2024

- Temporarily require numpy<2
Expand Down
8 changes: 4 additions & 4 deletions cxroots/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def __init__(self, segments: list[ComplexPathType]):
def __call__(self, t: float) -> complex: ...

@overload
def __call__(self, t: npt.NDArray[np.float_]) -> npt.NDArray[np.complex_]: ...
def __call__(self, t: npt.NDArray[np.float64]) -> npt.NDArray[np.complex128]: ...

def __call__(
self, t: float | npt.NDArray[np.float_]
) -> complex | npt.NDArray[np.complex_]:
self, t: float | npt.NDArray[np.float64]
) -> complex | npt.NDArray[np.complex128]:
r"""
The point on the contour corresponding the value of the
parameter t.
Expand All @@ -75,7 +75,7 @@ def __call__(
>>> c(0) == c(1)
True
"""
t = np.array(t, dtype=np.float_)
t = np.array(t, dtype=np.float64)
num_segments = len(self.segments)
segment_index = np.array(num_segments * t, dtype=int)
segment_index = np.mod(segment_index, num_segments)
Expand Down
2 changes: 1 addition & 1 deletion cxroots/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def central_diff(

@overload
def df(
z: npt.NDArray[np.complex_] | npt.NDArray[np.float_],
z: npt.NDArray[np.complex128] | npt.NDArray[np.float64],
) -> ComplexScalarOrArray: ...

@overload
Expand Down
36 changes: 18 additions & 18 deletions cxroots/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def __init__(self):
def __call__(self, t: float) -> complex: ...

@overload
def __call__(self, t: npt.NDArray[np.float_]) -> npt.NDArray[np.complex_]: ...
def __call__(self, t: npt.NDArray[np.float64]) -> npt.NDArray[np.complex128]: ...

def __call__(
self, t: float | npt.NDArray[np.float_]
) -> complex | npt.NDArray[np.complex_]:
self, t: float | npt.NDArray[np.float64]
) -> complex | npt.NDArray[np.complex128]:
r"""
The parameterization of the path in the varaible :math:`t\in[0,1]`.

Expand All @@ -51,11 +51,11 @@ def __call__(
def dzdt(self, t: float) -> complex: ...

@overload
def dzdt(self, t: npt.NDArray[np.float_]) -> npt.NDArray[np.complex_]: ...
def dzdt(self, t: npt.NDArray[np.float64]) -> npt.NDArray[np.complex128]: ...

def dzdt(
self, t: float | npt.NDArray[np.float_]
) -> complex | npt.NDArray[np.complex_]:
self, t: float | npt.NDArray[np.float64]
) -> complex | npt.NDArray[np.complex128]:
"""
The derivative of the parameterised curve in the complex plane, z, with
respect to the parameterization parameter, t.
Expand Down Expand Up @@ -341,11 +341,11 @@ def __str__(self):
def __call__(self, t: float) -> complex: ...

@overload
def __call__(self, t: npt.NDArray[np.float_]) -> npt.NDArray[np.complex_]: ...
def __call__(self, t: npt.NDArray[np.float64]) -> npt.NDArray[np.complex128]: ...

def __call__(
self, t: float | npt.NDArray[np.float_]
) -> complex | npt.NDArray[np.complex_]:
self, t: float | npt.NDArray[np.float64]
) -> complex | npt.NDArray[np.complex128]:
r"""
The function :math:`z(t) = a + (b-a)t`.

Expand All @@ -365,11 +365,11 @@ def __call__(
def dzdt(self, t: float) -> complex: ...

@overload
def dzdt(self, t: npt.NDArray[np.float_]) -> npt.NDArray[np.complex_]: ...
def dzdt(self, t: npt.NDArray[np.float64]) -> npt.NDArray[np.complex128]: ...

def dzdt(
self, t: float | npt.NDArray[np.float_]
) -> complex | npt.NDArray[np.complex_]:
self, t: float | npt.NDArray[np.float64]
) -> complex | npt.NDArray[np.complex128]:
"""
The derivative of the parameterised curve in the complex plane, z, with
respect to the parameterization parameter, t.
Expand Down Expand Up @@ -434,11 +434,11 @@ def __str__(self):
def __call__(self, t: float) -> complex: ...

@overload
def __call__(self, t: npt.NDArray[np.float_]) -> npt.NDArray[np.complex_]: ...
def __call__(self, t: npt.NDArray[np.float64]) -> npt.NDArray[np.complex128]: ...

def __call__(
self, t: float | npt.NDArray[np.float_]
) -> complex | npt.NDArray[np.complex_]:
self, t: float | npt.NDArray[np.float64]
) -> complex | npt.NDArray[np.complex128]:
r"""
The function :math:`z(t) = R e^{i(t_0 + t dt)} + z_0`.

Expand All @@ -458,11 +458,11 @@ def __call__(
def dzdt(self, t: float) -> complex: ...

@overload
def dzdt(self, t: npt.NDArray[np.float_]) -> npt.NDArray[np.complex_]: ...
def dzdt(self, t: npt.NDArray[np.float64]) -> npt.NDArray[np.complex128]: ...

def dzdt(
self, t: float | npt.NDArray[np.float_]
) -> complex | npt.NDArray[np.complex_]:
self, t: float | npt.NDArray[np.float64]
) -> complex | npt.NDArray[np.complex128]:
"""
The derivative of the parameterised curve in the complex plane, z, with
respect to the parameterization parameter, t.
Expand Down
4 changes: 2 additions & 2 deletions cxroots/root_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def func(z: complex | float) -> complex: ...

@overload
def func(
z: npt.NDArray[np.complex_] | npt.NDArray[np.float_],
) -> npt.NDArray[np.complex_] | complex: ...
z: npt.NDArray[np.complex128] | npt.NDArray[np.float64],
) -> npt.NDArray[np.complex128] | complex: ...

def func(z: ScalarOrArray) -> ComplexScalarOrArray:
return phi(1)(z) * phi(i)(z)
Expand Down
4 changes: 2 additions & 2 deletions cxroots/root_counting.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def integrand_func(z: complex | float) -> complex: ...

@overload
def integrand_func(
z: npt.NDArray[np.complex_] | npt.NDArray[np.float_],
) -> npt.NDArray[np.complex_] | complex: ...
z: npt.NDArray[np.complex128] | npt.NDArray[np.float64],
) -> npt.NDArray[np.complex128] | complex: ...

def integrand_func(z: ScalarOrArray) -> ComplexScalarOrArray:
return phi(z) * psi(z) * (df(z) / f(z)) / (2j * pi)
Expand Down
8 changes: 5 additions & 3 deletions cxroots/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

IntegrationMethod = Literal["quad", "romb"]
Color = Union[str, tuple[float, float, float], tuple[float, float, float, float]] # noqa: UP007
ScalarOrArray = Union[complex, float, npt.NDArray[np.complex_], npt.NDArray[np.float_]] # noqa: UP007
ComplexScalarOrArray = Union[complex, npt.NDArray[np.complex_]] # noqa: UP007
ScalarOrArray = Union[ # noqa: UP007
complex, float, npt.NDArray[np.complex128], npt.NDArray[np.float64]
]
ComplexScalarOrArray = Union[complex, npt.NDArray[np.complex128]] # noqa: UP007


class AnalyticFunc(Protocol):
Expand All @@ -15,7 +17,7 @@ def __call__(self, z: complex | float) -> complex: ...

@overload
def __call__(
self, z: npt.NDArray[np.complex_] | npt.NDArray[np.float_]
self, z: npt.NDArray[np.complex128] | npt.NDArray[np.float64]
) -> ComplexScalarOrArray:
# Note that the function may return a scalar in this case if, for example,
# it's a constant function
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#

scipy==1.14.0
numpy==1.26.4
numpy==2.0.1
numpydoc==1.7.0
mpmath==1.3.0
rich==13.7.1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
python_requires=">=3.10",
setup_requires=["pytest-runner"],
install_requires=[
"numpy<2",
"numpy",
"scipy",
"numpydoc",
"mpmath",
Expand Down
Loading