Skip to content

Commit b6aa75a

Browse files
committed
Make _move_sprite() generic
1 parent e83e77b commit b6aa75a

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

arcade/physics_engines.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
# pylint: disable=too-many-arguments, too-many-locals, too-few-public-methods
55

66
import math
7-
from typing import Iterable, List, Optional, Union
7+
from typing import Iterable, List, Optional, Union, cast
88

99
from arcade import (
1010
BasicSprite,
1111
Sprite,
1212
SpriteList,
13+
SpriteType,
1314
check_for_collision,
1415
check_for_collision_with_lists
1516
)
@@ -54,7 +55,7 @@ def _circular_check(player: Sprite, walls: List[SpriteList]):
5455
vary *= 2
5556

5657

57-
def _move_sprite(moving_sprite: Sprite, walls: List[SpriteList], ramp_up: bool) -> List[Union[Sprite, BasicSprite]]:
58+
def _move_sprite(moving_sprite: Sprite, walls: List[SpriteList[SpriteType]], ramp_up: bool) -> List[SpriteType]:
5859

5960
# See if we are starting this turn with a sprite already colliding with us.
6061
if len(check_for_collision_with_lists(moving_sprite, walls)) > 0:
@@ -231,12 +232,12 @@ class PhysicsEngineSimple:
231232
This can be one or multiple spritelists.
232233
"""
233234

234-
def __init__(self, player_sprite: Sprite, walls: Union[SpriteList, Iterable[SpriteList]]):
235+
def __init__(self, player_sprite: Sprite, walls: Union[SpriteList[BasicSprite], Iterable[SpriteList[BasicSprite]]]):
235236
assert isinstance(player_sprite, Sprite)
236237

237238
if walls:
238239
if isinstance(walls, SpriteList):
239-
self.walls = [walls]
240+
self.walls = [cast(SpriteList[BasicSprite], walls)]
240241
else:
241242
self.walls = list(walls)
242243
else:

arcade/sprite_list/collision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def get_closest_sprite(
5959
return sprite_list[min_pos], min_distance
6060

6161

62-
def check_for_collision(sprite1: SpriteType, sprite2: SpriteType) -> bool:
62+
def check_for_collision(sprite1: BasicSprite, sprite2: BasicSprite) -> bool:
6363
"""
6464
Check for a collision between two sprites.
6565

0 commit comments

Comments
 (0)