@@ -137,32 +137,54 @@ def test_group_statuses():
137137 )
138138
139139 # Make terminal steps for some dead agents
140- mock_decision_steps_2 , mock_terminal_steps_2 = mb .create_mock_steps (
140+ _ , mock_terminal_steps_2 = mb .create_mock_steps (
141141 num_agents = 2 ,
142142 observation_specs = create_observation_specs_with_shapes ([(8 ,)]),
143143 action_spec = ActionSpec .create_continuous (2 ),
144144 done = True ,
145145 grouped = True ,
146+ agent_ids = [2 , 3 ],
147+ )
148+ # Make decision steps continue for other agents
149+ mock_decision_steps_2 , _ = mb .create_mock_steps (
150+ num_agents = 2 ,
151+ observation_specs = create_observation_specs_with_shapes ([(8 ,)]),
152+ action_spec = ActionSpec .create_continuous (2 ),
153+ done = False ,
154+ grouped = True ,
155+ agent_ids = [0 , 1 ],
146156 )
147157
148158 processor .add_experiences (
149159 mock_decision_steps_2 , mock_terminal_steps_2 , 0 , fake_action_info
150160 )
151- fake_action_info = _create_action_info (4 , mock_decision_steps .agent_id )
161+ # Continue to add for remaining live agents
162+ fake_action_info = _create_action_info (4 , mock_decision_steps_2 .agent_id )
152163 for _ in range (3 ):
153164 processor .add_experiences (
154- mock_decision_steps , mock_terminal_steps , 0 , fake_action_info
165+ mock_decision_steps_2 , mock_terminal_steps , 0 , fake_action_info
155166 )
156167
157168 # Assert that four trajectories have been added to the Trainer
158169 assert len (tqueue .put .call_args_list ) == 4
159- # Last trajectory should be the longest
170+
171+ # Get the first trajectory, which should have been agent 2 (one of the killed agents)
160172 trajectory = tqueue .put .call_args_list [0 ][0 ][- 1 ]
173+ assert len (trajectory .steps ) == 3
174+ # Make sure trajectory has the right Groupmate Experiences.
175+ # All three steps should contain all agents
176+ for step in trajectory .steps :
177+ assert len (step .group_status ) == 3
178+
179+ # Last trajectory should be the longest. It should be that of agent 1, one of the surviving agents.
180+ trajectory = tqueue .put .call_args_list [- 1 ][0 ][- 1 ]
181+ assert len (trajectory .steps ) == 5
161182
162- # Make sure trajectory has the right Groupmate Experiences
183+ # Make sure trajectory has the right Groupmate Experiences.
184+ # THe first 3 steps should contain all of the obs (that 3rd step is also the terminal step of 2 of the agents)
163185 for step in trajectory .steps [0 :3 ]:
164186 assert len (step .group_status ) == 3
165- # After 2 agents has died
187+ # After 2 agents has died, there should only be 1 group status.
166188 for step in trajectory .steps [3 :]:
167189 assert len (step .group_status ) == 1
168190
0 commit comments