Skip to content

Commit

Permalink
[rllib] Make observation filter optional (ray-project#940)
Browse files Browse the repository at this point in the history
* make observation filter optional

* fix linting
  • Loading branch information
pcmoritz authored and robertnishihara committed Sep 15, 2017
1 parent 413140d commit 6601bb5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
2 changes: 2 additions & 0 deletions python/ray/rllib/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
"kl_target": 0.01,
# Config params to pass to the model
"model": {"free_log_std": False},
# Which observation filter to apply to the observation
"observation_filter": "MeanStdFilter",
# If >1, adds frameskip
"extra_frameskip": 1,
# Number of timesteps collected in each outer loop
Expand Down
12 changes: 9 additions & 3 deletions python/ray/rllib/ppo/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ray.rllib.models import ModelCatalog
from ray.rllib.ppo.env import BatchedEnv
from ray.rllib.ppo.loss import ProximalPolicyLoss
from ray.rllib.ppo.filter import MeanStdFilter
from ray.rllib.ppo.filter import NoFilter, MeanStdFilter
from ray.rllib.ppo.rollout import (
rollouts, add_return_values, add_advantage_values)
from ray.rllib.ppo.utils import flatten, concatenate
Expand Down Expand Up @@ -140,8 +140,14 @@ def build_loss(obs, rets, advs, acts, plog, pvf_preds):
self.common_policy = self.par_opt.get_common_loss()
self.variables = ray.experimental.TensorFlowVariables(
self.common_policy.loss, self.sess)
self.observation_filter = MeanStdFilter(
self.preprocessor_shape, clip=None)
if config["observation_filter"] == "MeanStdFilter":
self.observation_filter = MeanStdFilter(
self.preprocessor_shape, clip=None)
elif config["observation_filter"] == "NoFilter":
self.observation_filter = NoFilter()
else:
raise Exception("Unknown observation_filter: " +
str(config["observation_filter"]))
self.reward_filter = MeanStdFilter((), clip=5.0)
self.sess.run(tf.global_variables_initializer())

Expand Down

0 comments on commit 6601bb5

Please sign in to comment.