Skip to content

Commit 5d168b4

Browse files
authored
Merge pull request #1 from vmoens/tutorials_py2
Bug fixes and syntax changes for: - docs/source/config.py - docs/source/content_generation.py - tutorials/sphinx-tutorials/multi_task.py - tutorials/sphinx-tutorials/tensordict.py - tutorials/sphinx-tutorials/tensordict_module.py - tutorials/sphinx-tutorials/torchrl_demo.py
2 parents 05f61b8 + 7aeee38 commit 5d168b4

File tree

6 files changed

+294
-159
lines changed

6 files changed

+294
-159
lines changed

docs/source/conf.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
# -- Project information -----------------------------------------------------
2424
import os.path
2525
import sys
26-
from pathlib import Path
2726
import warnings
2827

2928
import pytorch_sphinx_theme
@@ -75,8 +74,8 @@
7574
"gallery_dirs": "tutorials", # path to where to save gallery generated output
7675
"backreferences_dir": "gen_modules/backreferences",
7776
"doc_module": ("torchrl",),
78-
"filename_pattern": "reference/generated/tutorials/", # files to parse
79-
"notebook_images": "reference/generated/tutorials/media/", # images to parse
77+
"filename_pattern": "reference/generated/tutorials/", # files to parse
78+
"notebook_images": "reference/generated/tutorials/media/", # images to parse
8079
}
8180

8281
napoleon_use_ivar = True
@@ -162,7 +161,10 @@
162161
# -- Generate knowledge base references -----------------------------------
163162
current_path = os.path.dirname(os.path.realpath(__file__))
164163
sys.path.append(current_path)
165-
from content_generation import generate_knowledge_base_references, generate_tutorial_references
164+
from content_generation import (
165+
generate_knowledge_base_references,
166+
generate_tutorial_references,
167+
)
166168

167169
generate_knowledge_base_references("../../knowledge_base")
168170
generate_tutorial_references("../../tutorials/sphinx-tutorials/", "tutorial")

docs/source/content_generation.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
2-
from pathlib import Path
32
import shutil
3+
from pathlib import Path
44
from typing import List
55

66
FILE_DIR = os.path.dirname(__file__)
@@ -78,9 +78,11 @@ def generate_tutorial_references(tutorial_path: str, file_type: str) -> None:
7878
Path(target_path).mkdir(parents=True, exist_ok=True)
7979

8080
# Iterate tutorial files and copy
81-
file_paths = [os.path.join(tutorial_path, f)
82-
for f in os.listdir(tutorial_path)
83-
if f.endswith((".py", ".rst", ".png"))]
81+
file_paths = [
82+
os.path.join(tutorial_path, f)
83+
for f in os.listdir(tutorial_path)
84+
if f.endswith((".py", ".rst", ".png"))
85+
]
8486

8587
for file_path in file_paths:
8688
shutil.copyfile(file_path, os.path.join(target_path, Path(file_path).name))

