Skip to content

[Algorithm] Added TQC #1631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ CQL
CQLLoss

DT
----
--

.. autosummary::
:toctree: generated/
Expand All @@ -148,14 +148,23 @@ DT
OnlineDTLoss

TD3
----
---

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

TD3Loss

TQC
---

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

TQCLoss

PPO
---

Expand Down
52 changes: 52 additions & 0 deletions examples/tqc/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# environment and task
env:
name: MountainCarContinuous-v0
task: ""
exp_name: ${env.name}_TQC
library: gymnasium
max_episode_steps: 1_000
seed: 42

# collector
collector:
total_frames: 1_000_000
init_random_frames: 25_000
frames_per_batch: 1_000
collector_device: cpu
env_per_collector: 1
reset_at_each_iter: False

# replay buffer
replay_buffer:
size: 1_000_000
prb: False # use prioritized experience replay
scratch_dir:

# optim
optim:
utd_ratio: 1.0
gamma: 0.99
lr: 3.0e-4
weight_decay: 0.0
batch_size: 256
target_update_polyak: 0.995
alpha_init: 1.0
adam_eps: 1.0e-8

# network
network:
actor_hidden_sizes: [256, 256]
critic_hidden_sizes: [512, 512, 512]
n_quantiles: 25
n_nets: 5
top_quantiles_to_drop_per_net: 2
activation: relu
default_policy_scale: 1.0
scale_lb: 0.1
device: cuda

# logging
logger:
backend: wandb
mode: online
eval_iter: 25_000
231 changes: 231 additions & 0 deletions examples/tqc/tqc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""TQC Example.

This is a simple self-contained example of a TQC training script.

The implementation is based on the implementation of SAC in the examples
directory. TQC was introduced in

"Controlling Overestimation Bias with Truncated Mixture of Continuous
Distributional Quantile Critics" (Arsenii Kuznetsov, Pavel Shvechikov,
Alexander Grishin, Dmitry Vetrov, 2020)

Available from https://proceedings.mlr.press/v119/kuznetsov20a.html.

Oftentimes, we follow the naming conventions used in the original TQC
PyTorch implementation, to facilitate the comparison with the present
implementation. Original PyTorch TQC code is available here:

https://github.com/SamsungLabs/tqc_pytorch/tree/master

All hyperparameters are set to the values used in the original
implementation.

The helper functions are coded in the utils.py associated with this script.
"""

import time

import hydra
import numpy as np
import torch
import torch.cuda
import tqdm
from tensordict import TensorDict
from torchrl.envs.utils import ExplorationType, set_exploration_type

from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
log_metrics,
make_collector,
make_environment,
make_loss_module,
make_replay_buffer,
make_tqc_agent,
make_tqc_optimizer,
)


@hydra.main(version_base="1.1", config_path=".", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821
device = torch.device(cfg.network.device)

# Create logger
exp_name = generate_exp_name("SAC", cfg.env.exp_name)
logger = None
# TO-DO: Add logging back in before pushing to git repo
# if cfg.logger.backend:
# logger = get_logger(
# logger_type=cfg.logger.backend,
# logger_name="sac_logging/wandb",
# experiment_name=exp_name,
# wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
# )

torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)

# Create environments
train_env, eval_env = make_environment(cfg)

# Create agent
model, exploration_policy = make_tqc_agent(cfg, train_env, eval_env, device)

# Create SAC loss
loss_module, target_net_updater = make_loss_module(cfg, model)

# Create off-policy collector
collector = make_collector(cfg, train_env, exploration_policy)

# Create replay buffer
replay_buffer = make_replay_buffer(
batch_size=cfg.optim.batch_size,
prb=cfg.replay_buffer.prb,
buffer_size=cfg.replay_buffer.size,
buffer_scratch_dir=cfg.replay_buffer.scratch_dir,
device=device,
)

# Create optimizers
(
optimizer_actor,
optimizer_critic,
optimizer_alpha,
) = make_tqc_optimizer(cfg, loss_module)

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
prb = cfg.replay_buffer.prb
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
eval_rollout_steps = cfg.env.max_episode_steps

sampling_start = time.time()
for i, tensordict in enumerate(collector):

sampling_time = time.time() - sampling_start
# Update weights of the inference policy
collector.update_policy_weights_()

pbar.update(tensordict.numel())

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
collected_frames += current_frames

# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
losses = TensorDict(
{},
batch_size=[
num_updates,
],
)
for i in range(num_updates):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample().clone()

# Compute loss
loss_td = loss_module(sampled_tensordict)

actor_loss = loss_td["loss_actor"]
q_loss = loss_td["loss_critic"]
alpha_loss = loss_td["loss_alpha"]

# Update actor
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()

# Update alpha
optimizer_alpha.zero_grad()
alpha_loss.backward()
optimizer_alpha.step()

losses[i] = loss_td.select(
"loss_actor", "loss_critic", "loss_alpha"
).detach()

# Update qnet_target params
target_net_updater.step()

# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)

training_time = time.time() - training_start
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
else tensordict["next", "truncated"]
)
episode_rewards = tensordict["next", "episode_reward"][episode_end]

# Logging
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][episode_end]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
if collected_frames >= init_random_frames:
metrics_to_log["train/critic_loss"] = (
losses.get("loss_critic").mean().item()
)
metrics_to_log["train/actor_loss"] = losses.get("loss_actor").mean().item()
metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha").mean().item()
metrics_to_log["train/alpha"] = loss_td["alpha"].item()
metrics_to_log["train/entropy"] = loss_td["entropy"].item()
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
model[0],
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)

sampling_start = time.time()

collector.shutdown()

end_time = time.time()
execution_time = end_time - start_time
print(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
main()
Loading