@@ -975,7 +975,72 @@ The following classes are deprecated and just point to the classes above:
975975Trees and Forests
976976-----------------
977977
978- TorchRL offers a set of classes and functions that can be used to represent trees and forests efficiently.
978+ TorchRL offers a set of classes and functions that can be used to represent trees and forests efficiently,
979+ which is particularly useful for Monte Carlo Tree Search (MCTS) algorithms.
980+
981+ TensorDictMap
982+ ~~~~~~~~~~~~~
983+
984+ At its core, the MCTS API relies on the :class: `~torchrl.data.TensorDictMap ` which acts like a storage where indices can
985+ be any numerical object. In traditional storages (e.g., :class: `~torchrl.data.TensorStorage `), only integer indices
986+ are allowed:
987+
988+ >>> storage = TensorStorage(... )
989+ >>> data = storage[3 ]
990+
991+ :class: `~torchrl.data.TensorDictMap ` allows us to make more advanced queries in the storage. The typical example is
992+ when we have a storage containing a set of MDPs and we want to rebuild a trajectory given its initial observation, action
993+ pair. In tensor terms, this could be written with the following pseudocode:
994+
995+ >>> next_state = storage[observation, action]
996+
997+ (if there is more than one next state associated with this pair one could return a stack of ``next_states `` instead).
998+ This API would make sense but it would be restrictive: allowing observations or actions that are composed of
999+ multiple tensors may be hard to implement. Instead, we provide a tensordict containing these values and let the storage
1000+ know what ``in_keys `` to look at to query the next state:
1001+
1002+ >>> td = TensorDict(observation = observation, action = action)
1003+ >>> next_td = storage[td]
1004+
1005+ Of course, this class also allows us to extend the storage with new data:
1006+
1007+ >>> storage[td] = next_state
1008+
1009+ This comes in handy because it allows us to represent complex rollout structures where different actions are undertaken
1010+ at a given node (ie, for a given observation). All `(observation, action) ` pairs that have been observed may lead us to
1011+ a (set of) rollout that we can use further.
1012+
1013+ MCTSForest
1014+ ~~~~~~~~~~
1015+
1016+ Building a tree from an initial observation then becomes just a matter of organizing data efficiently.
1017+ The :class: `~torchrl.data.MCTSForest ` has at its core two storages: a first storage links observations to hashes and
1018+ indices of actions encountered in the past in the dataset:
1019+
1020+ >>> data = TensorDict(observation = observation)
1021+ >>> metadata = forest.node_map[data]
1022+ >>> index = metadata[" _index" ]
1023+
1024+ where ``forest `` is a :class: `~torchrl.data.MCTSForest ` instance.
1025+ Then, a second storage keeps track of the actions and results associated with the observation:
1026+
1027+ >>> next_data = forest.data_map[index]
1028+
1029+ The ``next_data `` entry can have any shape, but it will usually match the shape of ``index `` (since at each index
1030+ corresponds one action). Once ``next_data `` is obtrained, it can be put together with ``data `` to form a set of nodes,
1031+ and the tree can be expanded for each of these. The following figure shows how this is done.
1032+
1033+ .. figure :: /_static/img/collector-copy.png
1034+
1035+ Building a :class: `~torchrl.data.Tree` from a :class:`~torchrl.data.MCTSForest` object.
1036+ The flowchart represents a tree being built from an initial observation `o `. The :class: `~torchrl.data.MCTSForest.get_tree`
1037+ method passed the input data structure (the root node) to the ``node_map `` :class: `~torchrl.data.TensorDictMap` instance
1038+ that returns a set of hashes and indices. These indices are then used to query the corresponding tuples of
1039+ actions, next observations, rewards etc. that are associated with the root node.
1040+ A vertex is created from each of them (possibly with a longer rollout when a compact representation is asked).
1041+ The stack of vertices is then used to build up the tree further, and these vertices are stacked together and make
1042+ up the branches of the tree at the root. This process is repeated for a given depth or until the tree cannot be
1043+ expanded anymore.
9791044
9801045.. currentmodule :: torchrl.data
9811046
0 commit comments