Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a utility function to draw a computational graph #166

Merged
merged 18 commits into from
Dec 1, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions chainerrl/action_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from abc import ABCMeta
from abc import abstractmethod
from abc import abstractproperty
import warnings

from cached_property import cached_property
import chainer
Expand Down Expand Up @@ -41,6 +42,15 @@ def evaluate_actions(self, actions):
"""Evaluate Q(s,a) with a = given actions."""
raise NotImplementedError()

@abstractproperty
def params(self):
"""Learnable parameters of this action value.

Returns:
tuple of chainer.Variable
"""
raise NotImplementedError()


class DiscreteActionValue(ActionValue):
"""Q-function output for discrete action space.
Expand Down Expand Up @@ -95,6 +105,10 @@ def __repr__(self):
self.greedy_actions.data,
self.q_values_formatter(self.q_values.data))

@property
def params(self):
return (self.q_values,)


class QuadraticActionValue(ActionValue):
"""Q-function output for continuous action space.
Expand Down Expand Up @@ -170,6 +184,10 @@ def __repr__(self):
return 'QuadraticActionValue greedy_actions:{} v:{}'.format(
self.greedy_actions.data, self.v.data)

@property
def params(self):
return (self.mu, self.mat, self.v)


class SingleActionValue(ActionValue):
"""ActionValue that can evaluate only a single action."""
Expand Down Expand Up @@ -200,3 +218,12 @@ def compute_double_advantage(self, actions, argmax_actions):

def __repr__(self):
return 'SingleActionValue'

@property
def params(self):
warnings.warn(
'SingleActionValue has no learnable parameters until it'
' is evaluated on some action. If you want to draw a computation'
' graph that outputs SingleActionValue, use the variable returned'
' by its method such as evaluate_actions instead.')
return ()
1 change: 0 additions & 1 deletion chainerrl/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from chainerrl.experiments.hooks import LinearInterpolationHook # NOQA
from chainerrl.experiments.hooks import StepHook # NOQA

from chainerrl.experiments.prepare_output_dir import is_return_code_zero # NOQA
from chainerrl.experiments.prepare_output_dir import is_under_git_control # NOQA
from chainerrl.experiments.prepare_output_dir import prepare_output_dir # NOQA

Expand Down
19 changes: 2 additions & 17 deletions chainerrl/experiments/prepare_output_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,12 @@
import sys
import tempfile


def is_return_code_zero(args):
"""Return true iff the given command's return code is zero.

All the messages to stdout or stderr are suppressed.
"""
FNULL = open(os.devnull, 'w')
try:
subprocess.check_call(args, stdout=FNULL, stderr=FNULL)
except subprocess.CalledProcessError:
# The given command returned an error
return False
except OSError:
# The given command was not found
return False
return True
import chainerrl


def is_under_git_control():
"""Return true iff the current directory is under git control."""
return is_return_code_zero(['git', 'rev-parse'])
return chainerrl.misc.is_return_code_zero(['git', 'rev-parse'])


