Skip to content

Commit 6deedec

Browse files
author
Vincent Moens
committed
[Feature] MCTSForest
ghstack-source-id: dd7d939 Pull Request resolved: #2307
1 parent bef503f commit 6deedec

File tree

4 files changed

+457
-0
lines changed

4 files changed

+457
-0
lines changed

test/test_storage_map.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,12 @@ def test_map_rollout(self):
237237
assert contains[: rollout.shape[-1]].all()
238238
assert not contains[rollout.shape[-1] :].any()
239239

240+
class TestMCTSForest:
241+
def test_forest_build(self):
242+
...
243+
def test_forest_extend_and_get(self):
244+
...
245+
240246

241247
if __name__ == "__main__":
242248
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/data/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from .map import (
77
BinaryToDecimal,
88
HashToInt,
9+
MCTSChildren,
10+
MCTSForest,
11+
MCTSNode,
912
QueryModule,
1013
RandomProjectionHash,
1114
SipHash,

torchrl/data/map/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
from .hash import BinaryToDecimal, RandomProjectionHash, SipHash
77
from .query import HashToInt, QueryModule
88
from .tdstorage import TensorDictMap, TensorMap
9+
from .tree import MCTSChildren, MCTSForest, MCTSNode

0 commit comments

Comments
 (0)