forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_model_api.py
108 lines (91 loc) · 3.74 KB
/
custom_model_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
from gym.spaces import Box, Discrete
import numpy as np
from ray.rllib.examples.models.custom_model_api import (
DuelingQModel,
TorchDuelingQModel,
ContActionQModel,
TorchContActionQModel,
)
from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf, try_import_torch
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
parser = argparse.ArgumentParser()
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.",
)
if __name__ == "__main__":
args = parser.parse_args()
# Test API wrapper for dueling Q-head.
obs_space = Box(-1.0, 1.0, (3,))
action_space = Discrete(3)
# Run in eager mode for value checking and debugging.
tf1.enable_eager_execution()
# __sphinx_doc_model_construct_1_begin__
my_dueling_model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=action_space.n,
model_config=MODEL_DEFAULTS,
framework=args.framework,
# Providing the `model_interface` arg will make the factory
# wrap the chosen default model with our new model API class
# (DuelingQModel). This way, both `forward` and `get_q_values`
# are available in the returned class.
model_interface=DuelingQModel
if args.framework != "torch"
else TorchDuelingQModel,
name="dueling_q_model",
)
# __sphinx_doc_model_construct_1_end__
batch_size = 10
input_ = np.array([obs_space.sample() for _ in range(batch_size)])
# Note that for PyTorch, you will have to provide torch tensors here.
if args.framework == "torch":
input_ = torch.from_numpy(input_)
input_dict = SampleBatch(obs=input_, _is_training=False)
out, state_outs = my_dueling_model(input_dict=input_dict)
assert out.shape == (10, 256)
# Pass `out` into `get_q_values`
q_values = my_dueling_model.get_q_values(out)
assert q_values.shape == (10, action_space.n)
# Test API wrapper for single value Q-head from obs/action input.
obs_space = Box(-1.0, 1.0, (3,))
action_space = Box(-1.0, -1.0, (2,))
# __sphinx_doc_model_construct_2_begin__
my_cont_action_q_model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=2,
model_config=MODEL_DEFAULTS,
framework=args.framework,
# Providing the `model_interface` arg will make the factory
# wrap the chosen default model with our new model API class
# (DuelingQModel). This way, both `forward` and `get_q_values`
# are available in the returned class.
model_interface=ContActionQModel
if args.framework != "torch"
else TorchContActionQModel,
name="cont_action_q_model",
)
# __sphinx_doc_model_construct_2_end__
batch_size = 10
input_ = np.array([obs_space.sample() for _ in range(batch_size)])
# Note that for PyTorch, you will have to provide torch tensors here.
if args.framework == "torch":
input_ = torch.from_numpy(input_)
input_dict = SampleBatch(obs=input_, _is_training=False)
# Note that for PyTorch, you will have to provide torch tensors here.
out, state_outs = my_cont_action_q_model(input_dict=input_dict)
assert out.shape == (10, 256)
# Pass `out` and an action into `my_cont_action_q_model`
action = np.array([action_space.sample() for _ in range(batch_size)])
if args.framework == "torch":
action = torch.from_numpy(action)
q_value = my_cont_action_q_model.get_single_q_value(out, action)
assert q_value.shape == (10, 1)