|
| 1 | +"""main executable file for TD3""" |
| 2 | +import os |
| 3 | +import logging |
| 4 | +from itertools import repeat |
| 5 | +import gymnasium as gym |
| 6 | +import torch |
| 7 | +import numpy as np |
| 8 | +from util import generate_gif |
| 9 | +from util.wrappers import TrainMonitor |
| 10 | +from util.buffer import Experience |
| 11 | +from collections import deque |
| 12 | +# pylint: disable=invalid-name |
| 13 | +from TD3.td3 import TD3Agent as TD3_torch |
| 14 | + |
| 15 | +Agent = TD3_torch |
| 16 | +logging.basicConfig(level=logging.INFO) |
| 17 | + |
| 18 | +torch.manual_seed(0) |
| 19 | +np.random.seed(0) |
| 20 | + |
| 21 | +EPSILON_DECAY_STEPS = 100 |
| 22 | + |
| 23 | + |
| 24 | +def main( |
| 25 | + n_episodes=2000, |
| 26 | + max_t=200, |
| 27 | + eps_start=1.0, |
| 28 | + eps_end=0.01, |
| 29 | + eps_decay=0.996, |
| 30 | + score_term_rules=lambda s: False, |
| 31 | + time_interval="25ms" |
| 32 | +): |
| 33 | + # pylint: disable=line-too-long |
| 34 | + """Deep Q-Learning |
| 35 | +
|
| 36 | + Params |
| 37 | + ====== |
| 38 | + n_episodes (int): maximum number of training epsiodes |
| 39 | + max_t (int): maximum number of timesteps per episode |
| 40 | + eps_start (float): starting value of epsilon, for epsilon-greedy action selection |
| 41 | + eps_end (float): minimum value of epsilon |
| 42 | + eps_decay (float): mutiplicative factor (per episode) for decreasing epsilon |
| 43 | +
|
| 44 | + """ |
| 45 | + scores = [] # list containing score from each episode |
| 46 | + scores_window = deque(maxlen=100) # last 100 scores |
| 47 | + eps = eps_start |
| 48 | + |
| 49 | + env = gym.make( |
| 50 | + "Pendulum-v1", |
| 51 | + render_mode="rgb_array", |
| 52 | + ) |
| 53 | + # env = gym.make( |
| 54 | + # "LunarLander-v2", |
| 55 | + # render_mode="rgb_array", |
| 56 | + # continuous=True, |
| 57 | + # ) |
| 58 | + # env = gym.make("MountainCarContinuous-v0", render_mode="rgb_array") |
| 59 | + env = TrainMonitor(env, tensorboard_dir="./logs", tensorboard_write_all=True) |
| 60 | + |
| 61 | + gamma = 0.99 |
| 62 | + batch_size = 64 |
| 63 | + learn_iteration = 16 |
| 64 | + update_tau = 0.5 |
| 65 | + |
| 66 | + lr_actor = 0.0001 |
| 67 | + lr_critic = 0.001 |
| 68 | + |
| 69 | + mu = 0.0 |
| 70 | + theta = 0.15 |
| 71 | + max_sigma = 0.3 |
| 72 | + min_sigma = 0.3 |
| 73 | + decay_period = 100000 |
| 74 | + value_noise_clip = 0.5 |
| 75 | + value_noise_sigma = 0.5 |
| 76 | + |
| 77 | + agent = Agent( |
| 78 | + state_dims=env.observation_space, |
| 79 | + action_space=env.action_space, |
| 80 | + lr_actor=lr_actor, |
| 81 | + lr_critic=lr_critic, |
| 82 | + gamma=gamma, |
| 83 | + batch_size=batch_size, |
| 84 | + forget_experience=False, |
| 85 | + update_tau=update_tau, |
| 86 | + mu=mu, |
| 87 | + theta=theta, |
| 88 | + max_sigma=max_sigma, |
| 89 | + min_sigma=min_sigma, |
| 90 | + decay_period=decay_period, |
| 91 | + value_noise_clip=value_noise_clip, |
| 92 | + value_noise_sigma=value_noise_sigma |
| 93 | + ) |
| 94 | + dump_gif_dir = f"images/{agent.__class__.__name__}/{agent.__class__.__name__}_{{}}.gif" |
| 95 | + for i_episode in range(1, n_episodes + 1): |
| 96 | + state, _ = env.reset() |
| 97 | + score = 0 |
| 98 | + for t, _ in enumerate(repeat(0, max_t)): |
| 99 | + action = agent.take_action(state=state, explore=True, step=t * i_episode) |
| 100 | + next_state, reward, done, _, _ = env.step(action) |
| 101 | + agent.remember(Experience(state, action, reward, next_state, done)) |
| 102 | + agent.learn(learn_iteration) |
| 103 | + |
| 104 | + state = next_state |
| 105 | + score += reward |
| 106 | + |
| 107 | + if done or score_term_rules(score): |
| 108 | + break |
| 109 | + |
| 110 | + scores_window.append(score) ## save the most recent score |
| 111 | + scores.append(score) ## sae the most recent score |
| 112 | + eps = max(eps * eps_decay, eps_end) ## decrease the epsilon |
| 113 | + print(" " * os.get_terminal_size().columns, end="\r") |
| 114 | + print( |
| 115 | + f"\rEpisode {i_episode}\tAverage Score {np.mean(scores_window):.2f}", |
| 116 | + end="\r" |
| 117 | + ) |
| 118 | + |
| 119 | + if i_episode and i_episode % 100 == 0: |
| 120 | + print(" " * os.get_terminal_size().columns, end="\r") |
| 121 | + print( |
| 122 | + f"\rEpisode {i_episode}\tAverage Score {np.mean(scores_window):.2f}" |
| 123 | + ) |
| 124 | + generate_gif( |
| 125 | + env, |
| 126 | + filepath=dump_gif_dir.format(i_episode), |
| 127 | + policy=lambda s: agent.take_action(s, explore=False), |
| 128 | + duration=float(time_interval.split("ms")[0]), |
| 129 | + max_episode_steps=max_t |
| 130 | + ) |
| 131 | + |
| 132 | + return scores |
| 133 | + |
| 134 | + |
| 135 | +if __name__ == "__main__": |
| 136 | + main() |
0 commit comments