Skip to content

Commit 61d31c7

Browse files
authored
Create rewards.py
this is needed to prevent clutter
1 parent 491f2b7 commit 61d31c7

File tree

1 file changed

+139
-0
lines changed

1 file changed

+139
-0
lines changed

rewards.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from typing import List, Dict, Any
2+
from rlgym.api import RewardFunction, AgentID
3+
from rlgym.rocket_league.api import GameState
4+
from rlgym.rocket_league import common_values
5+
import numpy as np
6+
7+
class AdvancedTouchReward(RewardFunction[AgentID, GameState, float]):
8+
def __init__(self, touch_reward: float = 0.0, acceleration_reward: float = 1, use_touch_count: bool = False):
9+
self.touch_reward = touch_reward
10+
self.acceleration_reward = acceleration_reward
11+
self.use_touch_count = use_touch_count
12+
13+
self.prev_ball = None
14+
15+
def reset(self, agents: List[AgentID], initial_state: GameState, shared_info: Dict[str, Any]) -> None:
16+
self.prev_ball = initial_state.ball
17+
18+
def get_rewards(self, agents: List[AgentID], state: GameState, is_terminated: Dict[AgentID, bool],
19+
is_truncated: Dict[AgentID, bool], shared_info: Dict[str, Any]) -> Dict[AgentID, float]:
20+
rewards = {agent: 0 for agent in agents}
21+
ball = state.ball
22+
for agent in agents:
23+
touches = state.cars[agent].ball_touches
24+
25+
if touches > 0:
26+
if not self.use_touch_count:
27+
touches = 1
28+
acceleration = np.linalg.norm(ball.linear_velocity - self.prev_ball.linear_velocity) / BALL_MAX_SPEED
29+
rewards[agent] += self.touch_reward * touches
30+
rewards[agent] += acceleration * self.acceleration_reward
31+
32+
self.prev_ball = ball
33+
34+
return rewards
35+
36+
class FaceBallReward(RewardFunction):
37+
"""Rewards the agent for facing the ball"""
38+
def reset(self, agents: List[AgentID], initial_state: GameState, shared_info: Dict[str, Any]) -> None:
39+
pass
40+
41+
42+
def get_rewards(self, agents: List[AgentID], state: GameState, is_terminated: Dict[AgentID, bool],
43+
is_truncated: Dict[AgentID, bool], shared_info: Dict[str, Any]) -> Dict[AgentID, float]:
44+
rewards = {}
45+
46+
for agent in agents:
47+
car = state.cars[agent]
48+
ball = state.ball
49+
50+
car_pos = car.physics.position
51+
ball_pos = ball.position
52+
direction_to_ball = ball_pos - car_pos
53+
norm = np.linalg.norm(direction_to_ball)
54+
55+
if norm > 0:
56+
direction_to_ball /= norm
57+
58+
car_forward = car.physics.forward
59+
dot_product = np.dot(car_forward, direction_to_ball)
60+
61+
reward = dot_product # Dot product directly indicates alignment (-1 to 1)
62+
rewards[agent] = reward
63+
64+
return rewards
65+
66+
class SpeedTowardBallReward(RewardFunction[AgentID, GameState, float]):
67+
"""Rewards the agent for moving quickly toward the ball"""
68+
69+
def reset(self, agents: List[AgentID], initial_state: GameState, shared_info: Dict[str, Any]) -> None:
70+
pass
71+
72+
def get_rewards(self, agents: List[AgentID], state: GameState, is_terminated: Dict[AgentID, bool],
73+
is_truncated: Dict[AgentID, bool], shared_info: Dict[str, Any]) -> Dict[AgentID, float]:
74+
rewards = {}
75+
for agent in agents:
76+
car = state.cars[agent]
77+
car_physics = car.physics if car.is_orange else car.inverted_physics
78+
ball_physics = state.ball if car.is_orange else state.inverted_ball
79+
player_vel = car_physics.linear_velocity
80+
pos_diff = (ball_physics.position - car_physics.position)
81+
dist_to_ball = np.linalg.norm(pos_diff)
82+
dir_to_ball = pos_diff / dist_to_ball
83+
84+
speed_toward_ball = np.dot(player_vel, dir_to_ball)
85+
86+
rewards[agent] = max(speed_toward_ball / common_values.CAR_MAX_SPEED, 0.0)
87+
return rewards
88+
89+
class InAirReward(RewardFunction[AgentID, GameState, float]):
90+
"""Rewards the agent for being in the air"""
91+
92+
def reset(self, agents: List[AgentID], initial_state: GameState, shared_info: Dict[str, Any]) -> None:
93+
pass
94+
95+
def get_rewards(self, agents: List[AgentID], state: GameState, is_terminated: Dict[AgentID, bool],
96+
is_truncated: Dict[AgentID, bool], shared_info: Dict[str, Any]) -> Dict[AgentID, float]:
97+
return {agent: float(not state.cars[agent].on_ground) for agent in agents}
98+
99+
class VelocityBallToGoalReward(RewardFunction[AgentID, GameState, float]):
100+
"""Rewards the agent for hitting the ball toward the opponent's goal"""
101+
102+
def reset(self, agents: List[AgentID], initial_state: GameState, shared_info: Dict[str, Any]) -> None:
103+
pass
104+
105+
def get_rewards(self, agents: List[AgentID], state: GameState, is_terminated: Dict[AgentID, bool],
106+
is_truncated: Dict[AgentID, bool], shared_info: Dict[str, Any]) -> Dict[AgentID, float]:
107+
rewards = {}
108+
for agent in agents:
109+
car = state.cars[agent]
110+
ball = state.ball
111+
if car.is_orange:
112+
goal_y = -common_values.BACK_NET_Y
113+
else:
114+
goal_y = common_values.BACK_NET_Y
115+
116+
ball_vel = ball.linear_velocity
117+
pos_diff = np.array([0, goal_y, 0]) - ball.position
118+
dist = np.linalg.norm(pos_diff)
119+
dir_to_goal = pos_diff / dist
120+
121+
vel_toward_goal = np.dot(ball_vel, dir_to_goal)
122+
rewards[agent] = max(vel_toward_goal / common_values.BALL_MAX_SPEED, 0)
123+
return rewards
124+
125+
126+
class TouchReward(RewardFunction[AgentID, GameState, float]):
127+
"""
128+
A RewardFunction that gives a reward of 1 if the agent touches the ball, 0 otherwise.
129+
"""
130+
131+
def reset(self, agents: List[AgentID], initial_state: GameState, shared_info: Dict[str, Any]) -> None:
132+
pass
133+
134+
def get_rewards(self, agents: List[AgentID], state: GameState, is_terminated: Dict[AgentID, bool],
135+
is_truncated: Dict[AgentID, bool], shared_info: Dict[str, Any]) -> Dict[AgentID, float]:
136+
return {agent: self._get_reward(agent, state) for agent in agents}
137+
138+
def _get_reward(self, agent: AgentID, state: GameState) -> float:
139+
return 1. if state.cars[agent].ball_touches > 0 else 0.

0 commit comments

Comments
 (0)