@@ -328,6 +328,212 @@ def _make_batched_step(
328
328
return (decision_step , terminal_step )
329
329
330
330
331
+ class MultiAgentEnvironment (BaseEnv ):
332
+ """
333
+ The MultiAgentEnvironment maintains a list of SimpleEnvironment, one for each agent.
334
+ When sending DecisionSteps and TerminalSteps to the trainers, it first batches the
335
+ decision steps from the individual environments. When setting actions, it indexes the
336
+ batched ActionTuple to obtain the ActionTuple for individual agents
337
+ """
338
+
339
+ def __init__ (
340
+ self ,
341
+ brain_names ,
342
+ step_size = STEP_SIZE ,
343
+ num_visual = 0 ,
344
+ num_vector = 1 ,
345
+ num_var_len = 0 ,
346
+ vis_obs_size = VIS_OBS_SIZE ,
347
+ vec_obs_size = OBS_SIZE ,
348
+ var_len_obs_size = VAR_LEN_SIZE ,
349
+ action_sizes = (1 , 0 ),
350
+ num_agents = 2 ,
351
+ ):
352
+ super ().__init__ ()
353
+ self .envs = {}
354
+ self .dones = {}
355
+ self .just_died = set ()
356
+ self .names = brain_names
357
+ self .final_rewards : Dict [str , List [float ]] = {}
358
+ for name in brain_names :
359
+ self .final_rewards [name ] = []
360
+ for i in range (num_agents ):
361
+ name_and_num = name + str (i )
362
+ self .envs [name_and_num ] = SimpleEnvironment (
363
+ [name ],
364
+ step_size ,
365
+ num_visual ,
366
+ num_vector ,
367
+ num_var_len ,
368
+ vis_obs_size ,
369
+ vec_obs_size ,
370
+ var_len_obs_size ,
371
+ action_sizes ,
372
+ )
373
+ self .dones [name_and_num ] = False
374
+ self .envs [name_and_num ].reset ()
375
+ # All envs have the same behavior spec, so just get the last one.
376
+ self .behavior_spec = self .envs [name_and_num ].behavior_spec
377
+ self .action_spec = self .envs [name_and_num ].action_spec
378
+ self .num_agents = num_agents
379
+
380
+ @property
381
+ def all_done (self ):
382
+ return all (self .dones .values ())
383
+
384
+ @property
385
+ def behavior_specs (self ):
386
+ behavior_dict = {}
387
+ for n in self .names :
388
+ behavior_dict [n ] = self .behavior_spec
389
+ return BehaviorMapping (behavior_dict )
390
+
391
+ def set_action_for_agent (self , behavior_name , agent_id , action ):
392
+ pass
393
+
394
+ def set_actions (self , behavior_name , action ):
395
+ # The ActionTuple contains the actions for all n_agents. This
396
+ # slices the ActionTuple into an action tuple for each environment
397
+ # and sets it. The index j is used to ignore agents that have already
398
+ # reached done.
399
+ j = 0
400
+ for i in range (self .num_agents ):
401
+ _act = ActionTuple ()
402
+ name_and_num = behavior_name + str (i )
403
+ env = self .envs [name_and_num ]
404
+ if not self .dones [name_and_num ]:
405
+ if self .action_spec .continuous_size > 0 :
406
+ _act .add_continuous (action .continuous [j : j + 1 ])
407
+ if self .action_spec .discrete_size > 0 :
408
+ _disc_list = [action .discrete [j , :]]
409
+ _act .add_discrete (np .array (_disc_list ))
410
+ j += 1
411
+ env .action [behavior_name ] = _act
412
+
413
+ def get_steps (self , behavior_name ):
414
+ # This gets the individual DecisionSteps and TerminalSteps
415
+ # from the envs and merges them into a batch to be sent
416
+ # to the AgentProcessor.
417
+ dec_vec_obs = []
418
+ dec_reward = []
419
+ dec_group_reward = []
420
+ dec_agent_id = []
421
+ dec_group_id = []
422
+ ter_vec_obs = []
423
+ ter_reward = []
424
+ ter_group_reward = []
425
+ ter_agent_id = []
426
+ ter_group_id = []
427
+ interrupted = []
428
+
429
+ action_mask = None
430
+ terminal_step = TerminalSteps .empty (self .behavior_spec )
431
+ decision_step = None
432
+ for i in range (self .num_agents ):
433
+ name_and_num = behavior_name + str (i )
434
+ env = self .envs [name_and_num ]
435
+ _dec , _term = env .step_result [behavior_name ]
436
+ if not self .dones [name_and_num ]:
437
+ dec_agent_id .append (i )
438
+ dec_group_id .append (1 )
439
+ if len (dec_vec_obs ) > 0 :
440
+ for j , obs in enumerate (_dec .obs ):
441
+ dec_vec_obs [j ] = np .concatenate ((dec_vec_obs [j ], obs ), axis = 0 )
442
+ else :
443
+ for obs in _dec .obs :
444
+ dec_vec_obs .append (obs )
445
+ dec_reward .append (_dec .reward [0 ])
446
+ dec_group_reward .append (_dec .group_reward [0 ])
447
+ if _dec .action_mask is not None :
448
+ if action_mask is None :
449
+ action_mask = []
450
+ if len (action_mask ) > 0 :
451
+ action_mask [0 ] = np .concatenate (
452
+ (action_mask [0 ], _dec .action_mask [0 ]), axis = 0
453
+ )
454
+ else :
455
+ action_mask .append (_dec .action_mask [0 ])
456
+ if len (_term .reward ) > 0 and name_and_num in self .just_died :
457
+ ter_agent_id .append (i )
458
+ ter_group_id .append (1 )
459
+ if len (ter_vec_obs ) > 0 :
460
+ for j , obs in enumerate (_term .obs ):
461
+ ter_vec_obs [j ] = np .concatenate ((ter_vec_obs [j ], obs ), axis = 0 )
462
+ else :
463
+ for obs in _term .obs :
464
+ ter_vec_obs .append (obs )
465
+ ter_reward .append (_term .reward [0 ])
466
+ ter_group_reward .append (_term .group_reward [0 ])
467
+ interrupted .append (False )
468
+ self .just_died .remove (name_and_num )
469
+ decision_step = DecisionSteps (
470
+ dec_vec_obs ,
471
+ dec_reward ,
472
+ dec_agent_id ,
473
+ action_mask ,
474
+ dec_group_id ,
475
+ dec_group_reward ,
476
+ )
477
+ terminal_step = TerminalSteps (
478
+ ter_vec_obs ,
479
+ ter_reward ,
480
+ interrupted ,
481
+ ter_agent_id ,
482
+ ter_group_id ,
483
+ ter_group_reward ,
484
+ )
485
+ return (decision_step , terminal_step )
486
+
487
+ def step (self ) -> None :
488
+ # Steps all environments and calls reset if all agents are done.
489
+ for name in self .names :
490
+ for i in range (self .num_agents ):
491
+ name_and_num = name + str (i )
492
+ # Does not step the env if done
493
+ if not self .dones [name_and_num ]:
494
+ env = self .envs [name_and_num ]
495
+ # Reproducing part of env step to intercept Dones
496
+ assert all (action is not None for action in env .action .values ())
497
+ done = env ._take_action (name )
498
+ reward = env ._compute_reward (name , done )
499
+ self .dones [name_and_num ] = done
500
+ if done :
501
+ self .just_died .add (name_and_num )
502
+ if self .all_done :
503
+ env .step_result [name ] = env ._make_batched_step (
504
+ name , done , 0.0 , reward
505
+ )
506
+ self .final_rewards [name ].append (reward )
507
+ self .reset ()
508
+ elif done :
509
+ # This agent has finished but others are still running.
510
+ # This gives a reward of the time penalty if this agent
511
+ # is successful and the negative env reward if it fails.
512
+ ceil_reward = min (- TIME_PENALTY , reward )
513
+ env .step_result [name ] = env ._make_batched_step (
514
+ name , done , ceil_reward , 0.0
515
+ )
516
+ self .final_rewards [name ].append (reward )
517
+
518
+ else :
519
+ env .step_result [name ] = env ._make_batched_step (
520
+ name , done , reward , 0.0
521
+ )
522
+
523
+ def reset (self ) -> None : # type: ignore
524
+ for name in self .names :
525
+ for i in range (self .num_agents ):
526
+ name_and_num = name + str (i )
527
+ self .dones [name_and_num ] = False
528
+
529
+ @property
530
+ def reset_parameters (self ) -> Dict [str , str ]:
531
+ return {}
532
+
533
+ def close (self ):
534
+ pass
535
+
536
+
331
537
class RecordEnvironment (SimpleEnvironment ):
332
538
def __init__ (
333
539
self ,
0 commit comments