Skip to content

Commit ce69631

Browse files
committed
amend
1 parent 3b4b369 commit ce69631

File tree

7 files changed

+915
-343
lines changed

7 files changed

+915
-343
lines changed

docs/source/reference/objectives.rst

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ CQL
138138
CQLLoss
139139

140140
DT
141-
----
141+
--
142142

143143
.. autosummary::
144144
:toctree: generated/
@@ -148,14 +148,23 @@ DT
148148
OnlineDTLoss
149149

150150
TD3
151-
----
151+
---
152152

153153
.. autosummary::
154154
:toctree: generated/
155155
:template: rl_template_noinherit.rst
156156

157157
TD3Loss
158158

159+
TQC
160+
---
161+
162+
.. autosummary::
163+
:toctree: generated/
164+
:template: rl_template_noinherit.rst
165+
166+
TQCLoss
167+
159168
PPO
160169
---
161170

examples/tqc/tqc.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"""
3030

3131
import time
32+
3233
import hydra
3334
import numpy as np
3435
import torch
@@ -57,7 +58,7 @@ def main(cfg: "DictConfig"): # noqa: F821
5758
exp_name = generate_exp_name("SAC", cfg.env.exp_name)
5859
logger = None
5960
# TO-DO: Add logging back in before pushing to git repo
60-
#if cfg.logger.backend:
61+
# if cfg.logger.backend:
6162
# logger = get_logger(
6263
# logger_type=cfg.logger.backend,
6364
# logger_name="sac_logging/wandb",
@@ -190,7 +191,9 @@ def main(cfg: "DictConfig"): # noqa: F821
190191
episode_length
191192
)
192193
if collected_frames >= init_random_frames:
193-
metrics_to_log["train/critic_loss"] = losses.get("loss_critic").mean().item()
194+
metrics_to_log["train/critic_loss"] = (
195+
losses.get("loss_critic").mean().item()
196+
)
194197
metrics_to_log["train/actor_loss"] = losses.get("loss_actor").mean().item()
195198
metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha").mean().item()
196199
metrics_to_log["train/alpha"] = loss_td["alpha"].item()

examples/tqc/utils.py

Lines changed: 31 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,27 @@
55

66
import tempfile
77
from contextlib import nullcontext
8+
from typing import Tuple
9+
810
import torch
911
from tensordict.nn import InteractionType, TensorDictModule
1012
from tensordict.nn.distributions import NormalParamExtractor
13+
from tensordict.tensordict import TensorDict, TensorDictBase
1114
from torch import nn, optim
1215
from torchrl.collectors import SyncDataCollector
13-
from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer
16+
from torchrl.data import (
17+
CompositeSpec,
18+
TensorDictPrioritizedReplayBuffer,
19+
TensorDictReplayBuffer,
20+
)
1421
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
1522
from torchrl.envs import Compose, DoubleToFloat, EnvCreator, ParallelEnv, TransformedEnv
1623
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
1724
from torchrl.envs.transforms import InitTracker, RewardSum, StepCounter
1825
from torchrl.envs.utils import ExplorationType, set_exploration_type
19-
from torchrl.modules import MLP, ProbabilisticActor, ValueOperator, ActorCriticWrapper
26+
from torchrl.modules import ActorCriticWrapper, MLP, ProbabilisticActor, ValueOperator
2027
from torchrl.modules.distributions import TanhNormal
21-
from torchrl.objectives import SoftUpdate
22-
from torchrl.data import CompositeSpec
23-
from torchrl.objectives.common import LossModule
28+
from torchrl.objectives import SoftUpdate, TQCLoss
2429
from torchrl.objectives.utils import (
2530
_cache_values,
2631
_GAMMA_LMBDA_DEPREC_WARNING,
@@ -29,14 +34,11 @@
2934
ValueEstimators,
3035
)
3136
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator
32-
from tensordict.tensordict import TensorDict, TensorDictBase
33-
from typing import Tuple
3437

3538

3639
# ====================================================================
3740
# Environment utils
3841
# -----------------
39-
from torchrl.objectives.tcq import TQCLoss
4042

4143

4244
def env_maker(task, device="cpu"):
@@ -100,17 +102,17 @@ def make_collector(cfg, train_env, actor_model_explore):
100102

101103

102104
def make_replay_buffer(
103-
batch_size,
104-
prb=False,
105-
buffer_size=1_000_000,
106-
buffer_scratch_dir=None,
107-
device="cpu",
108-
prefetch=3,
105+
batch_size,
106+
prb=False,
107+
buffer_size=1_000_000,
108+
buffer_scratch_dir=None,
109+
device="cpu",
110+
prefetch=3,
109111
):
110112
with (
111-
tempfile.TemporaryDirectory()
112-
if buffer_scratch_dir is None
113-
else nullcontext(buffer_scratch_dir)
113+
tempfile.TemporaryDirectory()
114+
if buffer_scratch_dir is None
115+
else nullcontext(buffer_scratch_dir)
114116
) as scratch_dir:
115117
if prb:
116118
replay_buffer = TensorDictPrioritizedReplayBuffer(
@@ -155,13 +157,15 @@ def __init__(self, cfg):
155157
}
156158
for i in range(cfg.network.n_nets):
157159
net = MLP(**qvalue_net_kwargs)
158-
self.add_module(f'critic_net_{i}', net)
160+
self.add_module(f"critic_net_{i}", net)
159161
self.nets.append(net)
160162

161163
def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:
162164
if len(inputs) > 1:
163165
inputs = (torch.cat([*inputs], -1),)
164-
quantiles = torch.stack(tuple(net(*inputs) for net in self.nets), dim=-2) # batch x n_nets x n_quantiles
166+
quantiles = torch.stack(
167+
tuple(net(*inputs) for net in self.nets), dim=-2
168+
) # batch x n_nets x n_quantiles
165169
return quantiles
166170

167171

@@ -239,172 +243,26 @@ def make_tqc_agent(cfg, train_env, eval_env, device):
239243
return model, model[0]
240244

241245

242-
# ====================================================================
243-
# Quantile Huber Loss
244-
# -------------------
245-
246-
247-
def quantile_huber_loss_f(quantiles, samples):
248-
"""
249-
Quantile Huber loss from the original PyTorch TQC implementation.
250-
See: https://github.com/SamsungLabs/tqc_pytorch/blob/master/tqc/functions.py
251-
252-
quantiles is assumed to be of shape [batch size, n_nets, n_quantiles]
253-
samples is assumed to be of shape [batch size, n_samples]
254-
Arbitrary batch sizes are allowed.
255-
"""
256-
pairwise_delta = samples[..., None, None, :] - quantiles[..., None] # batch x n_nets x n_quantiles x n_samples
257-
abs_pairwise_delta = torch.abs(pairwise_delta)
258-
huber_loss = torch.where(abs_pairwise_delta > 1,
259-
abs_pairwise_delta - 0.5,
260-
pairwise_delta ** 2 * 0.5)
261-
n_quantiles = quantiles.shape[-1]
262-
tau = torch.arange(n_quantiles, device=quantiles.device).float() / n_quantiles + 1 / 2 / n_quantiles
263-
loss = (torch.abs(tau[..., None, :, None] - (pairwise_delta < 0).float()) * huber_loss).mean()
264-
return loss
265-
266-
267246
# ====================================================================
268247
# TQC Loss
269248
# --------
270249

271-
class TQCLoss(LossModule):
272-
def __init__(
273-
self,
274-
actor_network,
275-
qvalue_network,
276-
gamma,
277-
top_quantiles_to_drop,
278-
alpha_init,
279-
device
280-
):
281-
super().__init__()
282-
283-
self.convert_to_functional(
284-
actor_network,
285-
"actor",
286-
create_target_params=False,
287-
funs_to_decorate=["forward", "get_dist"],
288-
)
289-
290-
self.convert_to_functional(
291-
qvalue_network,
292-
"critic",
293-
create_target_params=True # Create a target critic network
294-
)
295-
296-
self.device = device
297-
self.log_alpha = torch.tensor([np.log(alpha_init)], requires_grad=True, device=self.device)
298-
self.gamma = gamma
299-
self.top_quantiles_to_drop = top_quantiles_to_drop
300-
301-
# Compute target entropy
302-
action_spec = getattr(self.actor, "spec", None)
303-
if action_spec is None:
304-
print("Could not deduce action spec from actor network.")
305-
if not isinstance(action_spec, CompositeSpec):
306-
action_spec = CompositeSpec({"action": action_spec})
307-
action_container_len = len(action_spec.shape)
308-
self.target_entropy = -float(action_spec["action"].shape[action_container_len:].numel())
309-
310-
def value_loss(self, tensordict):
311-
td_next = tensordict.get("next")
312-
reward = td_next.get("reward")
313-
not_done = tensordict.get("done").logical_not()
314-
alpha = torch.exp(self.log_alpha)
315-
316-
# Q-loss
317-
with torch.no_grad():
318-
# get policy action
319-
self.actor(td_next, params=self.actor_params)
320-
self.critic(td_next, params=self.target_critic_params)
321-
322-
next_log_pi = td_next.get("sample_log_prob")
323-
next_log_pi = torch.unsqueeze(next_log_pi, dim=-1)
324-
325-
# compute and cut quantiles at the next state
326-
next_z = td_next.get("state_action_value")
327-
sorted_z, _ = torch.sort(next_z.reshape(*tensordict.batch_size, -1))
328-
sorted_z_part = sorted_z[..., :-self.top_quantiles_to_drop]
329-
330-
# compute target
331-
# --- Note ---
332-
# This is computed manually here, since the built-in value estimators in the library
333-
# currently do not support a critic of a shape different from the reward.
334-
# ------------
335-
target = reward + not_done * self.gamma * (sorted_z_part - alpha * next_log_pi)
336-
337-
self.critic(tensordict, params=self.critic_params)
338-
cur_z = tensordict.get("state_action_value")
339-
critic_loss = quantile_huber_loss_f(cur_z, target)
340-
return critic_loss
341-
342-
def actor_loss(self, tensordict):
343-
alpha = torch.exp(self.log_alpha)
344-
self.actor(tensordict, params=self.actor_params)
345-
self.critic(tensordict, params=self.critic_params)
346-
new_log_pi = tensordict.get("sample_log_prob")
347-
actor_loss = (alpha * new_log_pi - tensordict.get("state_action_value").mean(-1).mean(-1, keepdim=True)).mean()
348-
return actor_loss, new_log_pi
349-
350-
def alpha_loss(self, log_prob):
351-
alpha_loss = -self.log_alpha * (log_prob + self.target_entropy).detach().mean()
352-
return alpha_loss
353-
354-
def entropy(self, tensordict):
355-
with set_exploration_type(ExplorationType.RANDOM):
356-
dist = self.actor.get_dist(
357-
tensordict,
358-
params=self.actor_params,
359-
)
360-
a_reparm = dist.rsample()
361-
log_prob = dist.log_prob(a_reparm).detach()
362-
entropy = -log_prob.mean()
363-
return entropy
364-
365-
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
366-
alpha = torch.exp(self.log_alpha)
367-
critic_loss = self.value_loss(tensordict)
368-
actor_loss, log_prob = self.actor_loss(tensordict) # Compute actor loss AFTER critic loss
369-
alpha_loss = self.alpha_loss(log_prob)
370-
entropy = self.entropy(tensordict)
371-
372-
return TensorDict(
373-
{
374-
"loss_critic": critic_loss,
375-
"loss_actor": actor_loss,
376-
"loss_alpha": alpha_loss,
377-
"alpha": alpha,
378-
"entropy": entropy,
379-
},
380-
batch_size=[]
381-
)
382-
383-
def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
384-
"""
385-
This is a dummy function, which simply checks if the value type is TD0 and raises
386-
an error if the value type is different. As of writing of this, the value estimators
387-
in the library do not support a critic shape different from the reward state, which
388-
is however necessary by construction for TQC. Therefore, this function does not
389-
actually construct a value estimator, and the value is estimated "by hand" in the
390-
value_loss function above.
391-
"""
392-
if value_type is not ValueEstimators.TD0:
393-
raise NotImplementedError(f"Value type {value_type} is not currently implemented.")
394-
395250

396251
def make_loss_module(cfg, model):
397252
"""Make loss module and target network updater."""
398253
# Create TQC loss
254+
top_quantiles_to_drop = (
255+
cfg.network.top_quantiles_to_drop_per_net * cfg.network.n_nets
256+
)
399257
loss_module = TQCLoss(
400258
actor_network=model[0],
401259
qvalue_network=model[1],
402-
device=cfg.network.device,
403-
gamma=cfg.optim.gamma,
404-
top_quantiles_to_drop=cfg.network.top_quantiles_to_drop_per_net * cfg.network.n_nets,
405-
alpha_init=cfg.optim.alpha_init
260+
top_quantiles_to_drop=top_quantiles_to_drop,
261+
alpha_init=cfg.optim.alpha_init,
262+
)
263+
loss_module.make_value_estimator(
264+
value_type=ValueEstimators.TD0, gamma=cfg.optim.gamma
406265
)
407-
loss_module.make_value_estimator(value_type=ValueEstimators.TD0)
408266

409267
# Define Target Network Updater
410268
target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak)

0 commit comments

Comments
 (0)