Skip to content

Commit 8a75fb1

Browse files
authored
Fixed RL examples to work with new gym API (#2706)
1 parent 7de559c commit 8a75fb1

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

examples/reinforcement_learning/actor_critic.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def main(env, args):
8282
def run_single_timestep(engine, timestep):
8383
observation = engine.state.observation
8484
action = select_action(model, observation)
85-
engine.state.observation, reward, done, _ = env.step(action)
85+
engine.state.observation, reward, done, _, _ = env.step(action)
8686
if args.render:
8787
env.render()
8888
model.rewards.append(reward)
@@ -99,7 +99,8 @@ def initialize(engine):
9999

100100
@trainer.on(EPISODE_STARTED)
101101
def reset_environment_state(engine):
102-
engine.state.observation = env.reset()
102+
torch.manual_seed(args.seed + trainer.state.epoch)
103+
engine.state.observation, _ = env.reset(seed=args.seed + trainer.state.epoch)
103104

104105
@trainer.on(EPISODE_COMPLETED)
105106
def update_model(engine):
@@ -147,7 +148,5 @@ def should_finish_training(engine):
147148
args = parser.parse_args()
148149

149150
env = gym.make("CartPole-v1")
150-
env.seed(args.seed)
151-
torch.manual_seed(args.seed)
152151

153152
main(env, args)

examples/reinforcement_learning/reinforce.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def main(env, args):
7272
def run_single_timestep(engine, timestep):
7373
observation = engine.state.observation
7474
action = select_action(model, observation)
75-
engine.state.observation, reward, done, _ = env.step(action)
75+
engine.state.observation, reward, done, _, _ = env.step(action)
7676
if args.render:
7777
env.render()
7878
model.rewards.append(reward)
@@ -89,7 +89,8 @@ def initialize(engine):
8989

9090
@trainer.on(EPISODE_STARTED)
9191
def reset_environment_state(engine):
92-
engine.state.observation = env.reset()
92+
torch.manual_seed(args.seed + trainer.state.epoch)
93+
engine.state.observation, _ = env.reset(seed=args.seed + trainer.state.epoch)
9394

9495
@trainer.on(EPISODE_COMPLETED)
9596
def update_model(engine):
@@ -137,7 +138,5 @@ def should_finish_training(engine):
137138
args = parser.parse_args()
138139

139140
env = gym.make("CartPole-v1")
140-
env.seed(args.seed)
141-
torch.manual_seed(args.seed)
142141

143142
main(env, args)

0 commit comments

Comments
 (0)