-
Notifications
You must be signed in to change notification settings - Fork 6k
/
Copy pathdirect_method.py
170 lines (151 loc) · 6.13 KB
/
direct_method.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from typing import Tuple, List, Generator
from ray.rllib.offline.estimators.off_policy_estimator import (
OffPolicyEstimator,
OffPolicyEstimate,
)
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.typing import SampleBatchType
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
from ray.rllib.offline.estimators.qreg_torch_model import QRegTorchModel
from gym.spaces import Discrete
import numpy as np
torch, nn = try_import_torch()
# TODO (rohan): replace with AIR/parallel workers
# (And find a better name than `should_train`)
@DeveloperAPI
def k_fold_cv(
batch: SampleBatchType, k: int, should_train: bool = True
) -> Generator[Tuple[List[SampleBatch]], None, None]:
"""Utility function that returns a k-fold cross validation generator
over episodes from the given batch. If the number of episodes in the
batch is less than `k` or `should_train` is set to False, yields an empty
list for train_episodes and all the episodes in test_episodes.
Args:
batch: A SampleBatch of episodes to split
k: Number of cross-validation splits
should_train: True by default. If False, yield [], [episodes].
Returns:
A tuple with two lists of SampleBatches (train_episodes, test_episodes)
"""
episodes = batch.split_by_episode()
n_episodes = len(episodes)
if n_episodes < k or not should_train:
yield [], episodes
return
n_fold = n_episodes // k
for i in range(k):
train_episodes = episodes[: i * n_fold] + episodes[(i + 1) * n_fold :]
if i != k - 1:
test_episodes = episodes[i * n_fold : (i + 1) * n_fold]
else:
# Append remaining episodes onto the last test_episodes
test_episodes = episodes[i * n_fold :]
yield train_episodes, test_episodes
return
@DeveloperAPI
class DirectMethod(OffPolicyEstimator):
"""The Direct Method estimator.
DM estimator described in https://arxiv.org/pdf/1511.03722.pdf"""
@override(OffPolicyEstimator)
def __init__(
self,
name: str,
policy: Policy,
gamma: float,
q_model_type: str = "fqe",
k: int = 5,
**kwargs,
):
"""
Initializes a Direct Method OPE Estimator.
Args:
name: string to save OPE results under
policy: Policy to evaluate.
gamma: Discount factor of the environment.
q_model_type: Either "fqe" for Fitted Q-Evaluation
or "qreg" for Q-Regression, or a custom model that implements:
- `estimate_q(states,actions)`
- `estimate_v(states, action_probs)`
k: k-fold cross validation for training model and evaluating OPE
kwargs: Optional arguments for the specified Q model
"""
super().__init__(name, policy, gamma)
# TODO (rohan): Add support for continuous action spaces
assert isinstance(
policy.action_space, Discrete
), "DM Estimator only supports discrete action spaces!"
assert (
policy.config["batch_mode"] == "complete_episodes"
), "DM Estimator only supports `batch_mode`=`complete_episodes`"
# TODO (rohan): Add support for TF!
if policy.framework == "torch":
if q_model_type == "qreg":
model_cls = QRegTorchModel
elif q_model_type == "fqe":
model_cls = FQETorchModel
else:
assert hasattr(
q_model_type, "estimate_q"
), "q_model_type must implement `estimate_q`!"
assert hasattr(
q_model_type, "estimate_v"
), "q_model_type must implement `estimate_v`!"
else:
raise ValueError(
f"{self.__class__.__name__}"
"estimator only supports `policy.framework`=`torch`"
)
self.model = model_cls(
policy=policy,
gamma=gamma,
**kwargs,
)
self.k = k
self.losses = []
@override(OffPolicyEstimator)
def estimate(
self, batch: SampleBatchType, should_train: bool = True
) -> OffPolicyEstimate:
self.check_can_estimate_for(batch)
estimates = []
# Split data into train and test using k-fold cross validation
for train_episodes, test_episodes in k_fold_cv(batch, self.k, should_train):
# Train Q-function
if train_episodes:
# Reinitialize model
self.model.reset()
train_batch = SampleBatch.concat_samples(train_episodes)
losses = self.train(train_batch)
self.losses.append(losses)
# Calculate direct method OPE estimates
for episode in test_episodes:
rewards = episode["rewards"]
v_old = 0.0
v_new = 0.0
for t in range(episode.count):
v_old += rewards[t] * self.gamma ** t
init_step = episode[0:1]
init_obs = np.array([init_step[SampleBatch.OBS]])
all_actions = np.arange(self.policy.action_space.n, dtype=float)
init_step[SampleBatch.ACTIONS] = all_actions
action_probs = np.exp(self.action_log_likelihood(init_step))
v_value = self.model.estimate_v(init_obs, action_probs)
v_new = convert_to_numpy(v_value).item()
estimates.append(
OffPolicyEstimate(
self.name,
{
"v_old": v_old,
"v_new": v_new,
"v_gain": v_new / max(1e-8, v_old),
},
)
)
return estimates
@override(OffPolicyEstimator)
def train(self, batch: SampleBatchType):
return self.model.train_q(batch)