Skip to content

Commit

Permalink
added rendering trained agent
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 17, 2020
1 parent d34a333 commit ef4bb92
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 5 deletions.
2 changes: 1 addition & 1 deletion networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def initialize_modules(model, nonlinearity='leaky_relu', init_type='kaiming'):

def load_weights(state_dict_path, models, model_names, optimizers=[], optimizer_names=[], return_val=None, return_vals=None):
def put_in_list(item):
if not isinstance(item, list, tuple) and item is not None:
if not isinstance(item, (list, tuple)) and item is not None:
item = [item]
return item

Expand Down
9 changes: 5 additions & 4 deletions td3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,13 @@
score_history.append(score)
moving_avg = sum(score_history) / len(score_history)
agent.add_scalar('Average Score', moving_avg, global_step=e)
tqdm.write(f'Episode: {e + 1}/{args.n_episodes}, \
Episode Score: {score}, \
Average Score: {moving_avg}, \
Best Score: {best_score}')

# save weights @ best score
if moving_avg > best_score:
best_score = moving_avg
agent.save_networks()

tqdm.write(f'Episode: {e + 1}/{args.n_episodes}, \
Episode Score: {score}, \
Average Score: {moving_avg}, \
Best Score: {best_score}')
39 changes: 39 additions & 0 deletions td3/render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import gym
from tqdm import tqdm

from td3.agent import Agent
from td3.main import args
from networks.utils import load_weights


if __name__ == '__main__':
args.checkpoint_dir += f'/{args.env_name}_td3.pth'
# env & agent
env = gym.make(args.env_name)
agent = Agent(env, args.alpha, args.beta, args.hidden_dims, args.tau, args.batch_size,
args.gamma, args.d, 0, args.max_size, args.c, args.sigma,
args.one_device, args.log_dir, args.checkpoint_dir)
best_score = env.reward_range[0]
load_weights(args.checkpoint_dir,
[agent.actor] , ['actor'])
episodes = tqdm(range(args.n_episodes))
for e in episodes:
# resetting
state = env.reset()
done = False
score = 0

while not done:
action = agent.choose_action(state)
state_, reward, done, _ = env.step(action)

# reset, log & render
score += reward
state = state_
episodes.set_postfix({'Reward': reward})
env.render()
if score > best_score:
best_score = score
tqdm.write(f'Episode: {e + 1}/{args.n_episodes}, \
Episode Score: {score}, \
Best Score: {best_score}')

0 comments on commit ef4bb92

Please sign in to comment.