def prepare_output_dir(args, user_specified_dir=None, argv=None,
Expand Down
4 changes: 4 additions & 0 deletions chainerrl/misc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from chainerrl.misc.batch_states import batch_states # NOQA
from chainerrl.misc.draw_computational_graph import collect_variables # NOQA
from chainerrl.misc.draw_computational_graph import draw_computational_graph # NOQA
from chainerrl.misc.draw_computational_graph import is_graphviz_available # NOQA
from chainerrl.misc import env_modifiers # NOQA
from chainerrl.misc.is_return_code_zero import is_return_code_zero # NOQA
from chainerrl.misc.random_seed import set_random_seed # NOQA
62 changes: 62 additions & 0 deletions chainerrl/misc/draw_computational_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import absolute_import
from builtins import * # NOQA
from future import standard_library
standard_library.install_aliases()

import subprocess

import chainer.computational_graph
import chainerrl


def collect_variables(obj):
"""Collect Variable objects inside a given object.

Args:
obj (object): Object to collect Variable objects from.
Returns:
List of Variable objects.
"""
variables = []
if isinstance(obj, chainer.Variable):
return [obj]
elif isinstance(obj, chainerrl.action_value.ActionValue):
return list(obj.params)
elif isinstance(obj, chainerrl.distribution.Distribution):
return list(obj.params)
elif isinstance(obj, (list, tuple)):
variables = []
for child in obj:
variables.extend(collect_variables(child))
return variables


def is_graphviz_available():
return chainerrl.misc.is_return_code_zero(['dot', '-V'])


def draw_computational_graph(outputs, filepath):
"""Draw a computational graph and write to a given file.

Args:
outputs (object): Output(s) of the computational graph. It must be
a Variable, an ActionValue, a Distribution or a list of them.
filepath (str): Filepath to write a graph without file extention.
A DOT file will be saved with ".gv" extension added.
If Graphviz's dot command is available, a PNG file will also be
saved with ".png" extension added.
"""
variables = collect_variables(outputs)
g = chainer.computational_graph.build_computational_graph(variables)
gv_filepath = filepath + '.gv'
with open(gv_filepath, 'w') as f:
# future.builtins.str is required to make sure the content is unicode
# in both py2 and py3
f.write(str(g.dump()))
if is_graphviz_available():
png_filepath = filepath + '.png'
subprocess.check_call(
['dot', '-Tpng', gv_filepath, '-o', png_filepath])
26 changes: 26 additions & 0 deletions chainerrl/misc/is_return_code_zero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from future import standard_library
standard_library.install_aliases()

import os
import subprocess


def is_return_code_zero(args):
"""Return true iff the given command's return code is zero.

All the messages to stdout or stderr are suppressed.
"""
with open(os.devnull, 'wb') as FNULL:
try:
subprocess.check_call(args, stdout=FNULL, stderr=FNULL)
except subprocess.CalledProcessError:
# The given command returned an error
return False
except OSError:
# The given command was not found
return False
return True
6 changes: 6 additions & 0 deletions examples/gym/train_reinforce_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from future import standard_library
standard_library.install_aliases()
import argparse
import os

import chainer
import gym
Expand Down Expand Up @@ -98,6 +99,11 @@ def make_env(test):
nonlinearity=chainer.functions.leaky_relu,
)

# Draw the computational graph and save it in the output directory.
chainerrl.misc.draw_computational_graph(
[model(np.zeros_like(obs_space.low, dtype=np.float32)[None])],
os.path.join(args.outdir, 'model'))

if args.gpu >= 0:
chainer.cuda.get_device(args.gpu).use()
model.to_gpu(args.gpu)
Expand Down
11 changes: 0 additions & 11 deletions tests/experiments_tests/test_prepare_output_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,6 @@ def work_dir(dirname):
os.chdir(orig_dir)


class TestIsReturnCodeZero(unittest.TestCase):

def test(self):
# Assume ls command exists
self.assertTrue(chainerrl.experiments.is_return_code_zero(['ls']))
self.assertFalse(chainerrl.experiments.is_return_code_zero(
['ls --nonexistentoption']))
self.assertFalse(chainerrl.experiments.is_return_code_zero(
['nonexistentcommand']))


class TestIsUnderGitControl(unittest.TestCase):

def test(self):
Expand Down
92 changes: 92 additions & 0 deletions tests/misc_tests/test_draw_computational_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from future import standard_library
standard_library.install_aliases()

import os
import tempfile
import unittest

import chainer
from chainer import testing
import numpy as np

import chainerrl


