Skip to content

Commit bff1d26

Browse files
committed
update master
1 parent 01b9804 commit bff1d26

File tree

99 files changed

+2686
-333
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

99 files changed

+2686
-333
lines changed

code/dlgo/agent/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from .alphago import *
12
from .base import *
23
from .pg import *
34
from .predict import *
45
from .naive import *
56
from .naive_fast import *
7+
from .termination import *

code/dlgo/agent/alphago.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# tag::alphago_imports[]
2+
import numpy as np
3+
from dlgo.agent.base import Agent
4+
from dlgo.goboard_fast import Move
5+
from dlgo import kerasutil
6+
import operator
7+
# end::alphago_imports[]
8+
9+
10+
__all__ = [
11+
'AlphaGoNode',
12+
'AlphaGoMCTS'
13+
]
14+
15+
16+
# tag::init_alphago_node[]
17+
class AlphaGoNode:
18+
def __init__(self, parent=None, probability=1.0):
19+
self.parent = parent # <1>
20+
self.children = {} # <1>
21+
22+
self.visit_count = 0
23+
self.q_value = 0
24+
self.prior_value = probability # <2>
25+
self.u_value = probability # <3>
26+
# <1> Tree nodes have one parent and potentially many children.
27+
# <2> A node is initialized with a prior probability.
28+
# <3> The utility function will be updated during search.
29+
# end::init_alphago_node[]
30+
31+
# tag::select_node[]
32+
def select_child(self):
33+
return max(self.children.items(),
34+
key=lambda child: child[1].q_value + \
35+
child[1].u_value)
36+
# end::select_node[]
37+
38+
# tag::expand_children[]
39+
def expand_children(self, moves, probabilities):
40+
for move, prob in zip(moves, probabilities):
41+
if move not in self.children:
42+
self.children[move] = AlphaGoNode(probability=prob)
43+
# end::expand_children[]
44+
45+
# tag::update_values[]
46+
def update_values(self, leaf_value):
47+
if self.parent is not None:
48+
self.parent.update_values(leaf_value) # <1>
49+
50+
self.visit_count += 1 # <2>
51+
52+
self.q_value += leaf_value / self.visit_count # <3>
53+
54+
if self.parent is not None:
55+
c_u = 5
56+
self.u_value = c_u * np.sqrt(self.parent.visit_count) \
57+
* self.prior_value / (1 + self.visit_count) # <4>
58+
59+
# <1> We update parents first to ensure we traverse the tree top to bottom.
60+
# <2> Increment the visit count for this node.
61+
# <3> Add the specified leaf value to the Q-value, normalized by visit count.
62+
# <4> Update utility with current visit counts.
63+
# end::update_values[]
64+
65+
66+
# tag::alphago_mcts_init[]
67+
class AlphaGoMCTS(Agent):
68+
def __init__(self, policy_agent, fast_policy_agent, value_agent,
69+
lambda_value=0.5, num_simulations=1000,
70+
depth=50, rollout_limit=100):
71+
self.policy = policy_agent
72+
self.rollout_policy = fast_policy_agent
73+
self.value = value_agent
74+
75+
self.lambda_value = lambda_value
76+
self.num_simulations = num_simulations
77+
self.depth = depth
78+
self.rollout_limit = rollout_limit
79+
self.root = AlphaGoNode()
80+
# end::alphago_mcts_init[]
81+
82+
# tag::alphago_mcts_rollout[]
83+
def select_move(self, game_state):
84+
for simulation in range(self.num_simulations): # <1>
85+
current_state = game_state
86+
node = self.root
87+
for depth in range(self.depth): # <2>
88+
if not node.children: # <3>
89+
if current_state.is_over():
90+
break
91+
moves, probabilities = self.policy_probabilities(current_state) # <4>
92+
node.expand_children(moves, probabilities) # <4>
93+
94+
move, node = node.select_child() # <5>
95+
current_state = current_state.apply_move(move) # <5>
96+
97+
value = self.value.predict(current_state) # <6>
98+
rollout = self.policy_rollout(current_state) # <6>
99+
100+
weighted_value = (1 - self.lambda_value) * value + \
101+
self.lambda_value * rollout # <7>
102+
103+
node.update_values(weighted_value) # <8>
104+
# <1> From current state play out a number of simulations
105+
# <2> Play moves until the specified depth is reached.
106+
# <3> If the current node doesn't have any children...
107+
# <4> ... expand them with probabilities from the strong policy.
108+
# <5> If there are children, we can select one and play the corresponding move.
109+
# <6> Compute output of value network and a rollout by the fast policy.
110+
# <7> Determine the combined value function.
111+
# <8> Update values for this node in the backup phase
112+
# end::alphago_mcts_rollout[]
113+
114+
# tag::alphago_mcts_selection[]
115+
move = max(self.root.children, key=lambda move: # <1>
116+
self.root.children.get(move).visit_count) # <1>
117+
118+
self.root = AlphaGoNode()
119+
if move in self.root.children: # <2>
120+
self.root = self.root.children[move]
121+
self.root.parent = None
122+
123+
return move
124+
# <1> Pick most visited child of the root as next move.
125+
# <2> If the picked move is a child, set new root to this child node.
126+
# end::alphago_mcts_selection[]
127+
128+
# tag::alphago_policy_probs[]
129+
def policy_probabilities(self, game_state):
130+
encoder = self.policy._encoder
131+
outputs = self.policy.predict(game_state)
132+
legal_moves = game_state.legal_moves()
133+
if not legal_moves:
134+
return [], []
135+
encoded_points = [encoder.encode_point(move.point) for move in legal_moves if move.point]
136+
legal_outputs = outputs[encoded_points]
137+
normalized_outputs = legal_outputs / np.sum(legal_outputs)
138+
return legal_moves, normalized_outputs
139+
# end::alphago_policy_probs[]
140+
141+
# tag::alphago_policy_rollout[]
142+
def policy_rollout(self, game_state):
143+
for step in range(self.rollout_limit):
144+
if game_state.is_over():
145+
break
146+
move_probabilities = self.rollout_policy.predict(game_state)
147+
encoder = self.rollout_policy.encoder
148+
valid_moves = [m for idx, m in enumerate(move_probabilities)
149+
if Move(encoder.decode_point_index(idx)) in game_state.legal_moves()]
150+
max_index, max_value = max(enumerate(valid_moves), key=operator.itemgetter(1))
151+
max_point = encoder.decode_point_index(max_index)
152+
greedy_move = Move(max_point)
153+
if greedy_move in game_state.legal_moves():
154+
game_state = game_state.apply_move(greedy_move)
155+
156+
next_player = game_state.next_player
157+
winner = game_state.winner()
158+
if winner is not None:
159+
return 1 if winner == next_player else -1
160+
else:
161+
return 0
162+
# end::alphago_policy_rollout[]
163+
164+
165+
def serialize(self, h5file):
166+
raise IOError("AlphaGoMCTS agent can\'t be serialized" +
167+
"consider serializing the three underlying" +
168+
"neural networks instad.")

