Skip to content

Commit 2e56b5b

Browse files
Merge pull request #1 from vmoens/refactor_tcq
Refactor TQC
2 parents 6c80564 + ce69631 commit 2e56b5b

File tree

13 files changed

+1119
-184
lines changed

13 files changed

+1119
-184
lines changed

.github/unittest/linux_libs/scripts_d4rl/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ dependencies:
1717
- pyyaml
1818
- scipy
1919
- hydra-core
20+
- cython<3

docs/source/reference/data.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ We also give users the ability to compose a replay buffer using the following co
4343
Writer
4444
RoundRobinWriter
4545
TensorDictRoundRobinWriter
46+
TensorDictMaxValueWriter
4647

4748
Storage choice is very influential on replay buffer sampling latency, especially in distributed reinforcement learning settings with larger data volumes.
4849
:class:`LazyMemmapStorage` is highly advised in distributed settings with shared storage due to the lower serialisation cost of MemmapTensors as well as the ability to specify file storage locations for improved node failure recovery.

docs/source/reference/objectives.rst

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ CQL
138138
CQLLoss
139139

140140
DT
141-
----
141+
--
142142

143143
.. autosummary::
144144
:toctree: generated/
@@ -148,14 +148,23 @@ DT
148148
OnlineDTLoss
149149

150150
TD3
151-
----
151+
---
152152

153153
.. autosummary::
154154
:toctree: generated/
155155
:template: rl_template_noinherit.rst
156156

157157
TD3Loss
158158

159+
TQC
160+
---
161+
162+
.. autosummary::
163+
:toctree: generated/
164+
:template: rl_template_noinherit.rst
165+
166+
TQCLoss
167+
159168
PPO
160169
---
161170

