Skip to content

Commit

Permalink
[RLlib] Bump tf version in ML docker to tf==2.5.0; add tfp to ML-dock…
Browse files Browse the repository at this point in the history
…er. (#18544)
  • Loading branch information
sven1977 authored Sep 15, 2021
1 parent c5d2084 commit 8a00154
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 12 deletions.
3 changes: 3 additions & 0 deletions docker/ray-ml/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,8 @@ RUN sudo apt-get update \
&& sudo rm requirements_tune.txt && sudo rm requirements_rllib.txt \
&& sudo apt-get clean

# Make sure tfp is installed correctly and matches tf version.
RUN python -c "import tensorflow_probability"

# Install Atari ROMs. Previously these have been shipped with atari_py \
RUN ./install_atari_roms.sh
5 changes: 3 additions & 2 deletions python/requirements/rllib/requirements_rllib.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Deep learning.
# --------------
tensorflow==2.4.3
tensorflow-probability==0.12.2
tensorflow==2.5.0
tensorflow-probability==0.13.0

torch==1.8.1;sys_platform=="darwin"
torchvision==0.9.1;sys_platform=="darwin"

Expand Down
1 change: 0 additions & 1 deletion python/requirements/tune/requirements_tune.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ scikit-learn==0.24.2
scikit-optimize==0.8.1
sigopt==7.5.0
smart_open==5.1.0
tensorflow-probability==0.13.0
timm==0.4.5


Expand Down
5 changes: 4 additions & 1 deletion python/requirements_ml_docker.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
ipython

# In TF >v2, GPU support is included in the base package.
tensorflow==2.4.3
tensorflow==2.5.0
tensorflow-probability==0.13.0

-f https://download.pytorch.org/whl/torch_stable.html
torch==1.8.1+cu111
-f https://download.pytorch.org/whl/torch_stable.html
Expand Down
10 changes: 8 additions & 2 deletions rllib/agents/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.framework import try_import_tfp
from ray.rllib.utils.typing import TrainerConfigDict

tfp = try_import_tfp()
replay_buffer = None

# yapf: disable
# __sphinx_doc_begin__
CQL_DEFAULT_CONFIG = merge_dicts(
Expand Down Expand Up @@ -57,8 +61,10 @@ def validate_config(config: TrainerConfigDict):
config["framework"] == "torch":
config["simple_optimizer"] = True


replay_buffer = None
if config["framework"] in ["tf", "tf2", "tfe"] and tfp is None:
raise ModuleNotFoundError(
"You need `tensorflow_probability` in order to run CQL with tf! "
"Install it via `pip install tensorflow_probability`.")


def execution_plan(workers, config):
Expand Down
6 changes: 3 additions & 3 deletions rllib/agents/dqn/tests/test_apex_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,17 @@ def _step_n_times(trainer, n: int):
lr = _step_n_times(trainer, 5) # 50 timesteps
# PiecewiseSchedule does interpolation. So roughly 0.1 here.
self.assertLessEqual(lr, 0.15)
self.assertGreaterEqual(lr, 0.05)
self.assertGreaterEqual(lr, 0.04)

lr = _step_n_times(trainer, 5) # 100 timesteps
# PiecewiseSchedule does interpolation. So roughly 0.01 here.
self.assertLessEqual(lr, 0.02)
self.assertGreaterEqual(lr, 0.005)
self.assertGreaterEqual(lr, 0.004)

lr = _step_n_times(trainer, 5) # 150 timesteps
# PiecewiseSchedule does interpolation. So roughly 0.001 here.
self.assertLessEqual(lr, 0.002)
self.assertGreaterEqual(lr, 0.0005)
self.assertGreaterEqual(lr, 0.0004)

trainer.stop()

Expand Down
8 changes: 8 additions & 0 deletions rllib/agents/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
from ray.rllib.agents.sac.sac_tf_policy import SACTFPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.utils.framework import try_import_tfp
from ray.rllib.utils.typing import TrainerConfigDict

tfp = try_import_tfp()

logger = logging.getLogger(__name__)

OPTIMIZER_SHARED_CONFIGS = [
Expand Down Expand Up @@ -190,6 +193,11 @@ def validate_config(config: TrainerConfigDict) -> None:
if config["grad_clip"] is not None and config["grad_clip"] <= 0.0:
raise ValueError("`grad_clip` value must be > 0.0!")

if config["framework"] in ["tf", "tf2", "tfe"] and tfp is None:
raise ModuleNotFoundError(
"You need `tensorflow_probability` in order to run SAC! "
"Install it via `pip install tensorflow_probability`.")


def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
"""Policy class picker function. Class is chosen based on DL-framework.
Expand Down
4 changes: 1 addition & 3 deletions rllib/agents/sac/sac_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,13 @@
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.framework import get_variable, try_import_tf, \
try_import_tfp
from ray.rllib.utils.framework import get_variable, try_import_tf
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.tf_ops import huber_loss
from ray.rllib.utils.typing import AgentID, LocalOptimizer, ModelGradients, \
TensorType, TrainerConfigDict

tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()

logger = logging.getLogger(__name__)

Expand Down

0 comments on commit 8a00154

Please sign in to comment.