Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add more detailed Documentation on Model building API #13261

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
61b1423
WIP.
sven1977 Dec 29, 2020
0b1eb57
WIP.
sven1977 Dec 29, 2020
ce180fd
wip
sven1977 Dec 29, 2020
c6cfe3b
Merge branch 'master' of https://github.com/ray-project/ray into docu…
sven1977 Jan 2, 2021
1254e16
WIP.
sven1977 Jan 4, 2021
4136afe
WIP.
sven1977 Jan 4, 2021
6af7eb9
Merge branch 'master' of https://github.com/ray-project/ray into docu…
sven1977 Jan 5, 2021
4e61f5d
WIP.
sven1977 Jan 5, 2021
ec2d010
Merge branch 'master' of https://github.com/ray-project/ray into docu…
sven1977 Jan 5, 2021
98521e3
WIP.
sven1977 Jan 6, 2021
01fa66f
WIP.
sven1977 Jan 6, 2021
8b5d280
WIP.
sven1977 Jan 6, 2021
51ef2c6
Merge branch 'master' of https://github.com/ray-project/ray into docu…
sven1977 Jan 7, 2021
7335055
Merge branch 'master' of https://github.com/ray-project/ray
sven1977 Jan 7, 2021
5819e2f
WIP.
sven1977 Jan 7, 2021
54d3703
WIP.
sven1977 Jan 7, 2021
c372a54
WIP.
sven1977 Jan 7, 2021
0cd09c1
Merge branch 'documentation_model_building_prep_01' into documentatio…
sven1977 Jan 7, 2021
389df5b
WIP.
sven1977 Jan 7, 2021
348b012
WIP.
sven1977 Jan 7, 2021
98706bd
WIP.
sven1977 Jan 7, 2021
b190d8e
Merge branch 'documentation_model_building_prep_01' into documentatio…
sven1977 Jan 7, 2021
1b126a9
WIP.
sven1977 Jan 7, 2021
8536fe7
Fixes and LINT.
sven1977 Jan 8, 2021
2d836ac
Merge branch 'documentation_model_building_prep_01' into documentatio…
sven1977 Jan 8, 2021
241f17d
Merge branch 'master' of https://github.com/ray-project/ray into docu…
sven1977 Jan 8, 2021
379a791
Merge branch 'master' of https://github.com/ray-project/ray into docu…
sven1977 Jan 8, 2021
0b7bc77
Fix.
sven1977 Jan 8, 2021
5310292
Merge branch 'documentation_model_building_prep_01' into documentatio…
sven1977 Jan 8, 2021
60df06c
Merge branch 'master' of https://github.com/ray-project/ray into docu…
sven1977 Jan 8, 2021
b9e59f7
WIP.
sven1977 Jan 8, 2021
cda9fff
WIP.
sven1977 Jan 9, 2021
5c1e044
Merge branch 'master' of https://github.com/ray-project/ray into docu…
sven1977 Jan 9, 2021
fadf6d4
LINT.
sven1977 Jan 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
WIP.
  • Loading branch information
sven1977 committed Jan 7, 2021
commit 5819e2fd798a816bbfe4064ab54c01d6aab212b7
104 changes: 104 additions & 0 deletions rllib/examples/custom_model_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import argparse
from gym.spaces import Box, Discrete
import numpy as np

from ray.rllib.examples.models.custom_model_api import DuelingQModel, \
TorchDuelingQModel, ContActionQModel, TorchContActionQModel
from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS
from ray.rllib.utils.framework import try_import_tf, try_import_torch

tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()

parser = argparse.ArgumentParser()
parser.add_argument(
"--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf")


if __name__ == "__main__":
args = parser.parse_args()

# Test API wrapper for dueling Q-head.

obs_space = Box(-1.0, 1.0, (3, ))
action_space = Discrete(3)

# Run in eager mode for value checking and debugging.
tf1.enable_eager_execution()

# __sphinx_doc_model_construct_begin__
my_dueling_model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=action_space.n,
model_config=MODEL_DEFAULTS,
framework=args.framework,
# Providing the `model_interface` arg will make the factory
# wrap the chosen default model with our new model API class
# (DuelingQModel). This way, both `forward` and `get_q_values`
# are available in the returned class.
model_interface=DuelingQModel if args.framework != "torch"
else TorchDuelingQModel,
name="dueling_q_model",
)
# __sphinx_doc_model_construct_end__

batch_size = 10
input_ = np.array([obs_space.sample() for _ in range(batch_size)])
# Note that for PyTorch, you will have to provide torch tensors here.
if args.framework == "torch":
input_ = torch.from_numpy(input_)

