19
19
20
20
from torchrl .objectives .value .functional import reward2go
21
21
22
- from . mcts_node import MCTSNode
22
+ from torchrl . data import MCTSNode , MCTSChildren
23
23
24
24
25
25
@dataclass
@@ -64,17 +64,17 @@ def forward(self, node: MCTSNode) -> TensorDictBase:
64
64
65
65
if exploration_type () == ExplorationType .RANDOM or exploration_type () is None :
66
66
tensordict [self .action_key ] = self .explore_action (node )
67
- elif exploration_type () == ExplorationType .MODE :
67
+ elif exploration_type () in ( ExplorationType .MODE , ExplorationType . DETERMINISTIC , ExplorationType . MEAN ) :
68
68
tensordict [self .action_key ] = self .get_greedy_action (node )
69
69
70
70
return tensordict
71
71
72
72
def get_greedy_action (self , node : MCTSNode ) -> torch .Tensor :
73
- action = torch .argmax (node .children_visits )
73
+ action = torch .argmax (node .children . visits )
74
74
return action
75
75
76
76
def explore_action (self , node : MCTSNode ) -> torch .Tensor :
77
- action_scores = node .scores
77
+ action_scores = node .score
78
78
79
79
max_value = torch .max (action_scores )
80
80
action = torch .argmax (
@@ -156,9 +156,6 @@ class ExpansionStrategy:
156
156
This policy will use to initialize a node when it gets expanded at the first time.
157
157
"""
158
158
159
- def __init__ (self ):
160
- super ().__init__ ()
161
-
162
159
def forward (self , node : MCTSNode ) -> MCTSNode :
163
160
"""The node to be expanded.
164
161
@@ -179,7 +176,7 @@ def forward(self, node: MCTSNode) -> MCTSNode:
179
176
180
177
@abstractmethod
181
178
def expand (self , node : MCTSNode ) -> None :
182
- pass
179
+ ...
183
180
184
181
def set_node (self , node : MCTSNode ) -> None :
185
182
self .node = node
@@ -189,7 +186,7 @@ class BatchedRootExpansionStrategy(ExpansionStrategy):
189
186
def __init__ (
190
187
self ,
191
188
policy_module : TensorDictModule ,
192
- module_action_value_key : str = "action_value" ,
189
+ module_action_value_key : NestedKey = "action_value" ,
193
190
):
194
191
super ().__init__ ()
195
192
assert module_action_value_key in policy_module .out_keys
@@ -200,17 +197,15 @@ def expand(self, node: MCTSNode) -> None:
200
197
policy_netword_td = node .state .select (* self .policy_module .in_keys )
201
198
policy_netword_td = self .policy_module (policy_netword_td )
202
199
p_sa = policy_netword_td [self .action_value_key ]
203
- node .children_priors = p_sa # prior_action_value
204
- node .children_values = torch .zeros_like (p_sa ) # action_value
205
- node .children_visits = torch .zeros_like (p_sa ) # action_count
200
+ node .children = MCTSChildren .init_from_prob (p_sa )
206
201
# setattr(node, "truncated", torch.ones(1, dtype=torch.bool))
207
202
208
203
209
204
class AlphaZeroExpansionStrategy (ExpansionStrategy ):
210
205
def __init__ (
211
206
self ,
212
207
policy_module : TensorDictModule ,
213
- module_action_value_key : str = "action_value" ,
208
+ module_action_value_key : NestedKey = "action_value" ,
214
209
):
215
210
super ().__init__ ()
216
211
assert module_action_value_key in policy_module .out_keys
@@ -221,9 +216,9 @@ def expand(self, node: MCTSNode) -> None:
221
216
policy_netword_td = node .state .select (* self .policy_module .in_keys )
222
217
policy_netword_td = self .policy_module (policy_netword_td )
223
218
p_sa = policy_netword_td [self .action_value_key ]
224
- node .children_priors = p_sa # prior_action_value
225
- node .children_values = torch .zeros_like (p_sa ) # action_value
226
- node .children_visits = torch .zeros_like (p_sa ) # action_count
219
+ node .children . priors = p_sa # prior_action_value
220
+ node .children . vals = torch .zeros_like (p_sa ) # action_value
221
+ node .children . visits = torch .zeros_like (p_sa ) # action_count
227
222
# setattr(node, "truncated", torch.ones(1, dtype=torch.bool))
228
223
229
224
@@ -244,15 +239,15 @@ def __init__(
244
239
self .node : MCTSNode
245
240
246
241
def forward (self , node : MCTSNode ) -> MCTSNode :
247
- n = torch .sum (node .children_visits , dim = - 1 ) + 1
242
+ n = torch .sum (node .children . visits , dim = - 1 ) + 1
248
243
u_sa = (
249
244
self .cpuct
250
- * node .children_priors
245
+ * node .children . priors
251
246
* torch .sqrt (n )
252
- / (1 + node .children_visits )
247
+ / (1 + node .children . visits )
253
248
)
254
249
255
- optimism_estimation = node .children_values + u_sa
250
+ optimism_estimation = node .children . vals + u_sa
256
251
node .scores = optimism_estimation
257
252
258
253
return node
@@ -270,17 +265,17 @@ def __init__(
270
265
self .epsilon = epsilon
271
266
272
267
def forward (self , node : MCTSNode ) -> MCTSNode :
273
- if node .children_priors .device .type == "mps" :
274
- device = node .children_priors .device
268
+ if node .children . priors .device .type == "mps" :
269
+ device = node .children . priors .device
275
270
noise = _Dirichlet .apply (
276
- self .alpha * torch .ones_like (node .children_priors ).cpu ()
271
+ self .alpha * torch .ones_like (node .children . priors ).cpu ()
277
272
)
278
273
noise = noise .to (device ) # type: ignore
279
274
else :
280
- noise = _Dirichlet .apply (self .alpha * torch .ones_like (node .children_priors ))
275
+ noise = _Dirichlet .apply (self .alpha * torch .ones_like (node .children . priors ))
281
276
282
- noisy_priors = (1 - self .epsilon ) * node .children_priors + self .epsilon * noise # type: ignore
283
- node .children_priors = noisy_priors
277
+ noisy_priors = (1 - self .epsilon ) * node .children . priors + self .epsilon * noise # type: ignore
278
+ node .children . priors = noisy_priors
284
279
return node
285
280
286
281
@@ -293,6 +288,8 @@ class MCTSPolicy(TensorDictModuleBase):
293
288
exploration_strategy: a policy to exploration vs exploitation
294
289
"""
295
290
291
+ node : MCTSNode
292
+
296
293
def __init__ (
297
294
self ,
298
295
expansion_strategy : AlphaZeroExpansionStrategy ,
@@ -313,10 +310,11 @@ def __init__(
313
310
self .expansion_strategy = expansion_strategy
314
311
self .selection_strategy = selection_strategy
315
312
self .exploration_strategy = exploration_strategy
316
- self .node : MCTSNode
317
313
self .batch_size = batch_size
318
314
319
315
def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
316
+ if not hasattr (self , "node" ):
317
+ raise RuntimeError ("the MCTS policy has not been initialized. Please provide a node through policy.set_node()." )
320
318
if not self .node .expanded :
321
319
self .node .state = tensordict # type: ignore
322
320
self .expansion_strategy .forward (self .node )
0 commit comments