code/dlgo/agent/alphago_test.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import unittest
2+
3+
from dlgo.data.processor import GoDataProcessor
4+
from dlgo.agent.predict import DeepLearningAgent
5+
from dlgo.networks.alphago import alphago_model
6+
from dlgo.agent.pg import PolicyAgent
7+
from dlgo.agent.predict import load_prediction_agent
8+
from dlgo.encoders.alphago import AlphaGoEncoder
9+
from dlgo.rl.simulate import experience_simulation
10+
from dlgo.networks.alphago import alphago_model
11+
from dlgo.rl import ValueAgent, load_experience
12+
from dlgo.agent import load_prediction_agent, load_policy_agent, AlphaGoMCTS
13+
from dlgo.rl import load_value_agent
14+
from dlgo.goboard_fast import GameState
15+
16+
from keras.callbacks import ModelCheckpoint
17+
import h5py
18+
import numpy as np
19+
20+
class AlphaGoAgentTest(unittest.TestCase):
21+
def test_1_supervised_learning(self):
22+
rows, cols = 19, 19
23+
encoder = AlphaGoEncoder()
24+
25+
input_shape = (encoder.num_planes, rows, cols)
26+
alphago_sl_policy = alphago_model(input_shape, is_policy_net=True)
27+
28+
alphago_sl_policy.compile('sgd', 'categorical_crossentropy', metrics=['accuracy'])
29+
30+
alphago_sl_agent = DeepLearningAgent(alphago_sl_policy, encoder)
31+
32+
inputs = np.ones((10,) + input_shape)
33+
outputs = alphago_sl_policy.predict(inputs)
34+
assert(outputs.shape == (10, 361))
35+
36+
with h5py.File('test_alphago_sl_policy.h5', 'w') as sl_agent_out:
37+
alphago_sl_agent.serialize(sl_agent_out)
38+
39+
def test_2_reinforcement_learning(self):
40+
encoder = AlphaGoEncoder()
41+
42+
sl_agent = load_prediction_agent(h5py.File('test_alphago_sl_policy.h5'))
43+
sl_opponent = load_prediction_agent(h5py.File('test_alphago_sl_policy.h5'))
44+
45+
alphago_rl_agent = PolicyAgent(sl_agent.model, encoder)
46+
opponent = PolicyAgent(sl_opponent.model, encoder)
47+
48+
num_games = 1
49+
experience = experience_simulation(num_games, alphago_rl_agent, opponent)
50+
51+
alphago_rl_agent.train(experience)
52+
53+
with h5py.File('test_alphago_rl_policy.h5', 'w') as rl_agent_out:
54+
alphago_rl_agent.serialize(rl_agent_out)
55+
56+
with h5py.File('test_alphago_rl_experience.h5', 'w') as exp_out:
57+
experience.serialize(exp_out)
58+
59+
def test_3_alphago_value(self):
60+
rows, cols = 19, 19
61+
encoder = AlphaGoEncoder()
62+
input_shape = (encoder.num_planes, rows, cols)
63+
alphago_value_network = alphago_model(input_shape)
64+
65+
alphago_value = ValueAgent(alphago_value_network, encoder)
66+
67+
experience = load_experience(h5py.File('test_alphago_rl_experience.h5', 'r'))
68+
69+
alphago_value.train(experience)
70+
71+
with h5py.File('test_alphago_value.h5', 'w') as value_agent_out:
72+
alphago_value.serialize(value_agent_out)
73+
74+
def test_4_alphago_mcts(self):
75+
fast_policy = load_prediction_agent(h5py.File('test_alphago_sl_policy.h5', 'r'))
76+
strong_policy = load_policy_agent(h5py.File('test_alphago_rl_policy.h5', 'r'))
77+
value = load_value_agent(h5py.File('test_alphago_value.h5', 'r'))
78+
79+
alphago = AlphaGoMCTS(strong_policy, fast_policy, value,
80+
num_simulations=20, depth=5, rollout_limit=10)
81+
start = GameState.new_game(19)
82+
alphago.select_move(start)
83+
84+
85+
if __name__ == '__main__':
86+
unittest.main()

