Skip to content

Commit

Permalink
[RLlib] Make JSONReader default, users will have to use the DatasetRe…
Browse files Browse the repository at this point in the history
…ader for any speedups. (ray-project#26541)
  • Loading branch information
avnishn authored Jul 14, 2022
1 parent c168c09 commit a322ac4
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 57 deletions.
50 changes: 0 additions & 50 deletions rllib/evaluation/worker_set.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from pathlib import Path
import re

import gym
import logging
import importlib.util
Expand Down Expand Up @@ -108,53 +105,6 @@ def __init__(
self._local_worker = None
if num_workers == 0:
local_worker = True
if (
(
isinstance(trainer_config["input"], str)
or isinstance(trainer_config["input"], list)
)
and ("d4rl" not in trainer_config["input"])
and (not "sampler" == trainer_config["input"])
and (not "dataset" == trainer_config["input"])
and (
not (
isinstance(trainer_config["input"], str)
and registry_contains_input(trainer_config["input"])
)
)
and (
not (
isinstance(trainer_config["input"], str)
and self._valid_module(trainer_config["input"])
)
)
):
paths = trainer_config["input"]
if isinstance(paths, str):
inputs = Path(paths).absolute()
if inputs.is_dir():
paths = list(inputs.glob("*.json")) + list(inputs.glob("*.zip"))
paths = [str(path) for path in paths]
else:
paths = [paths]
ends_with_zip_or_json = all(
re.search("\\.zip$", path) or re.search("\\.json$", path)
for path in paths
)
ends_with_parquet = all(
re.search("\\.parquet$", path) for path in paths
)
trainer_config["input"] = "dataset"
input_config = {"paths": paths}
if ends_with_zip_or_json:
input_config["format"] = "json"
elif ends_with_parquet:
input_config["format"] = "parquet"
else:
raise ValueError(
"Input path must end with .zip, .parquet, or .json"
)
trainer_config["input_config"] = input_config
self._local_config = merge_dicts(
trainer_config,
{"tf_session_args": trainer_config["local_tf_session_args"]},
Expand Down
2 changes: 1 addition & 1 deletion rllib/tests/run_regression_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
required=True,
help="The directory in which to find all yamls to test.",
)
parser.add_argument("--num-cpus", type=int, default=6)
parser.add_argument("--num-cpus", type=int, default=8)
parser.add_argument(
"--local-mode",
action="store_true",
Expand Down
5 changes: 4 additions & 1 deletion rllib/tuned_examples/cql/pendulum-cql.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ pendulum-cql:
framework: tf

# Use one or more offline files or "input: sampler" for online learning.
input: ["tests/data/pendulum/enormous.zip"]
input: 'dataset'
input_config:
paths: ["tests/data/pendulum/enormous.zip"]
format: 'json'
# Our input file above comes from an SAC run. Actions in there
# are already normalized (produced by SquashedGaussian).
actions_in_input_normalized: true
Expand Down
6 changes: 4 additions & 2 deletions rllib/tuned_examples/crr/cartpole-v0-crr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ cartpole_crr:
evaluation/episode_reward_mean: 200
training_iteration: 100
config:
input:
- 'tests/data/cartpole/large.json'
input: 'dataset'
input_config:
paths: 'tests/data/cartpole/large.json'
format: 'json'
num_workers: 3
framework: torch
gamma: 0.99
Expand Down
6 changes: 4 additions & 2 deletions rllib/tuned_examples/crr/cartpole-v0-crr_expectation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ cartpole_crr:
evaluation/episode_reward_mean: 200
training_iteration: 100
config:
input:
- 'tests/data/cartpole/large.json'
input: 'dataset'
input_config:
paths: 'tests/data/cartpole/large.json'
format: 'json'
framework: torch
num_workers: 3
gamma: 0.99
Expand Down
5 changes: 4 additions & 1 deletion rllib/tuned_examples/crr/pendulum-v1-crr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ pendulum_crr:
evaluation/episode_reward_mean: -300
timesteps_total: 2000000
config:
input: 'tests/data/pendulum/pendulum_replay_v1.1.0.zip'
input: 'dataset'
input_config:
paths: 'tests/data/pendulum/pendulum_replay_v1.1.0.zip'
format: 'json'
framework: torch
num_workers: 3
gamma: 0.99
Expand Down

0 comments on commit a322ac4

Please sign in to comment.