Skip to content

Commit 7d9a26d

Browse files
authored
Switch arcade.math.lerp to use a Protocol bound to a TypeVar (#2310)
* Add HasAddSubMul protocol to cover lerping * Add support for arbitary lerpable arguments to arcade.math.lerp * Explain why the protocol and TypeVars are so specific * Convert lerp_2d and lerp_3d to return Vec types * Use appropriate Point2 and Point3 annotations instead of Point or local annotations * Fix pyright by cleaning A* pathing annotations and imports
1 parent 5832ddc commit 7d9a26d

File tree

3 files changed

+78
-40
lines changed

3 files changed

+78
-40
lines changed

arcade/math.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import math
44
import random
5-
from typing import Sequence, Union
5+
from typing import TypeVar
66

7-
from pyglet.math import Vec2
7+
from pyglet.math import Vec2, Vec3
88

9-
from arcade.types import AsFloat, Point, Point2
9+
from arcade.types import HasAddSubMul, Point, Point2
1010
from arcade.types.rect import Rect
1111
from arcade.types.vector_like import Point3
1212

@@ -46,43 +46,64 @@ def clamp(a, low: float, high: float) -> float:
4646
return high if a > high else max(a, low)
4747

4848

49-
V_2D = Union[tuple[AsFloat, AsFloat], Sequence[AsFloat]]
50-
V_3D = Union[tuple[AsFloat, AsFloat, AsFloat], Sequence[AsFloat]]
49+
# This TypeVar helps match v1 and v2 as the same type below in lerp's
50+
# signature. If we used HasAddSubMul, they could be different.
51+
L = TypeVar("L", bound=HasAddSubMul)
5152

5253

53-
def lerp(v1: AsFloat, v2: AsFloat, u: float) -> float:
54-
"""linearly interpolate between two values
54+
def lerp(v1: L, v2: L, u: float) -> L:
55+
"""Linearly interpolate two values which support arithmetic operators.
56+
57+
Both ``v1`` and ``v2`` must be of compatible types and support
58+
the following operators:
59+
60+
* ``+`` (:py:meth:`~object.__add__`)
61+
* ``-`` (:py:meth:`~object.__sub__`)
62+
* ``*`` (:py:meth:`~object.__mul__`)
63+
64+
This means that in certain cases, you may want to use another
65+
function:
66+
67+
* For angles, use :py:func:`lerp_angle`.
68+
* To convert points as arbitary sequences, use:
69+
70+
* :py:func:`lerp_2d`
71+
* :py:func:`lerp_3d`
5572
5673
Args:
57-
v1 (float): The first value
58-
v2 (float): The second value
59-
u (float): The interpolation value `(0.0 to 1.0)`
74+
v1 (HasAddSubMul): The first value
75+
v2 (HasAddSubMul): The second value
76+
u: The interpolation value `(0.0 to 1.0)`
6077
"""
6178
return v1 + ((v2 - v1) * u)
6279

6380

64-
def lerp_2d(v1: V_2D, v2: V_2D, u: float) -> tuple[float, float]:
65-
"""
66-
Linearly interpolate between two 2D points.
81+
def lerp_2d(v1: Point2, v2: Point2, u: float) -> Vec2:
82+
"""Linearly interpolate between two 2D points passed as sequences.
83+
84+
.. tip:: This function returns a :py:class:`Vec2` you can use
85+
with :py:func`lerp` .
6786
6887
Args:
69-
v1 (tuple[float, float]): The first point
70-
v2 (tuple[float, float]): The second point
88+
v1: The first point as a sequence of 2 values.
89+
v2: The second point as a sequence of 2 values.
7190
u (float): The interpolation value `(0.0 to 1.0)`
7291
"""
73-
return (lerp(v1[0], v2[0], u), lerp(v1[1], v2[1], u))
92+
return Vec2(lerp(v1[0], v2[0], u), lerp(v1[1], v2[1], u))
7493

7594

76-
def lerp_3d(v1: V_3D, v2: V_3D, u: float) -> tuple[float, float, float]:
77-
"""
78-
Linearly interpolate between two 3D points.
95+
def lerp_3d(v1: Point3, v2: Point3, u: float) -> Vec3:
96+
"""Linearly interpolate between two 3D points passed as sequences.
97+
98+
.. tip:: This function returns a :py:class:`Vec2` you can use
99+
with :py:func`lerp`.
79100
80101
Args:
81-
v1 (tuple[float, float, float]): The first point
82-
v2 (tuple[float, float, float]): The second point
102+
v1: The first point as a sequence of 3 values.
103+
v2: The second point as a sequence of 3 values.
83104
u (float): The interpolation value `(0.0 to 1.0)`
84105
"""
85-
return (lerp(v1[0], v2[0], u), lerp(v1[1], v2[1], u), lerp(v1[2], v2[2], u))
106+
return Vec3(lerp(v1[0], v2[0], u), lerp(v1[1], v2[1], u), lerp(v1[2], v2[2], u))
86107

87108

88109
def lerp_angle(start_angle: float, end_angle: float, u: float) -> float:

arcade/paths.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
from __future__ import annotations
66

77
import math
8-
from typing import cast
98

109
from arcade import Sprite, SpriteList, check_for_collision_with_list, get_sprites_at_point
1110
from arcade.math import get_distance, lerp_2d
12-
from arcade.types import Point, Point2
11+
from arcade.types import Point2
1312

1413
__all__ = ["AStarBarrierList", "astar_calculate_path", "has_line_of_sight"]
1514

@@ -33,13 +32,13 @@ def _spot_is_blocked(position: Point2, moving_sprite: Sprite, blocking_sprites:
3332
return len(hit_list) > 0
3433

3534

36-
def _heuristic(start: Point, goal: Point) -> float:
35+
def _heuristic(start: Point2, goal: Point2) -> float:
3736
"""
3837
Returns a heuristic value for the passed points.
3938
4039
Args:
41-
start (Point): The 1st point to compare
42-
goal (Point): The 2nd point to compare
40+
start (Point2): The 1st point to compare
41+
goal (Point2): The 2nd point to compare
4342
4443
Returns:
4544
float: The heuristic of the 2 points
@@ -102,7 +101,7 @@ def __init__(
102101
else:
103102
self.movement_directions = (1, 0), (-1, 0), (0, 1), (0, -1) # type: ignore
104103

105-
def get_vertex_neighbours(self, pos: Point) -> list[tuple[float, float]]:
104+
def get_vertex_neighbours(self, pos: Point2) -> list[tuple[float, float]]:
106105
"""
107106
Return neighbors for this point according to ``self.movement_directions``
108107
@@ -123,7 +122,7 @@ def get_vertex_neighbours(self, pos: Point) -> list[tuple[float, float]]:
123122
n.append((x2, y2))
124123
return n
125124

126-
def move_cost(self, a: Point, b: Point) -> float:
125+
def move_cost(self, a: Point2, b: Point2) -> float:
127126
"""
128127
Returns a float of the cost to move
129128
@@ -224,12 +223,12 @@ def _AStarSearch(start: Point2, end: Point2, graph: _AStarGraph) -> list[Point2]
224223
return None
225224

226225

227-
def _collapse(pos: Point, grid_size: float):
226+
def _collapse(pos: Point2, grid_size: float) -> tuple[int, int]:
228227
"""Makes Point pos smaller by grid_size"""
229228
return int(pos[0] // grid_size), int(pos[1] // grid_size)
230229

231230

232-
def _expand(pos: Point, grid_size: float):
231+
def _expand(pos: Point2, grid_size: float) -> tuple[int, int]:
233232
"""Makes Point pos larger by grid_size"""
234233
return int(pos[0] * grid_size), int(pos[1] * grid_size)
235234

@@ -329,11 +328,11 @@ def recalculate(self):
329328

330329

331330
def astar_calculate_path(
332-
start_point: Point,
333-
end_point: Point,
331+
start_point: Point2,
332+
end_point: Point2,
334333
astar_barrier_list: AStarBarrierList,
335334
diagonal_movement: bool = True,
336-
) -> list[Point] | None:
335+
) -> list[Point2] | None:
337336
"""
338337
Calculates the path using AStarSearch Algorithm and returns the path
339338
@@ -371,13 +370,13 @@ def astar_calculate_path(
371370

372371
# Currently 'result' is in grid locations. We need to convert them to pixel
373372
# locations.
374-
revised_result = [_expand(p, grid_size) for p in result]
375-
return cast(list[Point], revised_result)
373+
revised_result: list[Point2] = [_expand(p, grid_size) for p in result]
374+
return revised_result
376375

377376

378377
def has_line_of_sight(
379-
observer: Point,
380-
target: Point,
378+
observer: Point2,
379+
target: Point2,
381380
walls: SpriteList,
382381
max_distance: float = float("inf"),
383382
check_resolution: int = 2,
@@ -429,7 +428,7 @@ def has_line_of_sight(
429428

430429

431430
# NOTE: Rewrite this
432-
# def dda_step(start: Point, end: Point):
431+
# def dda_step(start: Point2, end: Point2):
433432
# """
434433
# Bresenham's line algorithm
435434

arcade/types/__init__.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# flake8: noqa: E402
2727
import sys
2828
from pathlib import Path
29-
from typing import NamedTuple, Union, TYPE_CHECKING, TypeVar, Iterable
29+
from typing import NamedTuple, Union, TYPE_CHECKING, TypeVar, Iterable, Protocol
3030

3131
from pytiled_parser import Properties
3232

@@ -124,6 +124,7 @@
124124
"Box",
125125
"LRBTNF",
126126
"XYZWHD",
127+
"HasAddSubMul",
127128
"RGB",
128129
"RGBA",
129130
"RGBOrA",
@@ -206,6 +207,23 @@ def annotated2(argument: OneOrIterableOf[MyType] | None = tuple()):
206207
# --- End potentially obsolete annotations ---
207208

208209

210+
# These are for the argument type + return type. They're separate TypeVars
211+
# to handle cases which take tuple but return Vec2 (e.g. pyglet.math.Vec2).
212+
_T_contra = TypeVar("_T_contra", contravariant=True) # Same or more general than T
213+
_T_co = TypeVar("_T_co", covariant=True) # Same or more specific than T
214+
215+
216+
class HasAddSubMul(Protocol[_T_contra, _T_co]):
217+
"""Matches types which work with :py:func:`arcade.math.lerp`."""
218+
219+
# The / matches float and similar operations to keep pyright
220+
# happy since built-in arithmetic makes them positional only.
221+
# See https://peps.python.org/pep-0570/
222+
def __add__(self, value: _T_contra, /) -> _T_co: ...
223+
def __sub__(self, value: _T_contra, /) -> _T_co: ...
224+
def __mul__(self, value: _T_contra, /) -> _T_co: ...
225+
226+
209227
# Path handling
210228
PathLike = Union[str, Path, bytes]
211229
_POr = TypeVar("_POr") # Allows PathOr[TypeNameHere] syntax

0 commit comments

Comments
 (0)