_v = chainer.Variable(np.zeros(5))
_dav = chainerrl.action_value.DiscreteActionValue(
chainer.Variable(np.zeros((5, 5))))
_qav = chainerrl.action_value.QuadraticActionValue(
chainer.Variable(np.zeros((5, 5), dtype=np.float32)),
chainer.Variable(np.ones((5, 5, 5), dtype=np.float32)),
chainer.Variable(np.zeros((5, 1), dtype=np.float32)),
)
_sdis = chainerrl.distribution.SoftmaxDistribution(
chainer.Variable(np.zeros((5, 5))))
_gdis = chainerrl.distribution.GaussianDistribution(
chainer.Variable(np.zeros((5, 5), dtype=np.float32)),
chainer.Variable(np.ones((5, 5), dtype=np.float32)))


@testing.parameterize(
{'obj': [], 'expected': []},
{'obj': (), 'expected': []},
{'obj': _v, 'expected': [_v]},
{'obj': _dav, 'expected': list(_dav.params)},
{'obj': _qav, 'expected': list(_qav.params)},
{'obj': _sdis, 'expected': list(_sdis.params)},
{'obj': _gdis, 'expected': list(_gdis.params)},
{'obj': [_v, _dav, _sdis],
'expected': [_v] + list(_dav.params) + list(_sdis.params)},
)
class TestCollectVariables(unittest.TestCase):

def _assert_eq_var_list(self, a, b):
# Equality between two Variable lists
self.assertEqual(len(a), len(b))
self.assertTrue(isinstance(a, list))
self.assertTrue(isinstance(b, list))
for item in a:
self.assertTrue(isinstance(item, chainer.Variable))
for item in b:
self.assertTrue(isinstance(item, chainer.Variable))
for va, vb in zip(a, b):
self.assertEqual(id(va), id(vb))

def test_collect_variables(self):
vs = chainerrl.misc.collect_variables(self.obj)
self._assert_eq_var_list(vs, self.expected)

# Wrap by a list
vs = chainerrl.misc.collect_variables([self.obj])
self._assert_eq_var_list(vs, self.expected)

# Wrap by two lists
vs = chainerrl.misc.collect_variables([[self.obj]])
self._assert_eq_var_list(vs, self.expected)

# Wrap by a tuple
vs = chainerrl.misc.collect_variables((self.obj,))
self._assert_eq_var_list(vs, self.expected)

# Wrap by a two tuples
vs = chainerrl.misc.collect_variables(((self.obj,),))
self._assert_eq_var_list(vs, self.expected)


class TestDrawComputationalGraph(unittest.TestCase):

def test_draw_computational_graph(self):
x = chainer.Variable(np.zeros(5))
y = x ** 2 + chainer.Variable(np.ones(5))
dirname = tempfile.mkdtemp()
filepath = os.path.join(dirname, 'graph')
chainerrl.misc.draw_computational_graph(y, filepath)
self.assertTrue(os.path.exists(filepath + '.gv'))
if chainerrl.misc.is_graphviz_available():
self.assertTrue(os.path.exists(filepath + '.png'))
else:
self.assertFalse(os.path.exists(filepath + '.png'))
21 changes: 21 additions & 0 deletions tests/misc_tests/test_is_return_code_zero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from future import standard_library
standard_library.install_aliases()

import unittest

import chainerrl


class TestIsReturnCodeZero(unittest.TestCase):

def test(self):
# Assume ls command exists
self.assertTrue(chainerrl.misc.is_return_code_zero(['ls']))
self.assertFalse(chainerrl.misc.is_return_code_zero(
['ls --nonexistentoption']))
self.assertFalse(chainerrl.misc.is_return_code_zero(
['nonexistentcommand']))
4 changes: 4 additions & 0 deletions tests/test_action_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def test_compute_advantage(self):
adv = q - v
self.assertAlmostEqual(ret.data[b], adv)

def test_params(self):
self.assertEqual(len(self.qout.params), 1)
self.assertEqual(id(self.qout.params[0]), id(self.qout.q_values))


class TestQuadraticActionValue(unittest.TestCase):
def test_max_unbounded(self):
Expand Down