code/dlgo/agent/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55

66
# tag::agent[]
7-
class Agent():
8-
"""Interface for a go-playing bot."""
7+
class Agent:
8+
def __init__(self):
9+
pass
10+
911
def select_move(self, game_state):
1012
raise NotImplementedError()
1113
# end::agent[]

code/dlgo/agent/naive_fast.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
class FastRandomBot(Agent):
1313
def __init__(self):
14+
Agent.__init__(self)
1415
self.dim = None
1516
self.point_cache = []
1617

code/dlgo/agent/pg.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,17 @@ def normalize(x):
3232
class PolicyAgent(Agent):
3333
"""An agent that uses a deep policy network to select moves."""
3434
def __init__(self, model, encoder):
35+
Agent.__init__(self)
3536
self._model = model
3637
self._encoder = encoder
3738
self._collector = None
3839
self._temperature = 0.0
3940

41+
def predict(self, game_state):
42+
encoded_state = self._encoder.encode(game_state)
43+
input_tensor = np.array([encoded_state])
44+
return self._model.predict(input_tensor)[0]
45+
4046
def set_temperature(self, temperature):
4147
self._temperature = temperature
4248

@@ -47,14 +53,14 @@ def select_move(self, game_state):
4753
num_moves = self._encoder.board_width * self._encoder.board_height
4854

