Skip to content

add typing and structure to base Graph class #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ repository = "https://github.com/funkelab/spatial_graph"
[project.optional-dependencies]
dev = [
"pytest>=8.3.4",
"pre-commit",
"ruff",
"mypy",
]

[tool.hatch.metadata]
Expand Down
134 changes: 74 additions & 60 deletions spatial_graph/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,84 @@
import numpy as np
from __future__ import annotations

import re

# Define valid base types
VALID_BASE_TYPES = {
# Floating-point types (no _t suffix)
"float": "float",
"float32": "float",
"double": "double",
"float64": "double",
# Fixed-width integer types
"int8": "int8_t",
"int16": "int16_t",
"int32": "int32_t",
"int64": "int64_t",
"uint8": "uint8_t",
"uint16": "uint16_t",
"uint32": "uint32_t",
"uint64": "uint64_t",
# Platform-dependent types mapped to fixed-width equivalents
"int": "int32_t", # Map generic int to standard 32-bit
"uint": "uint32_t", # Map generic uint to standard 32-bit
}

# Regex pattern for dtype validation and extraction
DTYPE_PATTERN = r"^({})(?:\[(\d+)\])?$".format("|".join(VALID_BASE_TYPES))
DTYPE_REGEX = re.compile(DTYPE_PATTERN)


class DType:
def __init__(self, dtype_str):
"""A class to represent a data type in C/C++ and Cython/PYX files.

Parameters
----------
dtype_str : str
The data type string in the format "base_type[size]". The base_type must
be one of the valid base types defined in VALID_BASE_TYPES, and size is
optional.
"""

def __init__(self, dtype_str: str) -> None:
self.as_string = dtype_str
self.is_array = self.__is_array(dtype_str)
self.base, self.size = self.__parse_array_dtype(dtype_str)
self.is_array = self.size is not None
self.shape = (self.size,) if self.is_array else ()

if self.is_array:
self.base, self.size = self.__parse_array_dtype(dtype_str)
self.shape = (self.size,)
else:
self.base = dtype_str
self.size = None
self.shape = ()
def __parse_array_dtype(self, dtype_str: str) -> tuple[str, int | None]:
"""Parse the array dtype string into base type and size."""

if not (match := DTYPE_REGEX.match(dtype_str)):
raise ValueError(
f"Invalid dtype string: {dtype_str!r}. Must have base type of "
f"{list(VALID_BASE_TYPES)!r} and optional size in square brackets."
)

def __is_array(self, dtype):
if "[" in dtype:
if "]" not in dtype:
raise RuntimeError(f"invalid array(?) dtype {dtype}")
return True
return False
base = match.group(1)
size = int(match.group(2)) if match.group(2) else None

def __parse_array_dtype(self, dtype):
dtype, size = dtype.split("[")
size = int(size.split("]")[0])
if base not in VALID_BASE_TYPES: # pragma: no cover
raise ValueError(f"Invalid base type: {base}")

return dtype, size
return base, size

@property
def base_c_type(self):
def base_c_type(self) -> str:
"""Convert the base of this DType into the equivalent C/C++ type."""
return VALID_BASE_TYPES[self.base]

if self.base == "float32" or self.base == "float":
return "float"
elif self.base == "float64" or self.base == "double":
return "double"
else:
# this might not work for all of them, this is just a fallback
return np.dtype(self.base).name + "_t"

def to_c_decl(self, name):
"""Convert this dtype to the equivalent C/C++ declaration with the
given name:
def to_c_decl(self, name: str) -> str:
"""Convert this dtype to the equivalent C/C++ declaration with the given name.

"base_c_type name" if not an array
"base_c_type name[size]" if an array type
"base_c_type name" if not an array
"base_c_type name[size]" if an array type
"""
# is this an array type?
if self.is_array:
suffix = f"[{self.size}]"
return f"{self.base_c_type} {name}[{self.size}]"
else:
suffix = ""

return self.base_c_type + " " + name + suffix
return f"{self.base_c_type} {name}"

def to_pyxtype(self, use_memory_view=False, add_dim=False):
def to_pyxtype(self, use_memory_view: bool = False, add_dim: bool = False) -> str:
"""Convert this dtype to the equivalent PYX type.

"base_c_type"
Expand All @@ -76,8 +100,6 @@ def to_pyxtype(self, use_memory_view=False, add_dim=False):
"int32_t" for dtype "int32". If this DType is already an array,
will create a 2D array, e.g., "int32_t[:, ::1]".
"""

# is this an array type?
if self.is_array:
if add_dim:
suffix = "[:, ::1]"
Expand All @@ -91,24 +113,23 @@ def to_pyxtype(self, use_memory_view=False, add_dim=False):

return self.base_c_type + suffix

def to_rvalue(self, name, array_index=None):
"""Convert this dtype into an r-value to be used in PYX files for
assignments.
def to_rvalue(self, name: str, array_index: str | None = None) -> str:
"""Convert this dtype into an r-value to be used in PYX files for assignments.

"name" default
"name[array_index]" if array_index is given
"{name[0], ..., name[size-1]}"
if an array type
"{name[array_index, 0], ..., name[array_index, size-1]}"
if an array type and array_index is given
"name" default
"name[array_index]" if array_index is given
"{name[0], ..., name[size-1]}"
if an array type
"{name[array_index, 0], ..., name[array_index, size-1]}"
if an array type and array_index is given
"""

if self.is_array:
if self.size:
if array_index:
return (
"{"
+ ", ".join(
[name + f"[{array_index}, {i}]" for i in range(self.size)]
[f"{name}[{array_index}, {i}]" for i in range(self.size)]
)
+ "}"
)
Expand All @@ -121,10 +142,3 @@ def to_rvalue(self, name, array_index=None):
return f"{name}[{array_index}]"
else:
return name


def dtypes_to_struct(struct_name, dtypes):
pyx_code = f"cdef struct {struct_name}:\n"
for name, dtype in dtypes.items():
pyx_code += f" {dtype.to_pyxtype()} {name}\n"
return pyx_code
Loading
Loading