Skip to content

Commit

Permalink
modified agent
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 23, 2020
1 parent a1f4087 commit 289fde2
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
Empty file added td3/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion td3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _get_noise(self, clip=True):
if clip:
noise = noise.clamp(-self.c, self.c)
return noise

def _clamp_action_bound(self, action):
return action.clamp(self.min_action, self.max_action)

Expand Down
15 changes: 11 additions & 4 deletions td3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path
from collections import deque
import torch
import numpy as np

from td3.utils import preprocess_img
from td3.agent import Agent
Expand All @@ -23,6 +24,7 @@
parser.add_argument('--gamma', type=float, default=0.99, help='Reward discount factor')
parser.add_argument('--sigma', type=float, default=0.2, help='Gaussian noise standard deviation')
parser.add_argument('--c', type=float, default=0.5, help='Noise clip')
parser.add_argument('--virtual_display', action="store_false", default=True, help='Enable virtual display')
# for image input agent
parser.add_argument('--img_input', action="store_true", default=False, help='Use image as states')
parser.add_argument('--in_channels', type=int, default=3, help='Number of image channels for image input')
Expand Down Expand Up @@ -52,6 +54,10 @@


if __name__ == '__main__':
if args.virtual_display:
import pyvirtualdisplay
_display = pyvirtualdisplay.Display(visible=False, size=(1400, 900))
_ = _display.start()
# paths
Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
Expand All @@ -76,16 +82,17 @@
# resetting
state = env.reset()
if args.img_input:
state_queue = next_state_queue = deque(
state_queue = deque(
[preprocess_img(state['pixels'], args.crop_dim) for _ in range(args.order)],
maxlen=args.order)
state = torch.cat(list(state_queue), 1)
done = False
score = 0
state = torch.cat(list(state_queue), 1).cpu().numpy()
done, score = False, 0

while not done:
action = agent.choose_action(state)
state_, reward, done, _ = env.step(action)
if isinstance(reward, np.ndarray):
reward = reward[0]
if args.img_input:
state_queue.append(preprocess_img(state_['pixels'], args.crop_dim))
state_ = torch.cat(list(state_queue), 1).cpu().numpy()
Expand Down

0 comments on commit 289fde2

Please sign in to comment.