Skip to content

Commit 3f4c392

Browse files
author
Vincent Moens
committed
amend
1 parent 30ec91c commit 3f4c392

File tree

6 files changed

+151
-90
lines changed

6 files changed

+151
-90
lines changed

sota-implementations/MCTS/AlphaZero/mcts_node.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

sota-implementations/MCTS/AlphaZero/mcts_policy.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from torchrl.objectives.value.functional import reward2go
2121

22-
from .mcts_node import MCTSNode
22+
from torchrl.data import MCTSNode, MCTSChildren
2323

2424

2525
@dataclass
@@ -64,17 +64,17 @@ def forward(self, node: MCTSNode) -> TensorDictBase:
6464

6565
if exploration_type() == ExplorationType.RANDOM or exploration_type() is None:
6666
tensordict[self.action_key] = self.explore_action(node)
67-
elif exploration_type() == ExplorationType.MODE:
67+
elif exploration_type() in (ExplorationType.MODE, ExplorationType.DETERMINISTIC, ExplorationType.MEAN):
6868
tensordict[self.action_key] = self.get_greedy_action(node)
6969

7070
return tensordict
7171

7272
def get_greedy_action(self, node: MCTSNode) -> torch.Tensor:
73-
action = torch.argmax(node.children_visits)
73+
action = torch.argmax(node.children.visits)
7474
return action
7575

7676
def explore_action(self, node: MCTSNode) -> torch.Tensor:
77-
action_scores = node.scores
77+
action_scores = node.score
7878

