Skip to content

Commit 1111e9a

Browse files
chendiwAboudyKreidieh
authored andcommitted
moved imports under functions in train.py (#903)
* deleting unworking params from SumoChangeLaneParams * deleted unworking params, sublane working in highway : * moved imports inside functions * Apply suggestions from code review * bug fixes * bug fix Co-authored-by: Aboudy Kreidieh <akreidieh@gmail.com>
1 parent 0ade197 commit 1111e9a

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

examples/train.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def run_model_stablebaseline(flow_params,
124124
stable_baselines.*
125125
the trained model
126126
"""
127+
from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv
128+
from stable_baselines import PPO2
129+
127130
if num_cpus == 1:
128131
constructor = env_constructor(params=flow_params, version=0)()
129132
# The algorithms require a vectorized environment to run
@@ -172,7 +175,12 @@ def setup_exps_rllib(flow_params,
172175
dict
173176
training configuration parameters
174177
"""
178+
from ray import tune
175179
from ray.tune.registry import register_env
180+
try:
181+
from ray.rllib.agents.agent import get_agent_class
182+
except ImportError:
183+
from ray.rllib.agents.registry import get_agent_class
176184

177185
horizon = flow_params['env'].horizon
178186

@@ -404,6 +412,9 @@ def train_h_baselines(flow_params, args, multiagent):
404412

405413
def train_stable_baselines(submodule, flags):
406414
"""Train policies using the PPO algorithm in stable-baselines."""
415+
from stable_baselines.common.vec_env import DummyVecEnv
416+
from stable_baselines import PPO2
417+
407418
flow_params = submodule.flow_params
408419
# Path to the saved files
409420
exp_tag = flow_params['exp_tag']

0 commit comments

Comments
 (0)