Skip to content

Array type checking fixes #317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 0 additions & 128 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -22633,14 +22633,6 @@
"lineCount": 1
}
},
{
"code": "reportUnusedFunction",
"range": {
"startColumn": 4,
"endColumn": 8,
"lineCount": 1
}
},
{
"code": "reportMissingParameterType",
"range": {
Expand Down Expand Up @@ -24307,30 +24299,6 @@
"lineCount": 1
}
},
{
"code": "reportAttributeAccessIssue",
"range": {
"startColumn": 25,
"endColumn": 29,
"lineCount": 1
}
},
{
"code": "reportAttributeAccessIssue",
"range": {
"startColumn": 25,
"endColumn": 29,
"lineCount": 1
}
},
{
"code": "reportAttributeAccessIssue",
"range": {
"startColumn": 14,
"endColumn": 31,
"lineCount": 1
}
},
{
"code": "reportAttributeAccessIssue",
"range": {
Expand All @@ -24339,70 +24307,6 @@
"lineCount": 1
}
},
{
"code": "reportAttributeAccessIssue",
"range": {
"startColumn": 43,
"endColumn": 47,
"lineCount": 1
}
},
{
"code": "reportAttributeAccessIssue",
"range": {
"startColumn": 43,
"endColumn": 47,
"lineCount": 1
}
},
{
"code": "reportArgumentType",
"range": {
"startColumn": 22,
"endColumn": 36,
"lineCount": 1
}
},
{
"code": "reportAttributeAccessIssue",
"range": {
"startColumn": 37,
"endColumn": 41,
"lineCount": 1
}
},
{
"code": "reportAttributeAccessIssue",
"range": {
"startColumn": 42,
"endColumn": 50,
"lineCount": 1
}
},
{
"code": "reportArgumentType",
"range": {
"startColumn": 22,
"endColumn": 37,
"lineCount": 1
}
},
{
"code": "reportAttributeAccessIssue",
"range": {
"startColumn": 38,
"endColumn": 42,
"lineCount": 1
}
},
{
"code": "reportAttributeAccessIssue",
"range": {
"startColumn": 43,
"endColumn": 51,
"lineCount": 1
}
},
{
"code": "reportMissingParameterType",
"range": {
Expand Down Expand Up @@ -24615,22 +24519,6 @@
"lineCount": 1
}
},
{
"code": "reportUnusedFunction",
"range": {
"startColumn": 4,
"endColumn": 28,
"lineCount": 1
}
},
{
"code": "reportUnusedFunction",
"range": {
"startColumn": 4,
"endColumn": 30,
"lineCount": 1
}
},
{
"code": "reportMissingParameterType",
"range": {
Expand Down Expand Up @@ -24663,14 +24551,6 @@
"lineCount": 1
}
},
{
"code": "reportUnusedFunction",
"range": {
"startColumn": 4,
"endColumn": 23,
"lineCount": 1
}
},
{
"code": "reportGeneralTypeIssues",
"range": {
Expand Down Expand Up @@ -24735,14 +24615,6 @@
"lineCount": 1
}
},
{
"code": "reportUnusedFunction",
"range": {
"startColumn": 4,
"endColumn": 26,
"lineCount": 1
}
},
{
"code": "reportMissingParameterType",
"range": {
Expand Down
28 changes: 19 additions & 9 deletions arraycontext/container/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@

import numpy as np

from pymbolic.typing import Integer

from arraycontext.container import (
ArrayContainer,
NotAnArrayContainerError,
Expand All @@ -91,7 +93,6 @@
ArrayOrContainer,
ArrayOrContainerOrScalar,
ArrayOrContainerT,
ArrayT,
ScalarLike,
)

Expand Down Expand Up @@ -400,21 +401,20 @@ def keyed_map_array_container(


def rec_keyed_map_array_container(
f: Callable[[tuple[SerializationKey, ...], ArrayT], ArrayT],
f: Callable[[tuple[SerializationKey, ...], Array], Array],
ary: ArrayOrContainer) -> ArrayOrContainer:
"""
Works similarly to :func:`rec_map_array_container`, except that *f* also
takes in a traversal path to the leaf array. The traversal path argument is
passed in as a tuple of identifiers of the arrays traversed before reaching
the current array.
"""

def rec(keys: tuple[SerializationKey, ...],
ary_: ArrayOrContainerT) -> ArrayOrContainerT:
ary_: ArrayOrContainer) -> ArrayOrContainer:
try:
iterable = serialize_container(ary_)
except NotAnArrayContainerError:
return cast(ArrayOrContainerT, f(keys, cast(ArrayT, ary_)))
return cast(ArrayOrContainer, f(keys, cast(Array, ary_)))
else:
return deserialize_container(ary_, [
(key, rec((*keys, key), subary)) for key, subary in iterable
Expand Down Expand Up @@ -777,7 +777,7 @@ def unflatten(
checks are skipped.
"""
# NOTE: https://github.com/python/mypy/issues/7057
offset = 0
offset: int = 0
common_dtype = None

def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
Expand All @@ -790,7 +790,11 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:

# {{{ validate subary

if (offset + template_subary_c.size) > ary.size:
if (
isinstance(offset, Integer)
and isinstance(template_subary_c.size, Integer)
and isinstance(ary.size, Integer)
and (offset + template_subary_c.size) > ary.size):
raise ValueError("'template' and 'ary' sizes do not match: "
"'template' is too large") from None

Expand All @@ -813,6 +817,12 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:

# {{{ reshape

if not isinstance(template_subary_c.size, Integer):
raise NotImplementedError(
"unflatten is not implemented for arrays with array-valued "
"size.") from None

# FIXME: Not sure how to make the slicing part work for Array-valued sizes
flat_subary = ary[offset:offset + template_subary_c.size]
try:
subary = actx.np.reshape(flat_subary,
Expand Down Expand Up @@ -871,15 +881,15 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:


def flat_size_and_dtype(
ary: ArrayOrContainer) -> tuple[int, np.dtype[Any] | None]:
ary: ArrayOrContainer) -> tuple[Array | Integer, np.dtype[Any] | None]:
"""
:returns: a tuple ``(size, dtype)`` that would be the length and
:class:`numpy.dtype` of the one-dimensional array returned by
:func:`flatten`.
"""
common_dtype = None

def _flat_size(subary: ArrayOrContainer) -> int:
def _flat_size(subary: ArrayOrContainer) -> Array | Integer:
nonlocal common_dtype

try:
Expand Down
29 changes: 15 additions & 14 deletions arraycontext/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@
import numpy as np
from typing_extensions import Self

from pymbolic.typing import Integer, Scalar as _Scalar
from pytools import memoize_method
from pytools.tag import ToTagSetConvertible

Expand Down Expand Up @@ -202,11 +203,11 @@ class Array(Protocol):
"""

@property
def shape(self) -> tuple[int, ...]:
def shape(self) -> tuple[Array | Integer, ...]:
...

@property
def size(self) -> int:
def size(self) -> Array | Integer:
...

@property
Expand All @@ -221,21 +222,21 @@ def __getitem__(self, index: Any) -> Array:
...

# some basic arithmetic that's supposed to work
def __neg__(self) -> Self: ...
def __abs__(self) -> Self: ...
def __add__(self, other: Self | ScalarLike) -> Self: ...
def __radd__(self, other: Self | ScalarLike) -> Self: ...
def __sub__(self, other: Self | ScalarLike) -> Self: ...
def __rsub__(self, other: Self | ScalarLike) -> Self: ...
def __mul__(self, other: Self | ScalarLike) -> Self: ...
def __rmul__(self, other: Self | ScalarLike) -> Self: ...
def __truediv__(self, other: Self | ScalarLike) -> Self: ...
def __rtruediv__(self, other: Self | ScalarLike) -> Self: ...
def __neg__(self) -> Array: ...
def __abs__(self) -> Array: ...
def __add__(self, other: Self | ScalarLike) -> Array: ...
def __radd__(self, other: Self | ScalarLike) -> Array: ...
def __sub__(self, other: Self | ScalarLike) -> Array: ...
def __rsub__(self, other: Self | ScalarLike) -> Array: ...
def __mul__(self, other: Self | ScalarLike) -> Array: ...
def __rmul__(self, other: Self | ScalarLike) -> Array: ...
def __truediv__(self, other: Self | ScalarLike) -> Array: ...
def __rtruediv__(self, other: Self | ScalarLike) -> Array: ...


# deprecated, use ScalarLike instead
ScalarLike: TypeAlias = int | float | complex | np.generic
Scalar = ScalarLike
Scalar = _Scalar
ScalarLike = Scalar
ScalarLikeT = TypeVar("ScalarLikeT", bound=ScalarLike)

# NOTE: I'm kind of not sure about the *Tc versions of these type variables.
Expand Down
3 changes: 3 additions & 0 deletions arraycontext/impl/pyopencl/taggable_cl_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class TaggableCLArray(cla.Array, Taggable):
record application-specific metadata to drive the optimizations in
:meth:`arraycontext.PyOpenCLArrayContext.transform_loopy_program`.
"""
tags: frozenset[Tag]
axes: tuple[Axis, ...]

def __init__(self, cq, shape, dtype, order="C", allocator=None,
data=None, offset=0, strides=None, events=None, _flags=None,
_fast=False, _size=None, _context=None, _queue=None,
Expand Down
20 changes: 12 additions & 8 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):

.. automethod:: compile
"""
queue: cl.CommandQueue # pyright: ignore[reportUnknownMemberType]

def __init__(
self, queue: cl.CommandQueue, allocator=None, *,
use_memory_pool: bool | None = None,
Expand Down Expand Up @@ -425,7 +427,7 @@ def from_numpy(self, array):

import arraycontext.impl.pyopencl.taggable_cl_array as tga

def _from_numpy(ary):
def _from_numpy(ary: np.ndarray[Any, Any]) -> pt.Array:
return pt.make_data_wrapper(
tga.to_device(self.queue, ary, allocator=self.allocator)
)
Expand Down Expand Up @@ -642,10 +644,11 @@ def thaw(self, array):
import arraycontext.impl.pyopencl.taggable_cl_array as tga
from .utils import get_pt_axes_from_cl_axes

def _thaw(ary):
return pt.make_data_wrapper(ary.with_queue(self.queue),
axes=get_pt_axes_from_cl_axes(ary.axes),
tags=ary.tags)
def _thaw(ary: tga.TaggableCLArray) -> pt.Array:
return pt.make_data_wrapper(
ary.with_queue(self.queue), # pyright: ignore[reportArgumentType, reportUnknownMemberType, reportUnknownArgumentType]
axes=get_pt_axes_from_cl_axes(ary.axes),
tags=ary.tags)

return with_array_context(
self._rec_map_container(_thaw, array, (tga.TaggableCLArray,)),
Expand Down Expand Up @@ -836,7 +839,7 @@ def from_numpy(self, array):
import jax
import pytato as pt

def _from_numpy(ary):
def _from_numpy(ary: np.ndarray[Any, Any]) -> pt.Array:
return pt.make_data_wrapper(jax.device_put(ary))

return with_array_context(
Expand Down Expand Up @@ -892,7 +895,7 @@ def _record_leaf_ary_in_dict(key: tuple[Any, ...],

# }}}

def _to_frozen(key: tuple[Any, ...], ary) -> jnp.ndarray:
def _to_frozen(key: tuple[Any, ...], ary: pt.Array) -> jnp.ndarray:
key_str = "_ary" + _ary_container_key_stringifier(key)
return key_to_frozen_subary[key_str]

Expand All @@ -919,9 +922,10 @@ def _to_frozen(key: tuple[Any, ...], ary) -> jnp.ndarray:
actx=None)

def thaw(self, array):
import jax.numpy as jnp
import pytato as pt

def _thaw(ary):
def _thaw(ary: jnp.ndarray) -> pt.Array:
return pt.make_data_wrapper(ary)

return with_array_context(
Expand Down
Loading
Loading