Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Nov 13, 2023
1 parent 33a5609 commit 70702cf
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions cleanrl/ppo_atari_multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ def get_action_and_value(self, x, action=None):
if __name__ == "__main__":
# torchrun --standalone --nnodes=1 --nproc_per_node=2 ppo_atari_multigpu.py
# taken from https://pytorch.org/docs/stable/elastic/run.html
local_rank = int(os.getenv("LOCAL_RANK", "0"))
args = tyro.cli(Args)
local_rank = int(os.getenv("LOCAL_RANK", "0"))
args.world_size = int(os.getenv("WORLD_SIZE", "1"))
args.local_batch_size = int(args.local_num_envs * args.num_steps)
args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
Expand Down Expand Up @@ -226,7 +226,7 @@ def get_action_and_value(self, x, action=None):

# env setup
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, i, args.capture_video, run_name) for i in range(args.num_envs)],
[make_env(args.env_id, i, args.capture_video, run_name) for i in range(args.local_num_envs)],
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

Expand All @@ -235,19 +235,19 @@ def get_action_and_value(self, x, action=None):
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

# ALGO Logic: Storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
obs = torch.zeros((args.num_steps, args.local_num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.local_num_envs) + envs.single_action_space.shape).to(device)
logprobs = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.local_num_envs)).to(device)
values = torch.zeros((args.num_steps, args.local_num_envs)).to(device)

# TRY NOT TO MODIFY: start the game
global_step = 0
start_time = time.time()
next_obs, _ = envs.reset(seed=args.seed)
next_obs = torch.Tensor(next_obs).to(device)
next_done = torch.zeros(args.num_envs).to(device)
next_done = torch.zeros(args.local_num_envs).to(device)

for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so.
Expand Down

0 comments on commit 70702cf

Please sign in to comment.