Skip to content

Commit 81784a9

Browse files
Merge pull request #309 from oscarbenjamin/pr_mpoly_p
Fix typing protocol for multivariate polynomials
2 parents 6c9d799 + c937bdc commit 81784a9

File tree

7 files changed

+293
-184
lines changed

7 files changed

+293
-184
lines changed

src/flint/flint_base/flint_base.pyi

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -58,25 +58,22 @@ class flint_poly(flint_elem, Generic[Telem]):
5858
def length(self) -> int: ...
5959
def degree(self) -> int: ...
6060
def coeffs(self) -> list[Telem]: ...
61-
@overload
6261
def __call__(self, other: Telem | int, /) -> Telem: ...
63-
@overload
64-
def __call__(self, other: Self, /) -> Self: ...
6562
def __pos__(self) -> Self: ...
6663
def __neg__(self) -> Self: ...
67-
def __add__(self, other: Telem | int | Self, /) -> Self: ...
64+
def __add__(self, other: Telem | int, /) -> Self: ...
6865
def __radd__(self, other: Telem | int, /) -> Self: ...
69-
def __sub__(self, other: Telem | int | Self, /) -> Self: ...
66+
def __sub__(self, other: Telem | int, /) -> Self: ...
7067
def __rsub__(self, other: Telem | int, /) -> Self: ...
71-
def __mul__(self, other: Telem | int | Self, /) -> Self: ...
68+
def __mul__(self, other: Telem | int, /) -> Self: ...
7269
def __rmul__(self, other: Telem | int, /) -> Self: ...
73-
def __truediv__(self, other: Telem | int | Self, /) -> Self: ...
70+
def __truediv__(self, other: Telem | int, /) -> Self: ...
7471
def __rtruediv__(self, other: Telem | int, /) -> Self: ...
75-
def __floordiv__(self, other: Telem | int | Self, /) -> Self: ...
72+
def __floordiv__(self, other: Telem | int, /) -> Self: ...
7673
def __rfloordiv__(self, other: Telem | int, /) -> Self: ...
77-
def __mod__(self, other: Telem | int | Self, /) -> Self: ...
74+
def __mod__(self, other: Telem | int, /) -> Self: ...
7875
def __rmod__(self, other: Telem | int, /) -> Self: ...
79-
def __divmod__(self, other: Telem | int | Self, /) -> tuple[Self, Self]: ...
76+
def __divmod__(self, other: Telem | int, /) -> tuple[Self, Self]: ...
8077
def __rdivmod__(self, other: Telem | int, /) -> tuple[Self, Self]: ...
8178
def __pow__(self, other: int, /) -> Self: ...
8279
def is_zero(self) -> bool: ...
@@ -130,41 +127,38 @@ class flint_mpoly(flint_elem, Generic[Tctx, Telem, Telem_coerce]):
130127
def coeffs(self) -> list[Telem]: ...
131128
def __pos__(self) -> Self: ...
132129
def __neg__(self) -> Self: ...
133-
def __add__(self, other: Self | Telem | Telem_coerce | int) -> Self: ...
130+
def __add__(self, other: Telem | Telem_coerce | int) -> Self: ...
134131
def __radd__(self, other: Telem | Telem_coerce | int) -> Self: ...
135-
def __sub__(self, other: Self | Telem | Telem_coerce | int) -> Self: ...
132+
def __sub__(self, other: Telem | Telem_coerce | int) -> Self: ...
136133
def __rsub__(self, other: Telem | Telem_coerce | int) -> Self: ...
137-
def __mul__(self, other: Self | Telem | Telem_coerce | int) -> Self: ...
134+
def __mul__(self, other: Telem | Telem_coerce | int) -> Self: ...
138135
def __rmul__(self, other: Telem | Telem_coerce | int) -> Self: ...
139-
def __truediv__(self, other: Self | Telem | Telem_coerce | int) -> Self: ...
136+
def __truediv__(self, other: Telem | Telem_coerce | int) -> Self: ...
140137
def __rtruediv__(self, other: Telem | Telem_coerce | int) -> Self: ...
141-
def __floordiv__(self, other: Self | Telem | Telem_coerce | int) -> Self: ...
138+
def __floordiv__(self, other: Telem | Telem_coerce | int) -> Self: ...
142139
def __rfloordiv__(self, other: Telem | Telem_coerce | int) -> Self: ...
143-
def __mod__(self, other: Self | Telem | Telem_coerce | int) -> Self: ...
140+
def __mod__(self, other: Telem | Telem_coerce | int) -> Self: ...
144141
def __rmod__(self, other: Telem | Telem_coerce | int) -> Self: ...
145142
def __divmod__(
146-
self, other: Self | Telem | Telem_coerce | int
143+
self, other: Telem | Telem_coerce | int
147144
) -> tuple[Self, Self]: ...
148145
def __rdivmod__(self, other: Telem | Telem_coerce | int) -> tuple[Self, Self]: ...
149146
def __pow__(self, other: Telem | Telem_coerce | int) -> Self: ...
150147
def __rpow__(self, other: Telem | Telem_coerce | int) -> Self: ...
151148
def iadd(self, other: Telem | Telem_coerce | int) -> None: ...
152149
def isub(self, other: Telem | Telem_coerce | int) -> None: ...
153150
def imul(self, other: Telem | Telem_coerce | int) -> None: ...
154-
def gcd(self, other: Self) -> Self: ...
155151
def term_content(self) -> Self: ...
156152
def factor(self) -> tuple[Telem, Sequence[tuple[Self, int]]]: ...
157153
def factor_squarefree(self) -> tuple[Telem, Sequence[tuple[Self, int]]]: ...
158154
def sqrt(self) -> Self: ...
159-
def resultant(self, other: Self, var: _str | int) -> Self: ...
160155
def discriminant(self, var: _str | int) -> Self: ...
161156
def deflation_index(self) -> tuple[list[int], list[int]]: ...
162157
def deflation(self) -> tuple[Self, list[int]]: ...
163158
def deflation_monom(self) -> tuple[Self, list[int], Self]: ...
164159
def inflate(self, N: list[int]) -> Self: ...
165160
def deflate(self, N: list[int]) -> Self: ...
166161
def subs(self, mapping: dict[_str | int, Telem | Telem_coerce | int]) -> Self: ...
167-
def compose(self, *args: Self, ctx: Tctx | None = None) -> Self: ...
168162
def __call__(self, *args: Telem | Telem_coerce) -> Telem: ...
169163
def derivative(self, var: _str | int) -> Self: ...
170164
def unused_gens(self) -> tuple[_str, ...]: ...
@@ -193,14 +187,14 @@ class flint_mpoly_context(flint_elem, Generic[Tmpoly, Telem, Telem_coerce]):
193187
@classmethod
194188
def from_context(
195189
cls,
196-
ctx: Sctx,
190+
ctx: flint_mpoly_context,
197191
names: str | Iterable[str | tuple[str, int]] | tuple[str, int] | None = None,
198192
ordering: Ordering | str = Ordering.lex,
199-
) -> Sctx: ...
193+
) -> Self: ...
200194