tutorials/sphinx-tutorials/multi_task.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
================================================
55
This tutorial details how multi-task policies and batched environments can be used.
66
"""
7+
import torch
8+
from torch import nn
9+
710
##############################################################################
811
# At the end of this tutorial, you will be capable of writing policies that
912
# can compute actions in diverse settings using a distinct set of weights.
@@ -12,8 +15,6 @@
1215
from torchrl.envs import TransformedEnv, CatTensors, Compose, DoubleToFloat, ParallelEnv
1316
from torchrl.envs.libs.dm_control import DMControlEnv
1417
from torchrl.modules import TensorDictModule, TensorDictSequential, MLP
15-
from torch import nn
16-
import torch
1718

1819
###############################################################################
1920
# We design two environments, one humanoid that must complete the stand task
@@ -26,8 +27,11 @@
2627
Compose(
2728
CatTensors(env1_obs_keys, "next_observation_stand", del_keys=False),
2829
CatTensors(env1_obs_keys, "next_observation"),
29-
DoubleToFloat(keys_in=["next_observation_stand", "next_observation"], keys_inv_in=["action"]),
30-
)
30+
DoubleToFloat(
31+
in_keys=["next_observation_stand", "next_observation"],
32+
in_keys_inv=["action"],
33+
),
34+
),
3135
)
3236
env2 = DMControlEnv("humanoid", "walk")
3337
env2_obs_keys = list(env2.observation_spec.keys())
@@ -36,8 +40,11 @@
3640
Compose(
3741
CatTensors(env2_obs_keys, "next_observation_walk", del_keys=False),
3842
CatTensors(env2_obs_keys, "next_observation"),
39-
DoubleToFloat(keys_in=["next_observation_walk", "next_observation"], keys_inv_in=["action"]),
40-
)
43+
DoubleToFloat(
44+
in_keys=["next_observation_walk", "next_observation"],
45+
in_keys_inv=["action"],
46+
),
47+
),
4148
)
4249

4350
###############################################################################
@@ -66,10 +73,22 @@
6673

6774
###############################################################################
6875

69-
policy_common = TensorDictModule(nn.Linear(67, 64), in_keys=["observation"], out_keys=["hidden"])
70-
policy_stand = TensorDictModule(MLP(67 + 64, action_dim, depth=2), in_keys=["observation_stand", "hidden"], out_keys=["action"])
71-
policy_walk = TensorDictModule(MLP(67 + 64, action_dim, depth=2), in_keys=["observation_walk", "hidden"], out_keys=["action"])
72-
seq = TensorDictSequential(policy_common, policy_stand, policy_walk, partial_tolerant=True)
76+
policy_common = TensorDictModule(
77+
nn.Linear(67, 64), in_keys=["observation"], out_keys=["hidden"]
78+
)
79+
policy_stand = TensorDictModule(
80+
MLP(67 + 64, action_dim, depth=2),
81+
in_keys=["observation_stand", "hidden"],
82+
out_keys=["action"],
83+
)
84+
policy_walk = TensorDictModule(
85+
MLP(67 + 64, action_dim, depth=2),
86+
in_keys=["observation_walk", "hidden"],
87+
out_keys=["action"],
88+
)
89+
seq = TensorDictSequential(
90+
policy_common, policy_stand, policy_walk, partial_tolerant=True
91+
)
7392

7493
###############################################################################
7594
# Let's check that our sequence outputs actions for a single env (stand).
@@ -101,22 +120,35 @@
101120
# a single task has to be performed. If a list of functions is provided, then
102121
# it will assume that we are in a multi-task setting.
103122

104-
env1_maker = lambda: TransformedEnv(
105-
DMControlEnv("humanoid", "stand"),
106-
Compose(
107-
CatTensors(env1_obs_keys, "next_observation_stand", del_keys=False),
108-
CatTensors(env1_obs_keys, "next_observation"),
109-
DoubleToFloat(keys_in=["next_observation_stand", "next_observation"], keys_inv_in=["action"]),
123+
124+
def env1_maker():
125+
return TransformedEnv(
126+
DMControlEnv("humanoid", "stand"),
127+
Compose(
128+
CatTensors(env1_obs_keys, "next_observation_stand", del_keys=False),
129+
CatTensors(env1_obs_keys, "next_observation"),
130+
DoubleToFloat(
131+
in_keys=["next_observation_stand", "next_observation"],
132+
in_keys_inv=["action"],
133+
),
134+
),
110135
)
111-
)
112-
env2_maker = lambda: TransformedEnv(
113-
DMControlEnv("humanoid", "walk"),
114-
Compose(
115-
CatTensors(env2_obs_keys, "next_observation_walk", del_keys=False),
116-
CatTensors(env2_obs_keys, "next_observation"),
117-
DoubleToFloat(keys_in=["next_observation_walk", "next_observation"], keys_inv_in=["action"]),
136+
137+
138+
def env2_maker():
139+
return TransformedEnv(
140+
DMControlEnv("humanoid", "walk"),
141+
Compose(
142+
CatTensors(env2_obs_keys, "next_observation_walk", del_keys=False),
143+
CatTensors(env2_obs_keys, "next_observation"),
144+
DoubleToFloat(
145+
in_keys=["next_observation_walk", "next_observation"],
146+
in_keys_inv=["action"],
147+
),
148+
),
118149
)
119-
)
150+
151+
120152
env = ParallelEnv(2, [env1_maker, env2_maker])
121153
assert not env._single_task
122154

@@ -148,8 +180,8 @@
148180

149181
###############################################################################
150182

151-
td_rollout[:, 0] # tensordict of the first step: only the common keys are shown
183+
td_rollout[:, 0] # tensordict of the first step: only the common keys are shown
152184

153185
###############################################################################
154186

155-
td_rollout[0] # tensordict of the first env: the stand obs is present
187+
td_rollout[0] # tensordict of the first env: the stand obs is present

tutorials/sphinx-tutorials/tensordict.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,19 @@
8181
# However to achieve this you would need to write a complicated collate
8282
# function that make sure that every modality is aggregated properly.
8383

84+
8485
def collate_dict_fn(dict_list):
8586
final_dict = {}
8687
for key in dict_list[0].keys():
87-
final_dict[key]= []
88+
final_dict[key] = []
8889
for single_dict in dict_list:
8990
final_dict[key].append(single_dict[key])
9091
final_dict[key] = torch.stack(final_dict[key], dim=0)
9192
return final_dict
9293

94+
95+
import torch
96+
9397
###############################################################################
9498
# dataloader = Dataloader(DictDataset(), collate_fn = collate_dict_fn)
9599
#
@@ -120,11 +124,9 @@ def collate_dict_fn(dict_list):
120124
from torchrl.data import TensorDict
121125
from torchrl.data.tensordict.tensordict import (
122126
UnsqueezedTensorDict,
123-
ViewedTensorDict,
127+
_ViewedTensorDict,
124128
PermutedTensorDict,
125-
LazyStackedTensorDict,
126129
)
127-
import torch
128130

129131
###############################################################################
130132
# TensorDict is a Datastructure indexed by either keys or numerical indices.
@@ -147,7 +149,7 @@ def collate_dict_fn(dict_list):
147149
# does not work
148150
try:
149151
tensordict = TensorDict({"a": a, "b": b}, batch_size=[3, 4, 5])
150-
except:
152+
except RuntimeError:
151153
print("caramba!")
152154

153155
###############################################################################
@@ -158,10 +160,10 @@ def collate_dict_fn(dict_list):
158160
a = torch.zeros(3, 4)
159161
b = TensorDict(
160162
{
161-
"c": torch.zeros(3, 4, 5, dtype=torch.int32),
162-
"d": torch.zeros(3, 4, 5, 6, dtype=torch.float32)
163+
"c": torch.zeros(3, 4, 5, dtype=torch.int32),
164+
"d": torch.zeros(3, 4, 5, 6, dtype=torch.float32),
163165
},
164-
batch_size=[3, 4, 5]
166+
batch_size=[3, 4, 5],
165167
)
166168
tensordict = TensorDict({"a": a, "b": b}, batch_size=[3, 4])
167169
print(tensordict)
@@ -233,7 +235,7 @@ def collate_dict_fn(dict_list):
233235
# The ``update`` method can be used to update a TensorDict with another one
234236
# (or with a dict):
235237

236-
tensordict.update({"a": torch.ones((3, 4, 5)), "d": 2*torch.ones((3, 4, 2))})
238+
tensordict.update({"a": torch.ones((3, 4, 5)), "d": 2 * torch.ones((3, 4, 2))})
237239
# Also works with tensordict.update(TensorDict({"a":torch.ones((3, 4, 5)),
238240
# "c":torch.ones((3, 4, 2))}, batch_size=[3,4]))
239241
print(f"a is now equal to 1: {(tensordict['a'] == 1).all()}")
@@ -262,7 +264,9 @@ def collate_dict_fn(dict_list):
262264
# but it must be shared across tensors. Indeed, you cannot have items that don't
263265
# share the batch size inside the same TensorDict:
264266

265-
tensordict = TensorDict({"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4])
267+
tensordict = TensorDict(
268+
{"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
269+
)
266270
print(f"Our TensorDict is of size {tensordict.shape}")
267271

268272
###############################################################################
@@ -302,8 +306,10 @@ def collate_dict_fn(dict_list):
302306
tensordict = TensorDict({}, [10])
303307
for i in range(2):
304308
tensordict[i] = TensorDict({"a": torch.randn(3, 4)}, [])
305-
assert (tensordict[9]["a"] == torch.zeros((3,4))).all()
306-
tensordict = TensorDict({"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4])
309+
assert (tensordict[9]["a"] == torch.zeros((3, 4))).all()
310+
tensordict = TensorDict(
311+
{"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
312+
)
307313

308314
###############################################################################
309315
# Devices
@@ -327,7 +333,9 @@ def collate_dict_fn(dict_list):
327333
# than the original item.
328334

329335
tensordict_clone = tensordict.clone()
330-
print(f"Content is identical ({(tensordict['a'] == tensordict_clone['a']).all()}) but duplicated ({tensordict['a'] is not tensordict_clone['a']})")
336+
print(
337+
f"Content is identical ({(tensordict['a'] == tensordict_clone['a']).all()}) but duplicated ({tensordict['a'] is not tensordict_clone['a']})"
338+
)
331339

332340
###############################################################################
333341
# **Slicing and Indexing**
@@ -356,7 +364,9 @@ def collate_dict_fn(dict_list):
356364
# to the original tensordict as well as the desired index such that tensor
357365
# modifications can be achieved easily.
358366

359-
tensordict = TensorDict({"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4])
367+
tensordict = TensorDict(
368+
{"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
369+
)
360370
# a SubTensorDict keeps track of the original one: it does not create a copy in memory of the original data
361371
subtd = tensordict.get_sub_tensordict((slice(None), torch.tensor([1, 3])))
362372
tensordict.fill_("a", -1)
@@ -422,10 +432,10 @@ def collate_dict_fn(dict_list):
422432
###############################################################################
423433
# **View**
424434
#
425-
# Support for the view operation returning a ``ViewedTensorDict``.
435+
# Support for the view operation returning a ``_ViewedTensorDict``.
426436
# Use ``to_tensordict`` to comeback to retrieve TensorDict.
427437

428-
assert type(tensordict.view(-1)) == ViewedTensorDict
438+
assert type(tensordict.view(-1)) == _ViewedTensorDict
429439
assert tensordict.view(-1).shape[0] == 12
430440

431441
###############################################################################
@@ -434,8 +444,8 @@ def collate_dict_fn(dict_list):
434444
# We can permute the dims of ``TensorDict``. Permute is a Lazy operation that
435445
# returns PermutedTensorDict. Use ``to_tensordict`` to convert to ``TensorDict``.
436446

437-
assert type(tensordict.permute(1,0)) == PermutedTensorDict
438-
assert tensordict.permute(1,0).batch_size == torch.Size([4, 3])
447+
assert type(tensordict.permute(1, 0)) == PermutedTensorDict
448+
assert tensordict.permute(1, 0).batch_size == torch.Size([4, 3])
439449

440450
###############################################################################
441451
# **Reshape**

0 commit comments

Comments
 (0)