Skip to content

Commit e959e8c

Browse files
committed
improve triangulation and LearnerND typing
1 parent f21f19d commit e959e8c

File tree

2 files changed

+51
-44
lines changed

2 files changed

+51
-44
lines changed

adaptive/learner/learnerND.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def volume(simplex: List[Tuple[float, float]], ys: None = None,) -> float:
4040
return vol
4141

4242

43-
def orientation(simplex):
43+
def orientation(simplex: np.ndarray):
4444
matrix = np.subtract(simplex[:-1], simplex[-1])
4545
# See https://www.jstor.org/stable/2315353
4646
sign, _logdet = np.linalg.slogdet(matrix)

adaptive/learner/triangulation.py

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1+
import collections.abc
12
import math
23
from collections import Counter
3-
from collections.abc import Iterable, Sized
44
from itertools import chain, combinations
55
from math import factorial
6-
from typing import Any, Iterator, List, Optional, Sequence, Set, Tuple, Union
6+
from typing import Any, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Union
77

88
import numpy as np
99
import scipy.spatial
1010

11-
Simplex = Tuple[int, ...] # XXX: check if this is correct
11+
SimplexPoints = Union[
12+
List[Tuple[float, ...]], np.ndarray
13+
] # XXX: check if this is correct
14+
Simplex = Tuple[int, ...]
15+
Point = Union[Tuple[float, ...], np.ndarray] # XXX: check if this is correct
1216

1317

1418
def fast_norm(v: Union[Tuple[float, ...], np.ndarray]) -> float:
@@ -21,9 +25,7 @@ def fast_norm(v: Union[Tuple[float, ...], np.ndarray]) -> float:
2125

2226

