-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
85 lines (65 loc) · 2.1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
from tqdm import tqdm
import time
import os
from snake_env import SnakeEnv
from DQN_agent import DQNAgent
env = SnakeEnv()
agent = DQNAgent()
np.random.seed(1)
if not os.path.isdir("models"):
os.makedirs("models")
# Training settings
EPISODES = 20_000
# Exploration settings
epsilon = 1
EPSILON_DECAY = 0.99975
MIN_EPSILON = 1e-3
# More Settings
SHOW_PREVIEW = True
SHOW_EVERY = 1
STATS_EVERY = 50
MIN_AVG_REWARD_FOR_SAVE = 200
ep_rewards = []
step = 0
for episode in tqdm(range(1, EPISODES + 1), ascii=True, unit="episode"):
# Update tensorboard step
agent.callback.step = episode
# Restarting episode
episode_reward = 0
done = False
obs = env.reset()
while not done:
# Choose action
if np.random.random() > epsilon:
action = np.argmax(agent.get_qs(obs))
else:
action = np.random.randint(0, env.ACTION_SPACE_SIZE)
# Take a step
newObs, reward, done = env.step(action)
step += 1
# Update total episode reward
episode_reward += reward
if SHOW_PREVIEW and not episode % SHOW_EVERY:
env.render()
# Add to memory buffer
agent.update_replay_memory((obs, action, reward, newObs, done))
# Train agent
agent.train(done)
# Update state
obs = newObs
# Keep track of episode rewards
ep_rewards.append(episode_reward)
# Log stats
if not episode % STATS_EVERY:
average_reward = np.mean(ep_rewards[-STATS_EVERY:])
min_reward = min(ep_rewards[-STATS_EVERY:])
max_reward = max(ep_rewards[-STATS_EVERY:])
agent.callback.update_stats(reward_avg=average_reward, reward_min=min_reward, reward_max=max_reward, epsilon=epsilon)
# Save model
if average_reward >= MIN_AVG_REWARD_FOR_SAVE:
agent.model.save(f"models/{agent.MODEL_NAME}{max_reward:_>6.2f}max{average_reward:_>6.2f}avg{min_reward:_>6.2f}min{int(time.time())}.model")
# Decay epsilon
if epsilon > MIN_EPSILON:
epsilon *= EPSILON_DECAY
epsilon = max(epsilon, MIN_EPSILON)