201195
class flint_mod_mpoly_context(flint_mpoly_context[Tmpoly, Telem, Telem_coerce]):
202196
@abstractmethod
203-
def modulus(self) -> int: ...
197+
def modulus(self): ...
204198

205199
class flint_series(flint_elem, Generic[Telem]):
206200
"""Base class for power series."""

src/flint/test/test_all.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def raises(f, exception) -> bool:
3232
Tscalar = TypeVar('Tscalar', bound=flint_base.flint_scalar)
3333
Tscalar_co = TypeVar('Tscalar_co', bound=flint_base.flint_scalar, covariant=True)
3434
Tmpoly = TypeVar('Tmpoly', bound=flint_base.flint_mpoly)
35+
Tmpoly_p = TypeVar('Tmpoly_p', bound=typ.mpoly_p)
3536
Tmpolyctx_co = TypeVar('Tmpolyctx_co', bound=flint_base.flint_mpoly_context, covariant=True)
37+
Tmpolyctx_p_co = TypeVar('Tmpolyctx_p_co', bound=typ.mpoly_context_p, covariant=True)
3638

3739

3840
_default_ctx_string = """\
@@ -2620,7 +2622,7 @@ def _all_polys() -> list[tuple[Any, Any, bool, flint.fmpz]]:
26202622

26212623

26222624
Tpoly = TypeVar("Tpoly", bound=typ.epoly_p)
2623-
Tc = TypeVar("Tc", bound=flint_base.flint_scalar)
2625+
Tc = TypeVar("Tc", bound=typ.scalar_p)
26242626
TS = Callable[[Tc | int], Tc]
26252627
TP = Callable[[Tpoly | Sequence[Tc | int] | Tc | int], Tpoly]
26262628
_PolyTestCase = tuple[TP[Tpoly,Tc], TS[Tc], bool, flint.fmpz]
@@ -3093,17 +3095,17 @@ def _all_mpolys(): # -> _all_mpolys_type:
30933095
)
30943096

30953097

3096-
class _GetMPolyCtx(Protocol[Tmpolyctx_co]):
3098+
class _GetMPolyCtx(Protocol[Tmpolyctx_p_co]):
30973099
def __call__(self,
30983100
names: Iterable[str | tuple[str, int]] | tuple[str, int],
30993101
ordering: str | flint.Ordering = "lex"
3100-
) -> Tmpolyctx_co:
3102+
) -> Tmpolyctx_p_co:
31013103
...
31023104

31033105

31043106
_MPolyTestCase = tuple[
3105-
type[Tmpoly],
3106-
_GetMPolyCtx['flint_base.flint_mpoly_context[Tmpoly, Tscalar, int]'],
3107+
type[Tmpoly_p],
3108+
_GetMPolyCtx[typ.mpoly_context_p[Tmpoly_p, Tscalar]],
31073109
Callable[[int], Tscalar],
31083110
bool,
31093111
flint.fmpz
@@ -3176,7 +3178,7 @@ def wrapper():
31763178

31773179

31783180
@all_mpolys
3179-
def test_mpolys_constructor(args: _MPolyTestCase[Tmpoly, Tscalar]) -> None:
3181+
def test_mpolys_constructor(args: _MPolyTestCase[Tmpoly_p, Tscalar]) -> None:
31803182
P, get_context, S, _, _ = args
31813183

31823184
ctx = get_context(("x", 2))
@@ -3301,7 +3303,7 @@ def quick_poly():
33013303

33023304

33033305
@all_mpolys
3304-
def test_mpolys_properties(args: _MPolyTestCase[Tmpoly, Tscalar]) -> None:
3306+
def test_mpolys_properties(args: _MPolyTestCase[Tmpoly_p, Tscalar]) -> None:
33053307

33063308
P, get_context, S, is_field, characteristic = args
33073309

src/flint/types/fmpq_mpoly.pyi

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Iterable, Mapping
1+
from typing import Iterable, Mapping, Any
22
from ..flint_base.flint_base import flint_mpoly, flint_mpoly_context, Ordering
33
from .fmpz import fmpz
44
from .fmpq import fmpq
@@ -20,15 +20,32 @@ class fmpq_mpoly_ctx(flint_mpoly_context[fmpq_mpoly, fmpq, ifmpq]):
2020

2121
def nvars(self) -> int: ...
2222
def ordering(self) -> Ordering: ...
23-
24-
def gen(self, i: int) -> fmpq_mpoly: ...
25-
def from_dict(self, d: Mapping[tuple[int, ...], ifmpq]) -> fmpq_mpoly: ...
26-
def constant(self, z: ifmpq) -> fmpq_mpoly: ...
27-
23+
def gen(self, i: int, /) -> fmpq_mpoly: ...
24+
def from_dict(self, d: Mapping[tuple[int, ...], ifmpq], /) -> fmpq_mpoly: ...
25+
def constant(self, q: ifmpq, /) -> fmpq_mpoly: ...
26+
def name(self, i: int, /) -> str: ...
27+
def names(self) -> tuple[str]: ...
28+
def gens(self) -> tuple[fmpq_mpoly, ...]: ...
29+
def variable_to_index(self, var: str, /) -> int: ...
30+
def term(
31+
self, coeff: ifmpq | None = None, exp_vec: Iterable[int] | None = None
32+
) -> fmpq_mpoly: ...
33+
def drop_gens(self, gens: Iterable[str | int], /) -> fmpq_mpoly_ctx: ...
34+
def append_gens(self, gens: Iterable[str | int], /) -> fmpq_mpoly_ctx: ...
35+
def infer_generator_mapping(
36+
self, ctx: flint_mpoly_context, /
37+
) -> dict[int, int]: ...
38+
@classmethod
39+
def from_context(
40+
cls,
41+
ctx: flint_mpoly_context[Any, Any, Any],
42+
names: str | Iterable[str | tuple[str, int]] | tuple[str, int] | None = None,
43+
ordering: Ordering | str = Ordering.lex,
44+
) -> fmpq_mpoly_ctx: ...
2845

2946
class fmpq_mpoly(flint_mpoly[fmpq_mpoly_ctx, fmpq, ifmpq]):
3047
def __init__(self,
31-
val: fmpq_mpoly | fmpz_mpoly | ifmpq | dict[tuple[int, ...], ifmpq] | _str = 0,
48+
val: fmpq_mpoly | fmpz_mpoly | ifmpq | Mapping[tuple[int, ...], ifmpq] | _str = 0,
3249
ctx: fmpq_mpoly_ctx | None = None
3350
) -> None: ...
3451

@@ -53,7 +70,7 @@ class fmpq_mpoly(flint_mpoly[fmpq_mpoly_ctx, fmpq, ifmpq]):
5370
def __getitem__(self, index: tuple[int, ...]) -> fmpq: ...
5471
def __setitem__(self, index: tuple[int, ...], coeff: ifmpq) -> None: ...
5572

56-
def subs(self, mapping: dict[_str | int, ifmpq]) -> fmpq_mpoly: ...
73+
def subs(self, mapping: Mapping[_str | int, ifmpq]) -> fmpq_mpoly: ...
5774
def compose(self, *args: fmpq_mpoly, ctx: fmpq_mpoly_ctx | None = None) -> fmpq_mpoly: ...
5875

5976
def __call__(self, *args: ifmpq) -> fmpq: ...

src/flint/types/fmpz_mod_mpoly.pyi

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1-
from typing import Iterable, Mapping
2-
from ..flint_base.flint_base import flint_mpoly, flint_mod_mpoly_context, Ordering
1+
from typing import Iterable, Mapping, Any
2+
from ..flint_base.flint_base import (
3+
flint_mpoly,
4+
flint_mpoly_context,
5+
flint_mod_mpoly_context,
6+
Ordering,
7+
)
38
from .fmpz import fmpz
49
from .fmpz_mod import fmpz_mod
510
from .fmpz_mpoly import fmpz_mpoly
@@ -17,21 +22,43 @@ class fmpz_mod_mpoly_ctx(flint_mod_mpoly_context[fmpz_mod_mpoly, fmpz_mod, ifmpz
1722
names: _str | Iterable[_str | tuple[_str, int]] | tuple[_str, int],
1823
ordering: Ordering | _str = Ordering.lex,
1924
*,
20-
modulus: ifmpz,
25+
modulus: int,
2126
) -> fmpz_mod_mpoly_ctx:
2227
...
2328

29+
def modulus(self) -> fmpz: ...
30+
2431
def nvars(self) -> int: ...
2532
def ordering(self) -> Ordering: ...
26-
def modulus(self) -> int: ...
27-
def gen(self, i: int, /) -> fmpz_mod_mpoly: ...
28-
def constant(self, z: ifmpz_mod) -> fmpz_mod_mpoly: ...
33+
34+
def gen(self, i: int) -> fmpz_mod_mpoly: ...
2935
def from_dict(self, d: Mapping[tuple[int, ...], ifmpz_mod]) -> fmpz_mod_mpoly: ...
36+
def constant(self, z: ifmpz_mod) -> fmpz_mod_mpoly: ...
37+
38+
def name(self, i: int, /) -> str: ...
39+
def names(self) -> tuple[str]: ...
40+
def gens(self) -> tuple[fmpz_mod_mpoly, ...]: ...
41+
def variable_to_index(self, var: str, /) -> int: ...
42+
def term(
43+
self, coeff: ifmpz_mod | None = None, exp_vec: Iterable[int] | None = None
44+
) -> fmpz_mod_mpoly: ...
45+
def drop_gens(self, gens: Iterable[str | int], /) -> fmpz_mod_mpoly_ctx: ...
46+
def append_gens(self, gens: Iterable[str | int], /) -> fmpz_mod_mpoly_ctx: ...
47+
def infer_generator_mapping(
48+
self, ctx: flint_mpoly_context, /
49+
) -> dict[int, int]: ...
50+
@classmethod
51+
def from_context(
52+
cls,
53+
ctx: flint_mpoly_context[Any, Any, Any],
54+
names: str | Iterable[str | tuple[str, int]] | tuple[str, int] | None = None,
55+
ordering: Ordering | str = Ordering.lex,
56+
) -> fmpz_mod_mpoly_ctx: ...
3057

3158

3259
class fmpz_mod_mpoly(flint_mpoly[fmpz_mod_mpoly_ctx, fmpz_mod, ifmpz_mod]):
3360
def __init__(self,
34-
val: fmpz_mod_mpoly | fmpz_mpoly | ifmpz_mod | dict[tuple[int, ...], ifmpz_mod] | _str = 0,
61+
val: fmpz_mod_mpoly | fmpz_mpoly | ifmpz_mod | Mapping[tuple[int, ...], ifmpz_mod] | _str = 0,
3562
ctx: fmpz_mod_mpoly_ctx | None = None
3663
) -> None: ...
3764

@@ -56,7 +83,7 @@ class fmpz_mod_mpoly(flint_mpoly[fmpz_mod_mpoly_ctx, fmpz_mod, ifmpz_mod]):
5683
def __getitem__(self, index: tuple[int, ...]) -> fmpz_mod: ...
5784
def __setitem__(self, index: tuple[int, ...], coeff: ifmpz_mod) -> None: ...
5885

59-
def subs(self, mapping: dict[_str | int, ifmpz_mod]) -> fmpz_mod_mpoly: ...
86+
def subs(self, mapping: Mapping[_str | int, ifmpz_mod]) -> fmpz_mod_mpoly: ...
6087
def compose(self, *args: fmpz_mod_mpoly, ctx: fmpz_mod_mpoly_ctx | None = None) -> fmpz_mod_mpoly: ...
6188

6289
def __call__(self, *args: ifmpz_mod) -> fmpz_mod: ...

src/flint/types/fmpz_mpoly.pyi

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Iterable, Iterator, Mapping
1+
from typing import Iterable, Iterator, Mapping, Any
22

33
from ..flint_base.flint_base import flint_mpoly, flint_mpoly_context, Ordering
44
from .fmpz import fmpz
@@ -18,16 +18,33 @@ class fmpz_mpoly_ctx(flint_mpoly_context[fmpz_mpoly, fmpz, ifmpz]):
1818

1919
def nvars(self) -> int: ...
2020
def ordering(self) -> Ordering: ...
21-
22-
def gen(self, i: int) -> fmpz_mpoly: ...
23-
def from_dict(self, d: Mapping[tuple[int, ...], ifmpz]) -> fmpz_mpoly: ...
24-
def constant(self, z: ifmpz) -> fmpz_mpoly: ...
25-
21+
def gen(self, i: int, /) -> fmpz_mpoly: ...
22+
def from_dict(self, d: Mapping[tuple[int, ...], ifmpz], /) -> fmpz_mpoly: ...
23+
def constant(self, z: ifmpz, /) -> fmpz_mpoly: ...
24+
def name(self, i: int, /) -> str: ...
25+
def names(self) -> tuple[str]: ...
26+
def gens(self) -> tuple[fmpz_mpoly, ...]: ...
27+
def variable_to_index(self, var: str, /) -> int: ...
28+
def term(
29+
self, coeff: ifmpz | None = None, exp_vec: Iterable[int] | None = None
30+
) -> fmpz_mpoly: ...
31+
def drop_gens(self, gens: Iterable[str | int], /) -> fmpz_mpoly_ctx: ...
32+
def append_gens(self, gens: Iterable[str | int], /) -> fmpz_mpoly_ctx: ...
33+
def infer_generator_mapping(
34+
self, ctx: flint_mpoly_context, /
35+
) -> dict[int, int]: ...
36+
@classmethod
37+
def from_context(
38+
cls,
39+
ctx: flint_mpoly_context[Any, Any, Any],
40+
names: str | Iterable[str | tuple[str, int]] | tuple[str, int] | None = None,
41+
ordering: Ordering | str = Ordering.lex,
42+
) -> fmpz_mpoly_ctx: ...
2643

2744
class fmpz_mpoly(flint_mpoly[fmpz_mpoly_ctx, fmpz, ifmpz]):
2845

2946
def __init__(self,
30-
val: fmpz_mpoly | ifmpz | dict[tuple[int, ...], ifmpz] | _str = 0,
47+
val: fmpz_mpoly | ifmpz | Mapping[tuple[int, ...], ifmpz] | _str = 0,
3148
ctx: fmpz_mpoly_ctx | None = None
3249
) -> None: ...
3350

@@ -52,7 +69,7 @@ class fmpz_mpoly(flint_mpoly[fmpz_mpoly_ctx, fmpz, ifmpz]):
5269
def __getitem__(self, index: tuple[int, ...]) -> fmpz: ...
5370
def __setitem__(self, index: tuple[int, ...], coeff: ifmpz) -> None: ...
5471

55-
def subs(self, mapping: dict[_str | int, ifmpz]) -> fmpz_mpoly: ...
72+
def subs(self, mapping: Mapping[_str | int, ifmpz]) -> fmpz_mpoly: ...
5673
def compose(self, *args: fmpz_mpoly, ctx: fmpz_mpoly_ctx | None = None) -> fmpz_mpoly: ...
5774

5875
def __call__(self, *args: ifmpz) -> fmpz: ...

0 commit comments

Comments
 (0)