2327
def fast_2d_point_in_simplex(
24-
point: Tuple[float, ...],
25-
simplex: Union[List[Tuple[float, ...]], np.ndarray],
26-
eps: float = 1e-8,
28+
point: Point, simplex: SimplexPoints, eps: float = 1e-8
2729
) -> Union[bool, np.bool_]:
2830
(p0x, p0y), (p1x, p1y), (p2x, p2y) = simplex
2931
px, py = point
@@ -38,9 +40,7 @@ def fast_2d_point_in_simplex(
3840
return (t >= -eps) and (s + t <= 1 + eps)
3941

4042

41-
def point_in_simplex(
42-
point: Any, simplex: Simplex, eps: float = 1e-8
43-
) -> Union[bool, np.bool_]:
43+
def point_in_simplex(point: Point, simplex: SimplexPoints, eps: float = 1e-8) -> bool:
4444
if len(point) == 2:
4545
return fast_2d_point_in_simplex(point, simplex, eps)
4646

@@ -51,7 +51,7 @@ def point_in_simplex(
5151
return all(alpha > -eps) and sum(alpha) < 1 + eps
5252

5353

54-
def fast_2d_circumcircle(points: np.ndarray,) -> Tuple[Tuple[float, float], float]:
54+
def fast_2d_circumcircle(points: Iterable[Point]) -> Tuple[Tuple[float, float], float]:
5555
"""Compute the center and radius of the circumscribed circle of a triangle
5656
5757
Parameters
@@ -88,7 +88,7 @@ def fast_2d_circumcircle(points: np.ndarray,) -> Tuple[Tuple[float, float], floa
8888

8989

9090
def fast_3d_circumcircle(
91-
points: np.ndarray,
91+
points: Iterable[Point],
9292
) -> Tuple[Tuple[float, float, float], float]:
9393
"""Compute the center and radius of the circumscribed shpere of a simplex.
9494
@@ -140,7 +140,7 @@ def fast_det(matrix: np.ndarray) -> float:
140140
return np.linalg.det(matrix)
141141

142142

143-
def circumsphere(pts: np.ndarray,) -> Tuple[Tuple[float, ...], float]:
143+
def circumsphere(pts: np.ndarray) -> Tuple[Tuple[float, ...], float]:
144144
dim = len(pts) - 1
145145
if dim == 2:
146146
return fast_2d_circumcircle(pts)
@@ -193,10 +193,12 @@ def orientation(face: np.ndarray, origin: np.ndarray) -> int:
193193

194194

195195
def is_iterable_and_sized(obj: Any) -> bool:
196-
return isinstance(obj, Iterable) and isinstance(obj, Sized)
196+
return isinstance(obj, collections.abc.Iterable) and isinstance(
197+
obj, collections.abc.Sized
198+
)
197199

198200

199-
def simplex_volume_in_embedding(vertices: List[Tuple[float, ...]]) -> float:
201+
def simplex_volume_in_embedding(vertices: Iterable[Point]) -> float:
200202
"""Calculate the volume of a simplex in a higher dimensional embedding.
201203
That is: dim > len(vertices) - 1. For example if you would like to know the
202204
surface area of a triangle in a 3d space.
@@ -277,7 +279,7 @@ class Triangulation:
277279
or more simplices in the
278280
"""
279281

280-
def __init__(self, coords: np.ndarray) -> None:
282+
def __init__(self, coords: Iterable[Point]) -> None:
281283
if not is_iterable_and_sized(coords):
282284
raise TypeError("Please provide a 2-dimensional list of points")
283285
coords = list(coords)
@@ -305,10 +307,10 @@ def __init__(self, coords: np.ndarray) -> None:
305307
"(the points are linearly dependent)"
306308
)
307309

308-
self.vertices = list(coords)
309-
self.simplices = set()
310+
self.vertices: List[Point] = list(coords)
311+
self.simplices: Set[Simplex] = set()
310312
# initialise empty set for each vertex
311-
self.vertex_to_simplices = [set() for _ in coords]
313+
self.vertex_to_simplices: List[Set[Simplex]] = [set() for _ in coords]
312314

313315
# find a Delaunay triangulation to start with, then we will throw it
314316
# away and continue with our own algorithm
@@ -328,16 +330,16 @@ def add_simplex(self, simplex: Simplex) -> None:
328330
for vertex in simplex:
329331
self.vertex_to_simplices[vertex].add(simplex)
330332

331-
def get_vertices(self, indices: Sequence[int]) -> Any:
333+
def get_vertices(self, indices: Sequence[int]) -> List[Optional[Point]]:
332334
return [self.get_vertex(i) for i in indices]
333335

334-
def get_vertex(self, index: Optional[int]) -> Any:
336+
def get_vertex(self, index: Optional[int]) -> Optional[Point]:
335337
if index is None:
336338
return None
337339
return self.vertices[index]
338340

339341
def get_reduced_simplex(
340-
self, point: Any, simplex: Simplex, eps: float = 1e-8
342+
self, point: Point, simplex: Simplex, eps: float = 1e-8
341343
) -> list:
342344
"""Check whether vertex lies within a simplex.
343345
@@ -364,12 +366,12 @@ def get_reduced_simplex(
364366
return [simplex[i] for i in result]
365367

366368
def point_in_simplex(
367-
self, point: Any, simplex: Simplex, eps: float = 1e-8
368-
) -> Union[bool, np.bool_]:
369+
self, point: Point, simplex: Simplex, eps: float = 1e-8
370+
) -> bool:
369371
vertices = self.get_vertices(simplex)
370372
return point_in_simplex(point, vertices, eps)
371373

372-
def locate_point(self, point: Any) -> Any:
374+
def locate_point(self, point: Point) -> Simplex:
373375
"""Find to which simplex the point belongs.
374376
375377
Return indices of the simplex containing the point.
@@ -385,8 +387,11 @@ def dim(self) -> int:
385387
return len(self.vertices[0])
386388

387389
def faces(
388-
self, dim: None = None, simplices: Optional[Any] = None, vertices: None = None
389-
) -> Iterator[Any]:
390+
self,
391+
dim: Optional[int] = None,
392+
simplices: Optional[Iterable[Simplex]] = None,
393+
vertices: Optional[Iterable[int]] = None,
394+
) -> Iterator[Tuple[int, ...]]:
390395
"""Iterator over faces of a simplex or vertex sequence."""
391396
if dim is None:
392397
dim = self.dim
@@ -407,11 +412,11 @@ def faces(
407412
else:
408413
return faces
409414

410-
def containing(self, face):
415+
def containing(self, face: Tuple[int, ...]) -> Set[Simplex]:
411416
"""Simplices containing a face."""
412417
return set.intersection(*(self.vertex_to_simplices[i] for i in face))
413418

414-
def _extend_hull(self, new_vertex: Any, eps: float = 1e-8) -> Any:
419+
def _extend_hull(self, new_vertex: Point, eps: float = 1e-8) -> Set[Simplex]:
415420
# count multiplicities in order to get all hull faces
416421
multiplicities = Counter(face for face in self.faces())
417422
hull_faces = [face for face, count in multiplicities.items() if count == 1]
@@ -471,7 +476,7 @@ def circumscribed_circle(
471476

472477
def point_in_cicumcircle(
473478
self, pt_index: int, simplex: Simplex, transform: np.ndarray
474-
) -> np.bool_:
479+
) -> bool:
475480
# return self.fast_point_in_circumcircle(pt_index, simplex, transform)
476481
eps = 1e-8
477482

@@ -487,9 +492,9 @@ def default_transform(self) -> np.ndarray:
487492
def bowyer_watson(
488493
self,
489494
pt_index: int,
490-
containing_simplex: Optional[Any] = None,
495+
containing_simplex: Optional[Simplex] = None,
491496
transform: Optional[np.ndarray] = None,
492-
) -> Any:
497+
) -> Tuple[Set[Simplex], Set[Simplex]]:
493498
"""Modified Bowyer-Watson point adding algorithm.
494499
495500
Create a hole in the triangulation around the new point,
@@ -549,7 +554,7 @@ def bowyer_watson(
549554
new_triangles = self.vertex_to_simplices[pt_index]
550555
return bad_triangles - new_triangles, new_triangles - bad_triangles
551556

552-
def _simplex_is_almost_flat(self, simplex: Simplex) -> np.bool_:
557+
def _simplex_is_almost_flat(self, simplex: Simplex) -> bool:
553558
return self._relative_volume(simplex) < 1e-8
554559

555560
def _relative_volume(self, simplex: Simplex) -> float:
@@ -565,8 +570,8 @@ def _relative_volume(self, simplex: Simplex) -> float:
565570

566571
def add_point(
567572
self,
568-
point: Any,
569-
simplex: Optional[Any] = None,
573+
point: Point,
574+
simplex: Optional[Simplex] = None,
570575
transform: Optional[np.ndarray] = None,
571576
) -> Any:
572577
"""Add a new vertex and create simplices as appropriate.
@@ -575,13 +580,13 @@ def add_point(
575580
----------
576581
point : float vector
577582
Coordinates of the point to be added.
578-
transform : N*N matrix of floats
579-
Multiplication matrix to apply to the point (and neighbouring
580-
simplices) when running the Bowyer Watson method.
581583
simplex : tuple of ints, optional
582584
Simplex containing the point. Empty tuple indicates points outside
583585
the hull. If not provided, the algorithm costs O(N), so this should
584586
be used whenever possible.
587+
transform : N*N matrix of floats
588+
Multiplication matrix to apply to the point (and neighbouring
589+
simplices) when running the Bowyer Watson method.
585590
"""
586591
point = tuple(point)
587592
if simplex is None:
@@ -626,7 +631,7 @@ def volume(self, simplex: Simplex) -> float:
626631
def volumes(self) -> List[float]:
627632
return [self.volume(sim) for sim in self.simplices]
628633

629-
def reference_invariant(self):
634+
def reference_invariant(self) -> bool:
630635
"""vertex_to_simplices and simplices are compatible."""
631636
for vertex in range(len(self.vertices)):
632637
if any(vertex not in tri for tri in self.vertex_to_simplices[vertex]):
@@ -640,26 +645,28 @@ def vertex_invariant(self, vertex):
640645
"""Simplices originating from a vertex don't overlap."""
641646
raise NotImplementedError
642647

643-
def get_neighbors_from_vertices(self, simplex: Simplex) -> Any:
648+
def get_neighbors_from_vertices(self, simplex: Simplex) -> Set[Simplex]:
644649
return set.union(*[self.vertex_to_simplices[p] for p in simplex])
645650

646-
def get_face_sharing_neighbors(self, neighbors: Any, simplex: Simplex) -> Any:
651+
def get_face_sharing_neighbors(
652+
self, neighbors: Set[Simplex], simplex: Simplex
653+
) -> Set[Simplex]:
647654
"""Keep only the simplices sharing a whole face with simplex."""
648655
return {
649656
simpl for simpl in neighbors if len(set(simpl) & set(simplex)) == self.dim
650657
} # they share a face
651658

652-
def get_simplices_attached_to_points(self, indices: Any) -> Any:
659+
def get_simplices_attached_to_points(self, indices: Simplex) -> Set[Simplex]:
653660
# Get all simplices that share at least a point with the simplex
654661
neighbors = self.get_neighbors_from_vertices(indices)
655662
return self.get_face_sharing_neighbors(neighbors, indices)
656663

657-
def get_opposing_vertices(self, simplex: Simplex,) -> Any:
664+
def get_opposing_vertices(self, simplex: Simplex) -> Tuple[int, ...]:
658665
if simplex not in self.simplices:
659666
raise ValueError("Provided simplex is not part of the triangulation")
660667
neighbors = self.get_simplices_attached_to_points(simplex)
661668

662-
def find_opposing_vertex(vertex):
669+
def find_opposing_vertex(vertex: int):
663670
# find the simplex:
664671
simp = next((x for x in neighbors if vertex not in x), None)
665672
if simp is None:

0 commit comments

Comments
 (0)