4955
board_tensor = self._encoder.encode(game_state)
50-
X = np.array([board_tensor])
56+
x = np.array([board_tensor])
5157

5258
if np.random.random() < self._temperature:
5359
# Explore random moves.
5460
move_probs = np.ones(num_moves) / num_moves
5561
else:
5662
# Follow our current policy.
57-
move_probs = self._model.predict(X)[0]
63+
move_probs = self._model.predict(x)[0]
5864

5965
# Prevent move probs from getting stuck at 0 or 1.
6066
eps = 1e-5

code/dlgo/agent/predict.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,20 @@
1616
# tag::dl_agent_init[]
1717
class DeepLearningAgent(Agent):
1818
def __init__(self, model, encoder):
19-
self._model = model
20-
self._encoder = encoder
19+
Agent.__init__(self)
20+
self.model = model
21+
self.encoder = encoder
2122
# end::dl_agent_init[]
2223

2324
# tag::dl_agent_predict[]
25+
def predict(self, game_state):
26+
encoded_state = self.encoder.encode(game_state)
27+
input_tensor = np.array([encoded_state])
28+
return self.model.predict(input_tensor)[0]
29+
2430
def select_move(self, game_state):
25-
num_moves = self._encoder.board_width * self._encoder.board_height
26-
board_tensor = self._encoder.encode(game_state)
27-
X = np.array([board_tensor])
28-
move_probs = self._model.predict(X)[0]
31+
num_moves = self.encoder.board_width * self.encoder.board_height
32+
move_probs = self.predict(game_state)
2933
# end::dl_agent_predict[]
3034

3135
# tag::dl_agent_probabilities[]
@@ -43,11 +47,11 @@ def select_move(self, game_state):
4347
ranked_moves = np.random.choice(
4448
candidates, num_moves, replace=False, p=move_probs) # <2>
4549
for point_idx in ranked_moves:
46-
point = self._encoder.decode_point_index(point_idx)
50+
point = self.encoder.decode_point_index(point_idx)
4751
if game_state.is_valid_move(goboard.Move.play(point)) and \
4852
not is_point_an_eye(game_state.board, point, game_state.next_player): # <3>
4953
return goboard.Move.play(point)
50-
return goboard.Move.pass_turn() # <4> No legal, non-self-destructive moves less.
54+
return goboard.Move.pass_turn() # <4>
5155
# <1> Turn the probabilities into a ranked list of moves.
5256
# <2> Sample potential candidates
5357
# <3> Starting from the top, find a valid move that doesn't reduce eye-space.
@@ -57,11 +61,11 @@ def select_move(self, game_state):
5761
# tag::dl_agent_serialize[]
5862
def serialize(self, h5file):
5963
h5file.create_group('encoder')
60-
h5file['encoder'].attrs['name'] = self._encoder.name()
61-
h5file['encoder'].attrs['board_width'] = self._encoder.board_width
62-
h5file['encoder'].attrs['board_height'] = self._encoder.board_height
64+
h5file['encoder'].attrs['name'] = self.encoder.name()
65+
h5file['encoder'].attrs['board_width'] = self.encoder.board_width
66+
h5file['encoder'].attrs['board_height'] = self.encoder.board_height
6367
h5file.create_group('model')
64-
kerasutil.save_model_to_hdf5_group(self._model, h5file['model'])
68+
kerasutil.save_model_to_hdf5_group(self.model, h5file['model'])
6569
# end::dl_agent_serialize[]
6670

6771

0 commit comments

Comments
 (0)