Skip to content
This repository has been archived by the owner on Nov 29, 2023. It is now read-only.

Commit

Permalink
[rllib] Fix logging to Athena (ray-project#1058)
Browse files Browse the repository at this point in the history
* Fix logging to Athena

* fixes
  • Loading branch information
pcmoritz authored and richardliaw committed Oct 3, 2017
1 parent 1488975 commit b94d85f
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 9 deletions.
5 changes: 3 additions & 2 deletions python/ray/rllib/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import ray
from ray.rllib.a3c.runner import RunnerThread, process_rollout
from ray.rllib.a3c.envs import create_env
from ray.rllib.common import Agent, TrainingResult
from ray.rllib.common import Agent, TrainingResult, get_tensorflow_log_dir
from ray.rllib.a3c.shared_model import SharedModel
from ray.rllib.a3c.shared_model_lstm import SharedModelLSTM

Expand Down Expand Up @@ -73,8 +73,9 @@ def get_completed_rollout_metrics(self):
return completed

def start(self):
logdir = get_tensorflow_log_dir(self.logdir)
summary_writer = tf.summary.FileWriter(
os.path.join(self.logdir, "agent_%d" % self.id))
os.path.join(logdir, "agent_%d" % self.id))
self.summary_writer = summary_writer
self.runner.start_runner(self.policy.sess, summary_writer)

Expand Down
39 changes: 34 additions & 5 deletions python/ray/rllib/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,43 @@
logger.setLevel(logging.INFO)


def get_tensorflow_log_dir(logdir):
if logdir.startswith("s3"):
print("WARNING: TensorFlow logging to S3 not supported by"
"TensorFlow, logging to /tmp/ray/ instead")
logdir = "/tmp/ray/"
if not os.path.exists(logdir):
os.makedirs(logdir)
return logdir


class RLLibEncoder(json.JSONEncoder):

def __init__(self, nan_str="null", **kwargs):
super(RLLibEncoder, self).__init__(**kwargs)
self.nan_str = nan_str

def iterencode(self, o, _one_shot=False):
if self.ensure_ascii:
_encoder = json.encoder.encode_basestring_ascii
else:
_encoder = json.encoder.encode_basestring

def floatstr(o, allow_nan=self.allow_nan, nan_str=self.nan_str):
return repr(o) if not np.isnan(o) else nan_str

_iterencode = json.encoder._make_iterencode(
None, self.default, _encoder, self.indent, floatstr,
self.key_separator, self.item_separator, self.sort_keys,
self.skipkeys, _one_shot)
return _iterencode(o, 0)

def default(self, value):
if np.isnan(value):
return None
if np.issubdtype(value, float):
if np.isnan(value):
return None
else:
return float(value)
elif np.issubdtype(value, int):
return float(value)
if np.issubdtype(value, int):
return int(value)


Expand Down
5 changes: 3 additions & 2 deletions python/ray/rllib/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tensorflow.python import debug as tf_debug

import ray
from ray.rllib.common import Agent, TrainingResult
from ray.rllib.common import Agent, TrainingResult, get_tensorflow_log_dir
from ray.rllib.ppo.runner import Runner, RemoteRunner
from ray.rllib.ppo.rollout import collect_samples
from ray.rllib.ppo.utils import shuffle
Expand Down Expand Up @@ -99,8 +99,9 @@ def _init(self):
for _ in range(self.config["num_workers"])]
self.start_time = time.time()
if self.config["write_logs"]:
logdir = get_tensorflow_log_dir(self.logdir)
self.file_writer = tf.summary.FileWriter(
self.logdir, self.model.sess.graph)
logdir, self.model.sess.graph)
else:
self.file_writer = None
self.saver = tf.train.Saver(max_to_keep=None)
Expand Down

0 comments on commit b94d85f

Please sign in to comment.