examples/tqc/tqc.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"""
3030

3131
import time
32+
3233
import hydra
3334
import numpy as np
3435
import torch
@@ -57,7 +58,7 @@ def main(cfg: "DictConfig"): # noqa: F821
5758
exp_name = generate_exp_name("SAC", cfg.env.exp_name)
5859
logger = None
5960
# TO-DO: Add logging back in before pushing to git repo
60-
#if cfg.logger.backend:
61+
# if cfg.logger.backend:
6162
# logger = get_logger(
6263
# logger_type=cfg.logger.backend,
6364
# logger_name="sac_logging/wandb",
@@ -190,7 +191,9 @@ def main(cfg: "DictConfig"): # noqa: F821
190191
episode_length
191192
)
192193
if collected_frames >= init_random_frames:
193-
metrics_to_log["train/critic_loss"] = losses.get("loss_critic").mean().item()
194+
metrics_to_log["train/critic_loss"] = (
195+
losses.get("loss_critic").mean().item()
196+
)
194197
metrics_to_log["train/actor_loss"] = losses.get("loss_actor").mean().item()
195198
metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha").mean().item()
196199
metrics_to_log["train/alpha"] = loss_td["alpha"].item()

examples/tqc/utils.py

Lines changed: 31 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,27 @@
55

66
import tempfile
77
from contextlib import nullcontext
8+
from typing import Tuple
9+
810
import torch
9-
import numpy as np
1011
from tensordict.nn import InteractionType, TensorDictModule
1112
from tensordict.nn.distributions import NormalParamExtractor
13+
from tensordict.tensordict import TensorDict, TensorDictBase
1214
from torch import nn, optim
1315
from torchrl.collectors import SyncDataCollector
14-
from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer
16+
from torchrl.data import (
17+
CompositeSpec,
18+
TensorDictPrioritizedReplayBuffer,
19+
TensorDictReplayBuffer,
20+
)
1521
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
1622
from torchrl.envs import Compose, DoubleToFloat, EnvCreator, ParallelEnv, TransformedEnv
1723
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
1824
from torchrl.envs.transforms import InitTracker, RewardSum, StepCounter
1925
from torchrl.envs.utils import ExplorationType, set_exploration_type
20-
from torchrl.modules import MLP, ProbabilisticActor, ValueOperator, ActorCriticWrapper
26+
from torchrl.modules import ActorCriticWrapper, MLP, ProbabilisticActor, ValueOperator
2127
from torchrl.modules.distributions import TanhNormal
22-
from torchrl.objectives import SoftUpdate
23-
from torchrl.data import CompositeSpec
24-
from torchrl.objectives.common import LossModule
28+
from torchrl.objectives import SoftUpdate, TQCLoss
2529
from torchrl.objectives.utils import (
2630
_cache_values,
2731
_GAMMA_LMBDA_DEPREC_WARNING,
@@ -30,8 +34,6 @@
3034
ValueEstimators,
3135
)
3236
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator
33-
from tensordict.tensordict import TensorDict, TensorDictBase
34-
from typing import Tuple
3537

3638

3739
# ====================================================================
@@ -100,17 +102,17 @@ def make_collector(cfg, train_env, actor_model_explore):
100102

101103

102104
def make_replay_buffer(
103-
batch_size,
104-
prb=False,
105-
buffer_size=1_000_000,
106-
buffer_scratch_dir=None,
107-
device="cpu",
108-
prefetch=3,
105+
batch_size,
106+
prb=False,
107+
buffer_size=1_000_000,
108+
buffer_scratch_dir=None,
109+
device="cpu",
110+
prefetch=3,
109111
):
110112
with (
111-
tempfile.TemporaryDirectory()
112-
if buffer_scratch_dir is None
113-
else nullcontext(buffer_scratch_dir)
113+
tempfile.TemporaryDirectory()
114+
if buffer_scratch_dir is None
115+
else nullcontext(buffer_scratch_dir)
114116
) as scratch_dir:
115117
if prb:
116118
replay_buffer = TensorDictPrioritizedReplayBuffer(
@@ -155,13 +157,15 @@ def __init__(self, cfg):
155157
}
156158
for i in range(cfg.network.n_nets):
157159
net = MLP(**qvalue_net_kwargs)
158-
self.add_module(f'critic_net_{i}', net)
160+
self.add_module(f"critic_net_{i}", net)
159161
self.nets.append(net)
160162

161163
def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:
162164
if len(inputs) > 1:
163165
inputs = (torch.cat([*inputs], -1),)
164-
quantiles = torch.stack(tuple(net(*inputs) for net in self.nets), dim=-2) # batch x n_nets x n_quantiles
166+
quantiles = torch.stack(
167+
tuple(net(*inputs) for net in self.nets), dim=-2
168+
) # batch x n_nets x n_quantiles
165169
return quantiles
166170

167171

@@ -239,172 +243,26 @@ def make_tqc_agent(cfg, train_env, eval_env, device):
239243
return model, model[0]
240244

241245

242-
# ====================================================================
243-
# Quantile Huber Loss
244-
# -------------------
245-
246-
247-
def quantile_huber_loss_f(quantiles, samples):
248-
"""
249-
Quantile Huber loss from the original PyTorch TQC implementation.
250-
See: https://github.com/SamsungLabs/tqc_pytorch/blob/master/tqc/functions.py
251-
252-
quantiles is assumed to be of shape [batch size, n_nets, n_quantiles]
253-
samples is assumed to be of shape [batch size, n_samples]
254-
Arbitrary batch sizes are allowed.
255-
"""
256-
pairwise_delta = samples[..., None, None, :] - quantiles[..., None] # batch x n_nets x n_quantiles x n_samples
257-
abs_pairwise_delta = torch.abs(pairwise_delta)
258-
huber_loss = torch.where(abs_pairwise_delta > 1,
259-
abs_pairwise_delta - 0.5,
260-
pairwise_delta ** 2 * 0.5)
261-
n_quantiles = quantiles.shape[-1]
262-
tau = torch.arange(n_quantiles, device=quantiles.device).float() / n_quantiles + 1 / 2 / n_quantiles
263-
loss = (torch.abs(tau[..., None, :, None] - (pairwise_delta < 0).float()) * huber_loss).mean()
264-
return loss
265-
266-
267246
# ====================================================================
268247
# TQC Loss
269248
# --------
270249

271-
class TQCLoss(LossModule):
272-
def __init__(
273-
self,
274-
actor_network,
275-
qvalue_network,
276-
gamma,
277-
top_quantiles_to_drop,
278-
alpha_init,
279-
device
280-
):
281-
super().__init__()
282-
283-
self.convert_to_functional(
284-
actor_network,
285-
"actor",
286-
create_target_params=False,
287-
funs_to_decorate=["forward", "get_dist"],
288-
)
289-
290-
self.convert_to_functional(
291-
qvalue_network,
292-
"critic",
293-
create_target_params=True # Create a target critic network
294-
)
295-
296-
self.device = device
297-
self.log_alpha = torch.tensor([np.log(alpha_init)], requires_grad=True, device=self.device)
298-
self.gamma = gamma
299-
self.top_quantiles_to_drop = top_quantiles_to_drop
300-
301-
# Compute target entropy
302-
action_spec = getattr(self.actor, "spec", None)
303-
if action_spec is None:
304-
print("Could not deduce action spec from actor network.")
305-
if not isinstance(action_spec, CompositeSpec):
306-
action_spec = CompositeSpec({"action": action_spec})
307-
action_container_len = len(action_spec.shape)
308-
self.target_entropy = -float(action_spec["action"].shape[action_container_len:].numel())
309-
310-
def value_loss(self, tensordict):
311-
td_next = tensordict.get("next")
312-
reward = td_next.get("reward")
313-
not_done = tensordict.get("done").logical_not()
314-
alpha = torch.exp(self.log_alpha)
315-
316-
# Q-loss
317-
with torch.no_grad():
318-
# get policy action
319-
self.actor(td_next, params=self.actor_params)
320-
self.critic(td_next, params=self.target_critic_params)
321-
322-
next_log_pi = td_next.get("sample_log_prob")
323-
next_log_pi = torch.unsqueeze(next_log_pi, dim=-1)
324-
325-
# compute and cut quantiles at the next state
326-
next_z = td_next.get("state_action_value")
327-
sorted_z, _ = torch.sort(next_z.reshape(*tensordict.batch_size, -1))
328-
sorted_z_part = sorted_z[..., :-self.top_quantiles_to_drop]
329-
330-
# compute target
331-
# --- Note ---
332-
# This is computed manually here, since the built-in value estimators in the library
333-
# currently do not support a critic of a shape different from the reward.
334-
# ------------
335-
target = reward + not_done * self.gamma * (sorted_z_part - alpha * next_log_pi)
336-
337-
self.critic(tensordict, params=self.critic_params)
338-
cur_z = tensordict.get("state_action_value")
339-
critic_loss = quantile_huber_loss_f(cur_z, target)
340-
return critic_loss
341-
342-
def actor_loss(self, tensordict):
343-
alpha = torch.exp(self.log_alpha)
344-
self.actor(tensordict, params=self.actor_params)
345-
self.critic(tensordict, params=self.critic_params)
346-
new_log_pi = tensordict.get("sample_log_prob")
347-
actor_loss = (alpha * new_log_pi - tensordict.get("state_action_value").mean(-1).mean(-1, keepdim=True)).mean()
348-
return actor_loss, new_log_pi
349-
350-
def alpha_loss(self, log_prob):
351-
alpha_loss = -self.log_alpha * (log_prob + self.target_entropy).detach().mean()
352-
return alpha_loss
353-
354-
def entropy(self, tensordict):
355-
with set_exploration_type(ExplorationType.RANDOM):
356-
dist = self.actor.get_dist(
357-
tensordict,
358-
params=self.actor_params,
359-
)
360-
a_reparm = dist.rsample()
361-
log_prob = dist.log_prob(a_reparm).detach()
362-
entropy = -log_prob.mean()
363-
return entropy
364-
365-
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
366-
alpha = torch.exp(self.log_alpha)
367-
critic_loss = self.value_loss(tensordict)
368-
actor_loss, log_prob = self.actor_loss(tensordict) # Compute actor loss AFTER critic loss
369-
alpha_loss = self.alpha_loss(log_prob)
370-
entropy = self.entropy(tensordict)
371-
372-
return TensorDict(
373-
{
374-
"loss_critic": critic_loss,
375-
"loss_actor": actor_loss,
376-
"loss_alpha": alpha_loss,
377-
"alpha": alpha,
378-
"entropy": entropy,
379-
},
380-
batch_size=[]
381-
)
382-
383-
def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
384-
"""
385-
This is a dummy function, which simply checks if the value type is TD0 and raises
386-
an error if the value type is different. As of writing of this, the value estimators
387-
in the library do not support a critic shape different from the reward state, which
388-
is however necessary by construction for TQC. Therefore, this function does not
389-
actually construct a value estimator, and the value is estimated "by hand" in the
390-
value_loss function above.
391-
"""
392-
if value_type is not ValueEstimators.TD0:
393-
raise NotImplementedError(f"Value type {value_type} is not currently implemented.")
394-
395250

396251
def make_loss_module(cfg, model):
397252
"""Make loss module and target network updater."""
398253
# Create TQC loss
254+
top_quantiles_to_drop = (
255+
cfg.network.top_quantiles_to_drop_per_net * cfg.network.n_nets
256+
)
399257
loss_module = TQCLoss(
400258
actor_network=model[0],
401259
qvalue_network=model[1],
402-
device=cfg.network.device,
403-
gamma=cfg.optim.gamma,
404-
top_quantiles_to_drop=cfg.network.top_quantiles_to_drop_per_net * cfg.network.n_nets,
405-
alpha_init=cfg.optim.alpha_init
260+
top_quantiles_to_drop=top_quantiles_to_drop,
261+
alpha_init=cfg.optim.alpha_init,
262+
)
263+
loss_module.make_value_estimator(
264+
value_type=ValueEstimators.TD0, gamma=cfg.optim.gamma
406265
)
407-
loss_module.make_value_estimator(value_type=ValueEstimators.TD0)
408266

409267
# Define Target Network Updater
410268
target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak)

0 commit comments

Comments
 (0)