@@ -56,9 +56,9 @@ def __init__(self, seed, brain, trainer_parameters):
56
56
self .seed = seed
57
57
self .brain = brain
58
58
self .use_recurrent = trainer_parameters ["use_recurrent" ]
59
- self .memory_dict : Dict [int , np .ndarray ] = {}
59
+ self .memory_dict : Dict [str , np .ndarray ] = {}
60
60
self .num_branches = len (self .brain .vector_action_space_size )
61
- self .previous_action_dict : Dict [int , np .array ] = {}
61
+ self .previous_action_dict : Dict [str , np .array ] = {}
62
62
self .normalize = trainer_parameters .get ("normalize" , False )
63
63
self .use_continuous_act = brain .vector_action_space_type == "continuous"
64
64
if self .use_continuous_act :
@@ -181,14 +181,14 @@ def make_empty_memory(self, num_agents):
181
181
return np .zeros ((num_agents , self .m_size ), dtype = np .float )
182
182
183
183
def save_memories (
184
- self , agent_ids : List [int ], memory_matrix : Optional [np .ndarray ]
184
+ self , agent_ids : List [str ], memory_matrix : Optional [np .ndarray ]
185
185
) -> None :
186
186
if memory_matrix is None :
187
187
return
188
188
for index , agent_id in enumerate (agent_ids ):
189
189
self .memory_dict [agent_id ] = memory_matrix [index , :]
190
190
191
- def retrieve_memories (self , agent_ids : List [int ]) -> np .ndarray :
191
+ def retrieve_memories (self , agent_ids : List [str ]) -> np .ndarray :
192
192
memory_matrix = np .zeros ((len (agent_ids ), self .m_size ), dtype = np .float )
193
193
for index , agent_id in enumerate (agent_ids ):
194
194
if agent_id in self .memory_dict :
@@ -209,14 +209,14 @@ def make_empty_previous_action(self, num_agents):
209
209
return np .zeros ((num_agents , self .num_branches ), dtype = np .int )
210
210
211
211
def save_previous_action (
212
- self , agent_ids : List [int ], action_matrix : Optional [np .ndarray ]
212
+ self , agent_ids : List [str ], action_matrix : Optional [np .ndarray ]
213
213
) -> None :
214
214
if action_matrix is None :
215
215
return
216
216
for index , agent_id in enumerate (agent_ids ):
217
217
self .previous_action_dict [agent_id ] = action_matrix [index , :]
218
218
219
- def retrieve_previous_action (self , agent_ids : List [int ]) -> np .ndarray :
219
+ def retrieve_previous_action (self , agent_ids : List [str ]) -> np .ndarray :
220
220
action_matrix = np .zeros ((len (agent_ids ), self .num_branches ), dtype = np .int )
221
221
for index , agent_id in enumerate (agent_ids ):
222
222
if agent_id in self .previous_action_dict :
0 commit comments