input_dict = {
"obs": input_,
"is_training": False,
}
out, state_outs = my_dueling_model(input_dict=input_dict)
assert out.shape == (10, 256)
# Pass `out` into `get_q_values`
q_values = my_dueling_model.get_q_values(out)
assert q_values.shape == (10, action_space.n)

# Test API wrapper for single value Q-head from obs/action input.

obs_space = Box(-1.0, 1.0, (3, ))
action_space = Box(-1.0, -1.0, (2, ))

# __sphinx_doc_model_construct_begin__
my_cont_action_q_model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=2,
model_config=MODEL_DEFAULTS,
framework=args.framework,
# Providing the `model_interface` arg will make the factory
# wrap the chosen default model with our new model API class
# (DuelingQModel). This way, both `forward` and `get_q_values`
# are available in the returned class.
model_interface=ContActionQModel if args.framework != "torch"
else TorchContActionQModel,
name="cont_action_q_model",
)
# __sphinx_doc_model_construct_end__

batch_size = 10
input_ = np.array([obs_space.sample() for _ in range(batch_size)])

# Note that for PyTorch, you will have to provide torch tensors here.
if args.framework == "torch":
input_ = torch.from_numpy(input_)

input_dict = {
"obs": input_,
"is_training": False,
}
# Note that for PyTorch, you will have to provide torch tensors here.
out, state_outs = my_cont_action_q_model(input_dict=input_dict)
assert out.shape == (10, 256)
# Pass `out` and an action into `my_cont_action_q_model`
action = np.array([action_space.sample() for _ in range(batch_size)])
if args.framework == "torch":
action = torch.from_numpy(action)

q_value = my_cont_action_q_model.get_single_q_value(out, action)
assert q_value.shape == (10, 1)
41 changes: 0 additions & 41 deletions rllib/examples/custom_model_api_tf.py

This file was deleted.

45 changes: 0 additions & 45 deletions rllib/examples/custom_model_api_torch.py

This file was deleted.

20 changes: 16 additions & 4 deletions rllib/examples/models/custom_model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.torch.fcnet import TorchFullyConnectedNetwork
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as \
TorchFullyConnectedNetwork
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_tf, try_import_torch
Expand Down Expand Up @@ -102,15 +103,20 @@ def __init__(self, obs_space, action_space, num_outputs, model_config,

# Nest an RLlib FullyConnectedNetwork (torch or tf) into this one here
# to be used for Q-value calculation.
# Use the current value of self.num_outputs, which is the wrapped
# model's output layer size.
combined_space = Box(
-1.0, 1.0, (self.num_outputs + action_space.shape[0], ))
self.q_head = FullyConnectedNetwork(
combined_space, action_space, 1, model_config, "q_head")

# Missing here: Probably still have to provide action output layer
# and value layer and make sure self.num_outputs is correctly set.

def get_single_q_value(self, underlying_output, action):
# Calculate the q-value after concating the underlying output with
# the given action.
input_ = torch.cat([underlying_output, action], dim=-1)
input_ = tf.concat([underlying_output, action], axis=-1)
# Construct a simple input_dict (needed for self.q_head as it's an
# RLlib ModelV2).
input_dict = {"obs": input_}
Expand All @@ -120,30 +126,36 @@ def get_single_q_value(self, underlying_output, action):


# __sphinx_doc_model_api_torch_start__
class TorchContActionQModel(TorchModelV2): # or: TFModelV2
class TorchContActionQModel(TorchModelV2):
"""A simple, q-value-from-cont-action model (for e.g. SAC type algos)."""

def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
nn.Module.__init__(self)
# Pass num_outputs=None into super constructor (so that no action/
# logits output layer is built).
# Alternatively, you can pass in num_outputs=[last layer size of
# config[model][fcnet_hiddens]] AND set no_last_linear=True, but
# this seems more tedious as you will have to explain users of this
# class that num_outputs is NOT the size of your Q-output layer.
super(ContActionQModel, self).__init__(
super(TorchContActionQModel, self).__init__(
obs_space, action_space, None, model_config, name)

# Now: self.num_outputs contains the last layer's size, which
# we can use to construct the single q-value computing head.

# Nest an RLlib FullyConnectedNetwork (torch or tf) into this one here
# to be used for Q-value calculation.
# Use the current value of self.num_outputs, which is the wrapped
# model's output layer size.
combined_space = Box(
-1.0, 1.0, (self.num_outputs + action_space.shape[0], ))
self.q_head = TorchFullyConnectedNetwork(
combined_space, action_space, 1, model_config, "q_head")

# Missing here: Probably still have to provide action output layer
# and value layer and make sure self.num_outputs is correctly set.

def get_single_q_value(self, underlying_output, action):
# Calculate the q-value after concating the underlying output with
# the given action.
Expand Down