Skip to content

Commit d25348c

Browse files
author
Ervin T
authored
[bug-fix] Fix stats reporting for reward signals in SAC (#3606)
1 parent 200ab7b commit d25348c

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

ml-agents/mlagents/trainers/policy/tf_policy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def __init__(self, seed, brain, trainer_parameters, load=False):
5656

5757
self.use_recurrent = trainer_parameters["use_recurrent"]
5858
self.memory_dict: Dict[str, np.ndarray] = {}
59-
self.reward_signals: Dict[str, "RewardSignal"] = {}
6059
self.num_branches = len(self.brain.vector_action_space_size)
6160
self.previous_action_dict: Dict[str, np.array] = {}
6261
self.normalize = trainer_parameters.get("normalize", False)

ml-agents/mlagents/trainers/sac/trainer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _process_trajectory(self, trajectory: Trajectory) -> None:
156156
self.collected_rewards["environment"][agent_id] += np.sum(
157157
agent_buffer_trajectory["environment_rewards"]
158158
)
159-
for name, reward_signal in self.policy.reward_signals.items():
159+
for name, reward_signal in self.optimizer.reward_signals.items():
160160
evaluate_result = reward_signal.evaluate_batch(
161161
agent_buffer_trajectory
162162
).scaled_reward
@@ -223,9 +223,6 @@ def create_policy(self, brain_parameters: BrainParameters) -> TFPolicy:
223223
reparameterize=True,
224224
create_tf_graph=False,
225225
)
226-
for _reward_signal in policy.reward_signals.keys():
227-
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
228-
229226
# Load the replay buffer if load
230227
if self.load and self.checkpoint_replay_buffer:
231228
try:

ml-agents/mlagents/trainers/tests/test_sac.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,10 @@ def test_process_trajectory(dummy_config):
221221
for agent in reward.values():
222222
assert agent == 0
223223
assert trainer.stats_reporter.get_stats_summaries("Policy/Extrinsic Reward").num > 0
224+
# Assert we're not just using the default values
225+
assert (
226+
trainer.stats_reporter.get_stats_summaries("Policy/Extrinsic Reward").mean > 0
227+
)
224228

225229

226230
if __name__ == "__main__":

0 commit comments

Comments
 (0)