Skip to content

Commit 31c657a

Browse files
Add TD3 main executable file (#19)
1 parent fab64a2 commit 31c657a

10 files changed

Lines changed: 1125 additions & 1 deletion

File tree

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@ For example (DDPG):
102102

103103
![AC](AC/A2CAgent_600.gif)
104104

105-
- improve `AWR`, `DDPG` with Gumbel Distribution Regression from [`XQL`](https://div99.github.io/XQL):
105+
- [x] [TD3](https://arxiv.org/pdf/1802.09477.pdf)
106+
107+
![TD3](TD3/TD3Agent_100.gif)
108+
109+
- improve `AWR`, `DDPG` `TD3` with Gumbel Distribution Regression from [`XQL`](https://div99.github.io/XQL):
106110
- XAWR
107111

108112
![XAWR](XAWR/XAWRAgent_100.gif)
@@ -111,6 +115,10 @@ For example (DDPG):
111115

112116
![XDDPG](XDDPG/XDDPGAgent_200.gif)
113117

118+
- XTD3
119+
120+
![XTD3](XTD3/XTD3Agent_100.gif)
121+
114122
## Reference
115123

116124
- TrainMonitor and Generategif modified from [coax](https://github.com/coax-dev/coax)

TD3/TD3Agent_100.gif

571 KB
Loading

TD3/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# pylint: disable=all

TD3/main.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""main executable file for TD3"""
2+
import os
3+
import logging
4+
from itertools import repeat
5+
import gymnasium as gym
6+
import torch
7+
import numpy as np
8+
from util import generate_gif
9+
from util.wrappers import TrainMonitor
10+
from util.buffer import Experience
11+
from collections import deque
12+
# pylint: disable=invalid-name
13+
from TD3.td3 import TD3Agent as TD3_torch
14+
15+
Agent = TD3_torch
16+
logging.basicConfig(level=logging.INFO)
17+
18+
torch.manual_seed(0)
19+
np.random.seed(0)
20+
21+
EPSILON_DECAY_STEPS = 100
22+
23+
24+
def main(
25+
n_episodes=2000,
26+
max_t=200,
27+
eps_start=1.0,
28+
eps_end=0.01,
29+
eps_decay=0.996,
30+
score_term_rules=lambda s: False,
31+
time_interval="25ms"
32+
):
33+
# pylint: disable=line-too-long
34+
"""Deep Q-Learning
35+
36+
Params
37+
======
38+
n_episodes (int): maximum number of training epsiodes
39+
max_t (int): maximum number of timesteps per episode
40+
eps_start (float): starting value of epsilon, for epsilon-greedy action selection
41+
eps_end (float): minimum value of epsilon
42+
eps_decay (float): mutiplicative factor (per episode) for decreasing epsilon
43+
44+
"""
45+
scores = [] # list containing score from each episode
46+
scores_window = deque(maxlen=100) # last 100 scores
47+
eps = eps_start
48+
49+
env = gym.make(
50+
"Pendulum-v1",
51+
render_mode="rgb_array",
52+
)
53+
# env = gym.make(
54+
# "LunarLander-v2",
55+
# render_mode="rgb_array",
56+
# continuous=True,
57+
# )
58+
# env = gym.make("MountainCarContinuous-v0", render_mode="rgb_array")
59+
env = TrainMonitor(env, tensorboard_dir="./logs", tensorboard_write_all=True)
60+
61+
gamma = 0.99
62+
batch_size = 64
63+
learn_iteration = 16
64+
update_tau = 0.5
65+
66+
lr_actor = 0.0001
67+
lr_critic = 0.001
68+
69+
mu = 0.0
70+
theta = 0.15
71+
max_sigma = 0.3
72+
min_sigma = 0.3
73+
decay_period = 100000
74+
value_noise_clip = 0.5
75+
value_noise_sigma = 0.5
76+
77+
agent = Agent(
78+
state_dims=env.observation_space,
79+
action_space=env.action_space,
80+
lr_actor=lr_actor,
81+
lr_critic=lr_critic,
82+
gamma=gamma,
83+
batch_size=batch_size,
84+
forget_experience=False,
85+
update_tau=update_tau,
86+
mu=mu,
87+
theta=theta,
88+
max_sigma=max_sigma,
89+
min_sigma=min_sigma,
90+
decay_period=decay_period,
91+
value_noise_clip=value_noise_clip,
92+
value_noise_sigma=value_noise_sigma
93+
)
94+
dump_gif_dir = f"images/{agent.__class__.__name__}/{agent.__class__.__name__}_{{}}.gif"
95+
for i_episode in range(1, n_episodes + 1):
96+
state, _ = env.reset()
97+
score = 0
98+
for t, _ in enumerate(repeat(0, max_t)):
99+
action = agent.take_action(state=state, explore=True, step=t * i_episode)
100+
next_state, reward, done, _, _ = env.step(action)
101+
agent.remember(Experience(state, action, reward, next_state, done))
102+
agent.learn(learn_iteration)
103+
104+
state = next_state
105+
score += reward
106+
107+
if done or score_term_rules(score):
108+
break
109+
110+
scores_window.append(score) ## save the most recent score
111+
scores.append(score) ## sae the most recent score
112+
eps = max(eps * eps_decay, eps_end) ## decrease the epsilon
113+
print(" " * os.get_terminal_size().columns, end="\r")
114+
print(
115+
f"\rEpisode {i_episode}\tAverage Score {np.mean(scores_window):.2f}",
116+
end="\r"
117+
)
118+
119+
if i_episode and i_episode % 100 == 0:
120+
print(" " * os.get_terminal_size().columns, end="\r")
121+
print(
122+
f"\rEpisode {i_episode}\tAverage Score {np.mean(scores_window):.2f}"
123+
)
124+
generate_gif(
125+
env,
126+
filepath=dump_gif_dir.format(i_episode),
127+
policy=lambda s: agent.take_action(s, explore=False),
128+
duration=float(time_interval.split("ms")[0]),
129+
max_episode_steps=max_t
130+
)
131+
132+
return scores
133+
134+
135+
if __name__ == "__main__":
136+
main()

0 commit comments

Comments
 (0)