Skip to content

Commit a9ce012

Browse files
tiny mod
1 parent f375dc2 commit a9ce012

File tree

4 files changed

+4
-12
lines changed

4 files changed

+4
-12
lines changed

rl2/cartpole/q_learning.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,7 @@ def __init__(self, env, feature_transformer):
7373

7474
def predict(self, s):
7575
X = self.feature_transformer.transform(np.atleast_2d(s))
76-
# print("X.shape", X.shape)
77-
result = np.array([m.predict(X)[0] for m in self.models])
78-
result = np.atleast_2d(result)
79-
assert(len(result.shape) == 2)
76+
result = np.stack([m.predict(X) for m in self.models]).T
8077
return result
8178

8279
def update(self, s, a, G):

rl2/cartpole/td_lambda.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,7 @@ def reset(self):
5151

5252
def predict(self, s):
5353
X = self.feature_transformer.transform([s])
54-
# assert(len(X.shape) == 2)
55-
result = np.array([m.predict(X)[0] for m in self.models])
56-
result = np.atleast_2d(result)
54+
result = np.stack([m.predict(X) for m in self.models]).T
5755
return result
5856

5957
def update(self, s, a, G, gamma, lambda_):

rl2/mountaincar/q_learning.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,7 @@ def __init__(self, env, feature_transformer, learning_rate):
7575

7676
def predict(self, s):
7777
X = self.feature_transformer.transform([s])
78-
assert(len(X.shape) == 2)
79-
result = np.array([m.predict(X)[0] for m in self.models])
80-
result = np.atleast_2d(result)
78+
result = np.stack([m.predict(X) for m in self.models]).T
8179
assert(len(result.shape) == 2)
8280
return result
8381

rl2/mountaincar/td_lambda.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ def __init__(self, env, feature_transformer):
5252
def predict(self, s):
5353
X = self.feature_transformer.transform([s])
5454
assert(len(X.shape) == 2)
55-
result = np.array([m.predict(X)[0] for m in self.models])
56-
result = np.atleast_2d(result)
55+
result = np.stack([m.predict(X) for m in self.models]).T
5756
assert(len(result.shape) == 2)
5857
return result
5958

0 commit comments

Comments
 (0)