Skip to content

Commit 74409a1

Browse files
committed
Add stubs for major protocols
1 parent 1d187c1 commit 74409a1

File tree

6 files changed

+89
-6
lines changed

6 files changed

+89
-6
lines changed

src/array_api_typing/__init__.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
"""Static typing support for the array API standard."""
22

3-
__all__ = ["HasArrayNamespace", "__version__", "__version_tuple__"]
3+
__all__ = [
4+
"Array",
5+
"DType",
6+
"Device",
7+
"HasArrayNamespace",
8+
"Namespace",
9+
"__version__",
10+
"__version_tuple__",
11+
]
412

5-
from ._namespace import HasArrayNamespace
13+
from ._array import Array
14+
from ._device import Device
15+
from ._dtype import DType
16+
from ._namespace import HasArrayNamespace, Namespace
617
from ._version import version as __version__, version_tuple as __version_tuple__

src/array_api_typing/_array.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Static typing support for the array API standard."""
2+
3+
from typing import Protocol, final
4+
5+
6+
@final
7+
class Array(Protocol):
8+
pass

src/array_api_typing/_device.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Static typing support for the array API standard."""
2+
3+
from typing import Protocol, final
4+
5+
6+
@final
7+
class Device(Protocol):
8+
pass

src/array_api_typing/_dtype.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Static typing support for the array API standard."""
2+
3+
from typing import Protocol, final
4+
5+
6+
@final
7+
class DType(Protocol):
8+
pass

src/array_api_typing/_namespace.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,47 @@
11
"""Static typing support for the array API standard."""
22

3-
__all__ = ["HasArrayNamespace"]
4-
5-
from types import ModuleType
63
from typing import Protocol, final
74
from typing_extensions import TypeVar
85

9-
T = TypeVar("T", bound=object, default=ModuleType) # PEP 696 default
6+
from ._array import Array
7+
from ._device import Device
8+
from ._dtype import DType
9+
from ._simple import NestedSequence, SupportsBufferProtocol
10+
11+
A = TypeVar("A", bound=Array, default=Array) # PEP 696 default
12+
13+
14+
class Namespace(Protocol[A]):
15+
"""An Array API namespace."""
16+
17+
def asarray(
18+
self,
19+
obj: (
20+
Array
21+
| complex
22+
| NestedSequence[bool | int | float | complex]
23+
| SupportsBufferProtocol
24+
),
25+
/,
26+
*,
27+
dtype: DType | None = None,
28+
device: Device | None = None,
29+
copy: bool | None = None,
30+
**kwargs: object,
31+
) -> A: ...
32+
33+
def astype(
34+
self,
35+
x: A,
36+
dtype: DType,
37+
/,
38+
*,
39+
copy: bool = True,
40+
device: Device | None = None,
41+
) -> A: ...
42+
43+
44+
T = TypeVar("T", bound=Namespace, default=Namespace) # PEP 696 default
1045

1146

1247
@final

src/array_api_typing/_simple.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Static typing support for the array API standard."""
2+
3+
from typing import Any, Protocol, TypeAlias, TypeVar
4+
5+
_T_co = TypeVar("_T_co", covariant=True)
6+
7+
8+
class NestedSequence(Protocol[_T_co]):
9+
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
10+
def __len__(self, /) -> int: ...
11+
12+
13+
SupportsBufferProtocol: TypeAlias = Any

0 commit comments

Comments
 (0)