Skip to content

Commit

Permalink
fixed CNNPolicy network architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongu committed Feb 19, 2016
1 parent 28497b7 commit 823481b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
22 changes: 13 additions & 9 deletions AlphaGo/models/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,14 @@ def _model_forward(self):
c.f. https://github.com/fchollet/keras/issues/1426
"""
model_input = self.model.get_input(train=False)
model_output = self.model.get_input(train=False)
return K.function([model_input], [model_output])
model_output = self.model.get_output(train=False)
forward_function = K.function([model_input], [model_output])

# the forward_function returns a list of tensors
# the first [0] gets the front tensor.
# this tensor, however, has dimensions (1, width, height)
# and we just want (width,height) hence the second [0]
return lambda inpt: forward_function(inpt)[0][0]

def batch_eval_state(self, state_gen, batch=16):
"""Given a stream of states in state_gen, evaluates them in batches
Expand All @@ -47,7 +53,7 @@ def eval_state(self, state):
tensor = self.preprocessor.state_to_tensor(state)

# run the tensor through the network
network_output = self.forward([tensor])[0]
network_output = self.forward([tensor])

# get network activations at legal move locations
# note: may not be a proper distribution by ignoring illegal moves
Expand Down Expand Up @@ -82,11 +88,8 @@ def create_network(**kwargs):
network = Sequential()

# create first layer
half_width = int(params["filter_width_1"] / 2)
network.add(convolutional.ZeroPadding2D(
input_shape=(params["input_dim"], params["board"], params["board"]),
padding=(half_width, half_width)))
network.add(convolutional.Convolution2D(
input_shape=(params["input_dim"], params["board"], params["board"]),
nb_filter=params["filters_per_layer"],
nb_row=params["filter_width_1"],
nb_col=params["filter_width_1"],
Expand All @@ -99,8 +102,6 @@ def create_network(**kwargs):
# use filter_width_K if it is there, otherwise use 3
filter_key = "filter_width_%d" % i
filter_width = params.get(filter_key, 3)
half_width = int(filter_width / 2)
network.add(convolutional.ZeroPadding2D(padding=(half_width, half_width)))
network.add(convolutional.Convolution2D(
nb_filter=params["filters_per_layer"],
nb_row=filter_width,
Expand Down Expand Up @@ -142,3 +143,6 @@ def save_params(self, h5_file):
"""save model parameters (weights) to the specified file
"""
raise NotImplementedError()

if __name__ == '__main__':
pol = CNNPolicy(["board", "sensibleness"])
4 changes: 2 additions & 2 deletions tests/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def test_default_policy(self):

def test_output_size(self):
policy19 = CNNPolicy(["board", "liberties", "sensibleness", "capture_size"], board=19)
output = policy19.forward([policy19.preprocessor.state_to_tensor(GameState(19))])[0]
output = policy19.forward([policy19.preprocessor.state_to_tensor(GameState(19))])
self.assertEqual(output.shape, (19,19))

policy13 = CNNPolicy(["board", "liberties", "sensibleness", "capture_size"], board=13)
output = policy13.forward([policy13.preprocessor.state_to_tensor(GameState(13))])[0]
output = policy13.forward([policy13.preprocessor.state_to_tensor(GameState(13))])
self.assertEqual(output.shape, (13,13))

0 comments on commit 823481b

Please sign in to comment.