7979
max_value = torch.max(action_scores)
8080
action = torch.argmax(
@@ -156,9 +156,6 @@ class ExpansionStrategy:
156156
This policy will use to initialize a node when it gets expanded at the first time.
157157
"""
158158

159-
def __init__(self):
160-
super().__init__()
161-
162159
def forward(self, node: MCTSNode) -> MCTSNode:
163160
"""The node to be expanded.
164161
@@ -179,7 +176,7 @@ def forward(self, node: MCTSNode) -> MCTSNode:
179176

180177
@abstractmethod
181178
def expand(self, node: MCTSNode) -> None:
182-
pass
179+
...
183180

184181
def set_node(self, node: MCTSNode) -> None:
185182
self.node = node
@@ -189,7 +186,7 @@ class BatchedRootExpansionStrategy(ExpansionStrategy):
189186
def __init__(
190187
self,
191188
policy_module: TensorDictModule,
192-
module_action_value_key: str = "action_value",
189+
module_action_value_key: NestedKey = "action_value",
193190
):
194191
super().__init__()
195192
assert module_action_value_key in policy_module.out_keys
@@ -200,17 +197,15 @@ def expand(self, node: MCTSNode) -> None:
200197
policy_netword_td = node.state.select(*self.policy_module.in_keys)
201198
policy_netword_td = self.policy_module(policy_netword_td)
202199
p_sa = policy_netword_td[self.action_value_key]
203-
node.children_priors = p_sa # prior_action_value
204-
node.children_values = torch.zeros_like(p_sa) # action_value
205-
node.children_visits = torch.zeros_like(p_sa) # action_count
200+
node.children = MCTSChildren.init_from_prob(p_sa)
206201
# setattr(node, "truncated", torch.ones(1, dtype=torch.bool))
207202

208203

209204
class AlphaZeroExpansionStrategy(ExpansionStrategy):
210205
def __init__(
211206
self,
212207
policy_module: TensorDictModule,
213-
module_action_value_key: str = "action_value",
208+
module_action_value_key: NestedKey = "action_value",
214209
):
215210
super().__init__()
216211
assert module_action_value_key in policy_module.out_keys
@@ -221,9 +216,9 @@ def expand(self, node: MCTSNode) -> None:
221216
policy_netword_td = node.state.select(*self.policy_module.in_keys)
222217
policy_netword_td = self.policy_module(policy_netword_td)
223218
p_sa = policy_netword_td[self.action_value_key]
224-
node.children_priors = p_sa # prior_action_value
225-
node.children_values = torch.zeros_like(p_sa) # action_value
226-
node.children_visits = torch.zeros_like(p_sa) # action_count
219+
node.children.priors = p_sa # prior_action_value
220+
node.children.vals = torch.zeros_like(p_sa) # action_value
221+
node.children.visits = torch.zeros_like(p_sa) # action_count
227222
# setattr(node, "truncated", torch.ones(1, dtype=torch.bool))
228223

229224

@@ -244,15 +239,15 @@ def __init__(
244239
self.node: MCTSNode
245240

246241
def forward(self, node: MCTSNode) -> MCTSNode:
247-
n = torch.sum(node.children_visits, dim=-1) + 1
242+
n = torch.sum(node.children.visits, dim=-1) + 1
248243
u_sa = (
249244
self.cpuct
250-
* node.children_priors
245+
* node.children.priors
251246
* torch.sqrt(n)
252-
/ (1 + node.children_visits)
247+
/ (1 + node.children.visits)
253248
)
254249

255-
optimism_estimation = node.children_values + u_sa
250+
optimism_estimation = node.children.vals + u_sa
256251
node.scores = optimism_estimation
257252

258253
return node
@@ -270,17 +265,17 @@ def __init__(
270265
self.epsilon = epsilon
271266

272267
def forward(self, node: MCTSNode) -> MCTSNode:
273-
if node.children_priors.device.type == "mps":
274-
device = node.children_priors.device
268+
if node.children.priors.device.type == "mps":
269+
device = node.children.priors.device
275270
noise = _Dirichlet.apply(
276-
self.alpha * torch.ones_like(node.children_priors).cpu()
271+
self.alpha * torch.ones_like(node.children.priors).cpu()
277272
)
278273
noise = noise.to(device) # type: ignore
279274
else:
280-
noise = _Dirichlet.apply(self.alpha * torch.ones_like(node.children_priors))
275+
noise = _Dirichlet.apply(self.alpha * torch.ones_like(node.children.priors))
281276

282-
noisy_priors = (1 - self.epsilon) * node.children_priors + self.epsilon * noise # type: ignore
283-
node.children_priors = noisy_priors
277+
noisy_priors = (1 - self.epsilon) * node.children.priors + self.epsilon * noise # type: ignore
278+
node.children.priors = noisy_priors
284279
return node
285280

286281

@@ -293,6 +288,8 @@ class MCTSPolicy(TensorDictModuleBase):
293288
exploration_strategy: a policy to exploration vs exploitation
294289
"""
295290

291+
node: MCTSNode
292+
296293
def __init__(
297294
self,
298295
expansion_strategy: AlphaZeroExpansionStrategy,
@@ -313,10 +310,11 @@ def __init__(
313310
self.expansion_strategy = expansion_strategy
314311
self.selection_strategy = selection_strategy
315312
self.exploration_strategy = exploration_strategy
316-
self.node: MCTSNode
317313
self.batch_size = batch_size
318314

319315
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
316+
if not hasattr(self, "node"):
317+
raise RuntimeError("the MCTS policy has not been initialized. Please provide a node through policy.set_node().")
320318
if not self.node.expanded:
321319
self.node.state = tensordict # type: ignore
322320
self.expansion_strategy.forward(self.node)

torchrl/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,4 @@
7474
UnboundedDiscreteTensorSpec,
7575
)
7676
from .utils import check_no_exclusive_keys, consolidate_spec, contains_lazy_spec
77+
from .mcts import MCTSNode, MCTSChildren

torchrl/data/mcts/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from .nodes import MCTSNode, MCTSChildren

torchrl/data/mcts/nodes.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import torch
8+
from tensordict import tensorclass, TensorDict
9+
10+
@tensorclass(autocast=True)
11+
class MCTSChildren:
12+
vals: torch.Tensor
13+
priors: torch.Tensor
14+
visits: torch.Tensor
15+
ids: torch.Tensor | None = None
16+
nodes: MCTSNode | None = None
17+
18+
@classmethod
19+
def init_from_prob(cls, probs):
20+
vals = torch.zeros_like(probs)
21+
visits = torch.zeros_like(probs)
22+
return cls(vals=vals, priors=probs, visits=visits)
23+
24+
25+
@tensorclass(autocast=True)
26+
class MCTSNode:
27+
prior_action: torch.Tensor
28+
_children: MCTSChildren | None = None
29+
score: torch.Tensor | None = None
30+
state: TensorDict | None = None
31+
terminated: torch.Tensor | None = None
32+
parent: MCTSNode | None = None
33+
34+
@classmethod
35+
def from_action(
36+
cls,
37+
action: torch.Tensor,
38+
parent: MCTSNode | None,
39+
):
40+
return cls(prior_action=action, parent=parent)
41+
42+
@property
43+
def children(self) -> MCTSChildren:
44+
children = self._children
45+
if children is None:
46+
return MCTSChildren(*[torch.zeros((), device=self.device) for _ in range(4)])
47+
return children
48+
49+
@children.setter
50+
def children(self, value):
51+
self._children = value
52+
53+
@property
54+
def visits(self) -> torch.Tensor:
55+
assert self.parent is not None
56+
return self.parent.children.visits[self.prior_action]
57+
58+
@visits.setter
59+
def visits(self, x) -> None:
60+
assert self.parent is not None
61+
self.parent.children.visits[self.prior_action] = x
62+
63+
@property
64+
def value(self) -> torch.Tensor:
65+
assert self.parent is not None
66+
return self.parent.children.vals[self.prior_action]
67+
68+
@value.setter
69+
def value(self, x) -> None:
70+
assert self.parent is not None
71+
self.parent.children.vals[self.prior_action] = x
72+
73+
@property
74+
def expanded(self) -> bool:
75+
return self.children.ids.numel() > 0
76+
77+
def get_child(self, action: torch.Tensor) -> MCTSNode:
78+
idx = (self.children.ids == action).all(-1)
79+
return self.children.nodes[idx] # type: ignore
80+
81+
@classmethod
82+
def root(cls) -> MCTSNode:
83+
return cls(torch.Tensor(-1), None)
84+
85+
@classmethod
86+
def dummy(cls):
87+
"""Creates a 'dummy' MCTSNode that can be used to explore TorchRL's MCTS API."""
88+
children_values = stuff
89+
children_priors = stuff
90+
children_visits = stuff
91+
children_ids = stuff
92+
children_nodes = stuff
93+
children = MCTSChildren(
94+
values = children_values,
95+
priors = children_priors,
96+
visits = children_visits,
97+
ids = children_ids,
98+
nodes = children_nodes,
99+
)
100+
prior_action = stuff
101+
score = stuff
102+
state = stuff
103+
terminated = stuff
104+
parent = None
105+
return cls(
106+
prior_action=prior_action,
107+
children=children,
108+
score=score,
109+
state=state,
110+
terminated=terminated,
111+
parent=parent,
112+
)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from __future__ import annotations
7+

0 commit comments

Comments
 (0)