Skip to content

Commit

Permalink
Separate train and eval environments for ddpg and sac.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 259407411
Change-Id: Ia9d3215d00e908613390b4cb0230df32e8250bc2
  • Loading branch information
TF-Agents Team authored and copybara-github committed Jul 22, 2019
1 parent 80a2172 commit 11a875c
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 4 deletions.
4 changes: 3 additions & 1 deletion tf_agents/agents/ddpg/examples/v1/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
def train_eval(
root_dir,
env_name='HalfCheetah-v2',
eval_env_name=None,
env_load_fn=suite_mujoco.load,
num_iterations=2000000,
actor_fc_layers=(400, 300),
Expand Down Expand Up @@ -132,7 +133,8 @@ def train_eval(
[lambda: env_load_fn(env_name)] * num_parallel_environments))
else:
tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
eval_py_env = env_load_fn(env_name)
eval_env_name = eval_env_name or env_name
eval_py_env = env_load_fn(eval_env_name)

actor_net = actor_network.ActorNetwork(
tf_env.time_step_spec().observation,
Expand Down
4 changes: 3 additions & 1 deletion tf_agents/agents/ddpg/examples/v2/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
def train_eval(
root_dir,
env_name='HalfCheetah-v2',
eval_env_name=None,
env_load_fn=suite_mujoco.load,
num_iterations=2000000,
actor_fc_layers=(400, 300),
Expand Down Expand Up @@ -132,7 +133,8 @@ def train_eval(
[lambda: env_load_fn(env_name)] * num_parallel_environments))
else:
tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
eval_env_name = eval_env_name or env_name
eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(eval_env_name))

actor_net = actor_network.ActorNetwork(
tf_env.time_step_spec().observation,
Expand Down
4 changes: 3 additions & 1 deletion tf_agents/agents/sac/examples/v1/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def normal_projection_net(action_spec,
def train_eval(
root_dir,
env_name='HalfCheetah-v2',
eval_env_name=None,
env_load_fn=suite_mujoco.load,
num_iterations=1000000,
actor_fc_layers=(256, 256),
Expand Down Expand Up @@ -144,7 +145,8 @@ def train_eval(
lambda: tf.math.equal(global_step % summary_interval, 0)):
# Create the environment.
tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
eval_py_env = env_load_fn(env_name)
eval_env_name = eval_env_name or env_name
eval_py_env = env_load_fn(eval_env_name)

# Get the data specs from the environment
time_step_spec = tf_env.time_step_spec()
Expand Down
4 changes: 3 additions & 1 deletion tf_agents/agents/sac/examples/v2/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def normal_projection_net(action_spec,
def train_eval(
root_dir,
env_name='HalfCheetah-v2',
eval_env_name=None,
env_load_fn=suite_mujoco.load,
num_iterations=1000000,
actor_fc_layers=(256, 256),
Expand Down Expand Up @@ -142,7 +143,8 @@ def train_eval(
with tf.compat.v2.summary.record_if(
lambda: tf.math.equal(global_step % summary_interval, 0)):
tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
eval_env_name = eval_env_name or env_name
eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(eval_env_name))

time_step_spec = tf_env.time_step_spec()
observation_spec = time_step_spec.observation
Expand Down

0 comments on commit 11a875c

Please sign in to comment.