@@ -318,7 +318,7 @@ def _make_td(state: torch.Tensor, action: torch.Tensor) -> TensorDict:
318318 def _make_forest (self ) -> MCTSForest :
319319 r0 , r1 , r2 , r3 , r4 = self .dummy_rollouts ()
320320 assert r0 .shape
321- forest = MCTSForest (consolidated = True )
321+ forest = MCTSForest ()
322322 forest .extend (r0 )
323323 forest .extend (r1 )
324324 forest .extend (r2 )
@@ -363,10 +363,24 @@ def _make_forest_intersect(self) -> MCTSForest:
363363 forest .extend (rollout5 )
364364 return forest
365365
366+ @staticmethod
367+ def make_labels (tree ):
368+ if tree .rollout is not None :
369+ s = torch .cat (
370+ [
371+ tree .rollout ["observation" ][:1 ],
372+ tree .rollout ["next" , "observation" ],
373+ ]
374+ )
375+ s = s .tolist ()
376+ return f"{ tree .node_id } : { s } "
377+ return f"{ tree .node_id } "
378+
366379 def test_forest_build (self ):
367380 r0 , * _ = self .dummy_rollouts ()
368381 forest = self ._make_forest ()
369382 tree = forest .get_tree (r0 [0 ])
383+ # tree.plot(make_labels=self.make_labels)
370384
371385 def test_forest_vertices (self ):
372386 r0 , * _ = self .dummy_rollouts ()
@@ -436,18 +450,6 @@ def test_forest_intersect(self):
436450 tree = forest .get_tree (state0 )
437451 subtree = forest .get_tree (TensorDict (observation = 19 ))
438452
439- def make_labels (tree ):
440- if tree .rollout is not None :
441- s = torch .cat (
442- [
443- tree .rollout ["observation" ][:1 ],
444- tree .rollout ["next" , "observation" ],
445- ]
446- )
447- s = s .tolist ()
448- return f"{ tree .node_id } : { s } "
449- return f"{ tree .node_id } "
450-
451453 # subtree.plot(make_labels=make_labels)
452454 # tree.plot(make_labels=make_labels)
453455 assert tree .get_vertex_by_id (2 ).num_children == 2
0 commit comments