|
53 | 53 |
|
54 | 54 | ## Tic-Tac-Toe(井字棋)
|
55 | 55 |
|
56 |
| --  |
| 56 | + |
| 57 | + |
57 | 58 | - 一个简单的应用强化学习的例子。
|
58 | 59 | - 定义policy:任何一种局面下,该如何落子。
|
59 |
| -- 遗传算法解法:试很多种policy,找到最终胜利的几种,然后结合,更新。 |
60 |
| -- 强化学习解法: |
61 |
| - - 1.建立一张表格,state_num × 1,代表每个state下,获胜的概率,这个表格就是所谓的**value function**,即状态到价值的映射。 |
62 |
| - - 2.跟对手下很多局。每次落子的时候,依据是在某个state下,选择所有可能的后继state中,获胜概率最大的(value最大的)。这种方法即贪婪法(Exploit)。偶尔我们也随机选择一些其他的state(Explore)。 |
63 |
| - - 3.“back up”后继state的v到当前state上。$V(s)\leftarrow V(s)+\alpha[V(s')-V(s)]$,这就是所谓的**差分学习**(temporal-difference learning),这么叫是因为$V(s')-V(s)$是两个时间点上的两次估计的差。 |
| 60 | + |
| 61 | +**遗传算法解法**:试很多种policy,找到最终胜利的几种,然后结合,更新。 |
| 62 | + |
| 63 | +**强化学习解法**: |
| 64 | + |
| 65 | +- 1.建立一张表格,state_num × 1,代表每个state下,获胜的概率,这个表格就是所谓的**value function**,即状态到价值的映射。 |
| 66 | +- 2.跟对手下很多局。每次落子的时候,依据是在某个state下,选择所有可能的后继state中,获胜概率最大的(value最大的)。这种方法即贪婪法(Exploit)。偶尔我们也随机选择一些其他的state(Explore)。 |
| 67 | +- 3.**back up**后继state的v到当前state上。$V(s)\leftarrow V(s)+\alpha[V(s')-V(s)]$,这就是所谓的**差分学习**(temporal-difference learning),这么叫是因为$V(s')-V(s)$是两个时间点上的两次估计的差。 |
| 68 | + |
| 69 | +### 代码分析 |
| 70 | + |
| 71 | +[完整源码](https://github.com/ShangtongZhang/reinforcement-learning-an-introduction/blob/master/chapter01/tic_tac_toe.py) |
| 72 | + |
| 73 | +游戏实现: |
| 74 | + |
| 75 | +用`1`代表白棋,`-1`代表黑棋,若有连续的三个数之和为3则白赢,-3则黑赢。若所有绝对值之和为9,则游戏为平局。 |
| 76 | + |
| 77 | +```python |
| 78 | +for result in results: |
| 79 | + if result == 3: |
| 80 | + self.winner = 1 |
| 81 | + self.end = True |
| 82 | + return self.end |
| 83 | + if result == -3: |
| 84 | + self.winner = -1 |
| 85 | + self.end = True |
| 86 | + return self.end |
| 87 | + |
| 88 | +# whether it's a tie |
| 89 | +sum = np.sum(np.abs(self.data)) |
| 90 | +if sum == BOARD_ROWS * BOARD_COLS: |
| 91 | + self.winner = 0 |
| 92 | + self.end = True |
| 93 | + return self.end |
| 94 | +``` |
| 95 | + |
| 96 | +定义状态字典: |
| 97 | + |
| 98 | +```python |
| 99 | +all_states = dict() |
| 100 | +all_states[current_state.hash()] = (current_state, current_state.is_end()) |
| 101 | +``` |
| 102 | + |
| 103 | +其中,键名是状态的哈希值,值是状态对象以及该状态是否是终止状态。哈希值计算: |
| 104 | + |
| 105 | +```python |
| 106 | +# compute the hash value for one state, it's unique |
| 107 | +def hash(self): |
| 108 | + if self.hash_val is None: |
| 109 | + self.hash_val = 0 |
| 110 | + for i in self.data.reshape(BOARD_ROWS * BOARD_COLS): |
| 111 | + if i == -1: |
| 112 | + i = 2 |
| 113 | + self.hash_val = self.hash_val * 3 + i |
| 114 | + return int(self.hash_val) |
| 115 | +``` |
| 116 | + |
| 117 | +可以看到,状态的个数理论上应该是$3^9=19683$个,下面的价值表格的键数也一样是这个数字。 |
| 118 | + |
| 119 | +价值表格也是用dict实现: |
| 120 | + |
| 121 | +```python |
| 122 | +self.estimations = dict() |
| 123 | +... |
| 124 | +for hash_val in all_states.keys(): |
| 125 | + (state, is_end) = all_states[hash_val] |
| 126 | + if is_end: |
| 127 | + if state.winner == self.symbol: |
| 128 | + self.estimations[hash_val] = 1.0 |
| 129 | + elif state.winner == 0: |
| 130 | + # we need to distinguish between a tie and a lose |
| 131 | + self.estimations[hash_val] = 0.5 |
| 132 | + else: |
| 133 | + self.estimations[hash_val] = 0 |
| 134 | + else: |
| 135 | + self.estimations[hash_val] = 0.5 |
| 136 | +``` |
| 137 | + |
| 138 | +backup: |
| 139 | + |
| 140 | +```python |
| 141 | +# update value estimation |
| 142 | +def backup(self): |
| 143 | + self.states = [state.hash() for state in self.states] |
| 144 | + |
| 145 | + for i in reversed(range(len(self.states) - 1)): |
| 146 | + state = self.states[i] |
| 147 | + td_error = self.greedy[i] * (self.estimations[self.states[i + 1]] - self.estimations[state]) |
| 148 | + self.estimations[state] += self.step_size * td_error |
| 149 | +``` |
| 150 | + |
| 151 | +决策使用epsilon-greedy: |
| 152 | + |
| 153 | +```python |
| 154 | +# choose an action based on the state |
| 155 | +def act(self): |
| 156 | + state = self.states[-1] |
| 157 | + next_states = [] |
| 158 | + next_positions = [] |
| 159 | + for i in range(BOARD_ROWS): |
| 160 | + for j in range(BOARD_COLS): |
| 161 | + if state.data[i, j] == 0: |
| 162 | + next_positions.append([i, j]) |
| 163 | + next_states.append(state.next_state(i, j, self.symbol).hash()) |
| 164 | + |
| 165 | + if np.random.rand() < self.epsilon: |
| 166 | + action = next_positions[np.random.randint(len(next_positions))] |
| 167 | + action.append(self.symbol) |
| 168 | + self.greedy[-1] = False |
| 169 | + return action |
| 170 | + |
| 171 | + values = [] |
| 172 | + for hash, pos in zip(next_states, next_positions): |
| 173 | + values.append((self.estimations[hash], pos)) |
| 174 | + # to select one of the actions of equal value at random |
| 175 | + np.random.shuffle(values) |
| 176 | + values.sort(key=lambda x: x[0], reverse=True) |
| 177 | + action = values[0][1] |
| 178 | + action.append(self.symbol) |
| 179 | + return action |
| 180 | +``` |
| 181 | + |
| 182 | +可以在终端和训练好的ai player对弈: |
| 183 | + |
| 184 | + |
| 185 | + |
| 186 | +我试了好几局,都是平局,看来训练的还是不错的。 |
| 187 | + |
| 188 | + |
| 189 | + |
| 190 | +模型训练好后,保存的数据就是价值表格。但我们从中也可以看到一个问题,一个像tic-tac-toe这么简单的问题,使用价值表格保存所有状态的价值,也需要耗费大量的存储。 |
0 commit comments