|
5 | 5 |
|
6 | 6 | import tempfile
|
7 | 7 | from contextlib import nullcontext
|
| 8 | +from typing import Tuple |
| 9 | + |
8 | 10 | import torch
|
9 | 11 | from tensordict.nn import InteractionType, TensorDictModule
|
10 | 12 | from tensordict.nn.distributions import NormalParamExtractor
|
| 13 | +from tensordict.tensordict import TensorDict, TensorDictBase |
11 | 14 | from torch import nn, optim
|
12 | 15 | from torchrl.collectors import SyncDataCollector
|
13 |
| -from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer |
| 16 | +from torchrl.data import ( |
| 17 | + CompositeSpec, |
| 18 | + TensorDictPrioritizedReplayBuffer, |
| 19 | + TensorDictReplayBuffer, |
| 20 | +) |
14 | 21 | from torchrl.data.replay_buffers.storages import LazyMemmapStorage
|
15 | 22 | from torchrl.envs import Compose, DoubleToFloat, EnvCreator, ParallelEnv, TransformedEnv
|
16 | 23 | from torchrl.envs.libs.gym import GymEnv, set_gym_backend
|
17 | 24 | from torchrl.envs.transforms import InitTracker, RewardSum, StepCounter
|
18 | 25 | from torchrl.envs.utils import ExplorationType, set_exploration_type
|
19 |
| -from torchrl.modules import MLP, ProbabilisticActor, ValueOperator, ActorCriticWrapper |
| 26 | +from torchrl.modules import ActorCriticWrapper, MLP, ProbabilisticActor, ValueOperator |
20 | 27 | from torchrl.modules.distributions import TanhNormal
|
21 |
| -from torchrl.objectives import SoftUpdate |
22 |
| -from torchrl.data import CompositeSpec |
23 |
| -from torchrl.objectives.common import LossModule |
| 28 | +from torchrl.objectives import SoftUpdate, TQCLoss |
24 | 29 | from torchrl.objectives.utils import (
|
25 | 30 | _cache_values,
|
26 | 31 | _GAMMA_LMBDA_DEPREC_WARNING,
|
|
29 | 34 | ValueEstimators,
|
30 | 35 | )
|
31 | 36 | from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator
|
32 |
| -from tensordict.tensordict import TensorDict, TensorDictBase |
33 |
| -from typing import Tuple |
34 | 37 |
|
35 | 38 |
|
36 | 39 | # ====================================================================
|
37 | 40 | # Environment utils
|
38 | 41 | # -----------------
|
39 |
| -from torchrl.objectives.tcq import TQCLoss |
40 | 42 |
|
41 | 43 |
|
42 | 44 | def env_maker(task, device="cpu"):
|
@@ -100,17 +102,17 @@ def make_collector(cfg, train_env, actor_model_explore):
|
100 | 102 |
|
101 | 103 |
|
102 | 104 | def make_replay_buffer(
|
103 |
| - batch_size, |
104 |
| - prb=False, |
105 |
| - buffer_size=1_000_000, |
106 |
| - buffer_scratch_dir=None, |
107 |
| - device="cpu", |
108 |
| - prefetch=3, |
| 105 | + batch_size, |
| 106 | + prb=False, |
| 107 | + buffer_size=1_000_000, |
| 108 | + buffer_scratch_dir=None, |
| 109 | + device="cpu", |
| 110 | + prefetch=3, |
109 | 111 | ):
|
110 | 112 | with (
|
111 |
| - tempfile.TemporaryDirectory() |
112 |
| - if buffer_scratch_dir is None |
113 |
| - else nullcontext(buffer_scratch_dir) |
| 113 | + tempfile.TemporaryDirectory() |
| 114 | + if buffer_scratch_dir is None |
| 115 | + else nullcontext(buffer_scratch_dir) |
114 | 116 | ) as scratch_dir:
|
115 | 117 | if prb:
|
116 | 118 | replay_buffer = TensorDictPrioritizedReplayBuffer(
|
@@ -155,13 +157,15 @@ def __init__(self, cfg):
|
155 | 157 | }
|
156 | 158 | for i in range(cfg.network.n_nets):
|
157 | 159 | net = MLP(**qvalue_net_kwargs)
|
158 |
| - self.add_module(f'critic_net_{i}', net) |
| 160 | + self.add_module(f"critic_net_{i}", net) |
159 | 161 | self.nets.append(net)
|
160 | 162 |
|
161 | 163 | def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:
|
162 | 164 | if len(inputs) > 1:
|
163 | 165 | inputs = (torch.cat([*inputs], -1),)
|
164 |
| - quantiles = torch.stack(tuple(net(*inputs) for net in self.nets), dim=-2) # batch x n_nets x n_quantiles |
| 166 | + quantiles = torch.stack( |
| 167 | + tuple(net(*inputs) for net in self.nets), dim=-2 |
| 168 | + ) # batch x n_nets x n_quantiles |
165 | 169 | return quantiles
|
166 | 170 |
|
167 | 171 |
|
@@ -239,172 +243,26 @@ def make_tqc_agent(cfg, train_env, eval_env, device):
|
239 | 243 | return model, model[0]
|
240 | 244 |
|
241 | 245 |
|
242 |
| -# ==================================================================== |
243 |
| -# Quantile Huber Loss |
244 |
| -# ------------------- |
245 |
| - |
246 |
| - |
247 |
| -def quantile_huber_loss_f(quantiles, samples): |
248 |
| - """ |
249 |
| - Quantile Huber loss from the original PyTorch TQC implementation. |
250 |
| - See: https://github.com/SamsungLabs/tqc_pytorch/blob/master/tqc/functions.py |
251 |
| -
|
252 |
| - quantiles is assumed to be of shape [batch size, n_nets, n_quantiles] |
253 |
| - samples is assumed to be of shape [batch size, n_samples] |
254 |
| - Arbitrary batch sizes are allowed. |
255 |
| - """ |
256 |
| - pairwise_delta = samples[..., None, None, :] - quantiles[..., None] # batch x n_nets x n_quantiles x n_samples |
257 |
| - abs_pairwise_delta = torch.abs(pairwise_delta) |
258 |
| - huber_loss = torch.where(abs_pairwise_delta > 1, |
259 |
| - abs_pairwise_delta - 0.5, |
260 |
| - pairwise_delta ** 2 * 0.5) |
261 |
| - n_quantiles = quantiles.shape[-1] |
262 |
| - tau = torch.arange(n_quantiles, device=quantiles.device).float() / n_quantiles + 1 / 2 / n_quantiles |
263 |
| - loss = (torch.abs(tau[..., None, :, None] - (pairwise_delta < 0).float()) * huber_loss).mean() |
264 |
| - return loss |
265 |
| - |
266 |
| - |
267 | 246 | # ====================================================================
|
268 | 247 | # TQC Loss
|
269 | 248 | # --------
|
270 | 249 |
|
271 |
| -class TQCLoss(LossModule): |
272 |
| - def __init__( |
273 |
| - self, |
274 |
| - actor_network, |
275 |
| - qvalue_network, |
276 |
| - gamma, |
277 |
| - top_quantiles_to_drop, |
278 |
| - alpha_init, |
279 |
| - device |
280 |
| - ): |
281 |
| - super().__init__() |
282 |
| - |
283 |
| - self.convert_to_functional( |
284 |
| - actor_network, |
285 |
| - "actor", |
286 |
| - create_target_params=False, |
287 |
| - funs_to_decorate=["forward", "get_dist"], |
288 |
| - ) |
289 |
| - |
290 |
| - self.convert_to_functional( |
291 |
| - qvalue_network, |
292 |
| - "critic", |
293 |
| - create_target_params=True # Create a target critic network |
294 |
| - ) |
295 |
| - |
296 |
| - self.device = device |
297 |
| - self.log_alpha = torch.tensor([np.log(alpha_init)], requires_grad=True, device=self.device) |
298 |
| - self.gamma = gamma |
299 |
| - self.top_quantiles_to_drop = top_quantiles_to_drop |
300 |
| - |
301 |
| - # Compute target entropy |
302 |
| - action_spec = getattr(self.actor, "spec", None) |
303 |
| - if action_spec is None: |
304 |
| - print("Could not deduce action spec from actor network.") |
305 |
| - if not isinstance(action_spec, CompositeSpec): |
306 |
| - action_spec = CompositeSpec({"action": action_spec}) |
307 |
| - action_container_len = len(action_spec.shape) |
308 |
| - self.target_entropy = -float(action_spec["action"].shape[action_container_len:].numel()) |
309 |
| - |
310 |
| - def value_loss(self, tensordict): |
311 |
| - td_next = tensordict.get("next") |
312 |
| - reward = td_next.get("reward") |
313 |
| - not_done = tensordict.get("done").logical_not() |
314 |
| - alpha = torch.exp(self.log_alpha) |
315 |
| - |
316 |
| - # Q-loss |
317 |
| - with torch.no_grad(): |
318 |
| - # get policy action |
319 |
| - self.actor(td_next, params=self.actor_params) |
320 |
| - self.critic(td_next, params=self.target_critic_params) |
321 |
| - |
322 |
| - next_log_pi = td_next.get("sample_log_prob") |
323 |
| - next_log_pi = torch.unsqueeze(next_log_pi, dim=-1) |
324 |
| - |
325 |
| - # compute and cut quantiles at the next state |
326 |
| - next_z = td_next.get("state_action_value") |
327 |
| - sorted_z, _ = torch.sort(next_z.reshape(*tensordict.batch_size, -1)) |
328 |
| - sorted_z_part = sorted_z[..., :-self.top_quantiles_to_drop] |
329 |
| - |
330 |
| - # compute target |
331 |
| - # --- Note --- |
332 |
| - # This is computed manually here, since the built-in value estimators in the library |
333 |
| - # currently do not support a critic of a shape different from the reward. |
334 |
| - # ------------ |
335 |
| - target = reward + not_done * self.gamma * (sorted_z_part - alpha * next_log_pi) |
336 |
| - |
337 |
| - self.critic(tensordict, params=self.critic_params) |
338 |
| - cur_z = tensordict.get("state_action_value") |
339 |
| - critic_loss = quantile_huber_loss_f(cur_z, target) |
340 |
| - return critic_loss |
341 |
| - |
342 |
| - def actor_loss(self, tensordict): |
343 |
| - alpha = torch.exp(self.log_alpha) |
344 |
| - self.actor(tensordict, params=self.actor_params) |
345 |
| - self.critic(tensordict, params=self.critic_params) |
346 |
| - new_log_pi = tensordict.get("sample_log_prob") |
347 |
| - actor_loss = (alpha * new_log_pi - tensordict.get("state_action_value").mean(-1).mean(-1, keepdim=True)).mean() |
348 |
| - return actor_loss, new_log_pi |
349 |
| - |
350 |
| - def alpha_loss(self, log_prob): |
351 |
| - alpha_loss = -self.log_alpha * (log_prob + self.target_entropy).detach().mean() |
352 |
| - return alpha_loss |
353 |
| - |
354 |
| - def entropy(self, tensordict): |
355 |
| - with set_exploration_type(ExplorationType.RANDOM): |
356 |
| - dist = self.actor.get_dist( |
357 |
| - tensordict, |
358 |
| - params=self.actor_params, |
359 |
| - ) |
360 |
| - a_reparm = dist.rsample() |
361 |
| - log_prob = dist.log_prob(a_reparm).detach() |
362 |
| - entropy = -log_prob.mean() |
363 |
| - return entropy |
364 |
| - |
365 |
| - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: |
366 |
| - alpha = torch.exp(self.log_alpha) |
367 |
| - critic_loss = self.value_loss(tensordict) |
368 |
| - actor_loss, log_prob = self.actor_loss(tensordict) # Compute actor loss AFTER critic loss |
369 |
| - alpha_loss = self.alpha_loss(log_prob) |
370 |
| - entropy = self.entropy(tensordict) |
371 |
| - |
372 |
| - return TensorDict( |
373 |
| - { |
374 |
| - "loss_critic": critic_loss, |
375 |
| - "loss_actor": actor_loss, |
376 |
| - "loss_alpha": alpha_loss, |
377 |
| - "alpha": alpha, |
378 |
| - "entropy": entropy, |
379 |
| - }, |
380 |
| - batch_size=[] |
381 |
| - ) |
382 |
| - |
383 |
| - def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): |
384 |
| - """ |
385 |
| - This is a dummy function, which simply checks if the value type is TD0 and raises |
386 |
| - an error if the value type is different. As of writing of this, the value estimators |
387 |
| - in the library do not support a critic shape different from the reward state, which |
388 |
| - is however necessary by construction for TQC. Therefore, this function does not |
389 |
| - actually construct a value estimator, and the value is estimated "by hand" in the |
390 |
| - value_loss function above. |
391 |
| - """ |
392 |
| - if value_type is not ValueEstimators.TD0: |
393 |
| - raise NotImplementedError(f"Value type {value_type} is not currently implemented.") |
394 |
| - |
395 | 250 |
|
396 | 251 | def make_loss_module(cfg, model):
|
397 | 252 | """Make loss module and target network updater."""
|
398 | 253 | # Create TQC loss
|
| 254 | + top_quantiles_to_drop = ( |
| 255 | + cfg.network.top_quantiles_to_drop_per_net * cfg.network.n_nets |
| 256 | + ) |
399 | 257 | loss_module = TQCLoss(
|
400 | 258 | actor_network=model[0],
|
401 | 259 | qvalue_network=model[1],
|
402 |
| - device=cfg.network.device, |
403 |
| - gamma=cfg.optim.gamma, |
404 |
| - top_quantiles_to_drop=cfg.network.top_quantiles_to_drop_per_net * cfg.network.n_nets, |
405 |
| - alpha_init=cfg.optim.alpha_init |
| 260 | + top_quantiles_to_drop=top_quantiles_to_drop, |
| 261 | + alpha_init=cfg.optim.alpha_init, |
| 262 | + ) |
| 263 | + loss_module.make_value_estimator( |
| 264 | + value_type=ValueEstimators.TD0, gamma=cfg.optim.gamma |
406 | 265 | )
|
407 |
| - loss_module.make_value_estimator(value_type=ValueEstimators.TD0) |
408 | 266 |
|
409 | 267 | # Define Target Network Updater
|
410 | 268 | target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak)
|
|
0 commit comments