Skip to content

Commit

Permalink
Export agent policy as TF SavedModel for Atari using policy saver.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 260949267
Change-Id: I9454b7cc15c1902a93188d9bd01873b3a0ff01d2
  • Loading branch information
TF-Agents Team authored and copybara-github committed Jul 31, 2019
1 parent 2951a8c commit 92756d8
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions tf_agents/agents/dqn/examples/v1/train_eval_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from tf_agents.metrics import py_metrics
from tf_agents.networks import q_network
from tf_agents.policies import epsilon_greedy_policy
from tf_agents.policies import policy_saver
from tf_agents.policies import py_tf_policy
from tf_agents.policies import random_py_policy
from tf_agents.replay_buffers import py_hashed_replay_buffer
Expand Down Expand Up @@ -374,6 +375,9 @@ def __init__(
train_step=self._global_step,
step_metrics=(self._iteration_metric,))

self._train_dir = train_dir
self._policy_exporter = policy_saver.PolicySaver(
agent.policy, train_step=self._global_step)
self._train_checkpointer = common.Checkpointer(
ckpt_dir=train_dir,
agent=agent,
Expand Down Expand Up @@ -444,6 +448,12 @@ def run(self):
self._policy_checkpointer.save(global_step=global_step_val)
self._rb_checkpointer.save(global_step=global_step_val)

export_dir = os.path.join(self._train_dir, 'saved_policy',
'step_' + ('%d' % global_step_val).zfill(8))
self._policy_exporter.save(export_dir)
common.save_spec(self._collect_policy.trajectory_spec,
os.path.join(export_dir, 'trajectory_spec'))

def _initialize_graph(self, sess):
"""Initialize the graph for sess."""
self._train_checkpointer.initialize_or_restore(sess)
Expand Down

0 comments on commit 92756d8

Please sign in to comment.