Skip to content

Commit cd372ac

Browse files
committed
Tic tac toe can now learn, but slow.
1 parent e9a7c8e commit cd372ac

File tree

6 files changed

+36
-32
lines changed

6 files changed

+36
-32
lines changed

rl/agents/agent_builder.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#! /usr/bin/env python3
2+
23
from rl import agents
34

45

@@ -35,18 +36,24 @@ def add(self, agent_type: str):
3536
def set(self, *args, **kwargs):
3637
self.args = args
3738
self.kwargs = kwargs
38-
39-
def make(self) -> agents.Agent:
4039
policy_name = self.registry[self.policy_agent].__name__
4140
learning_name = self.registry[self.learning_agent].__name__
41+
4242
exec(f"""
43+
global {policy_name}{learning_name}
4344
class {policy_name}{learning_name}(agents.{policy_name} ,agents.{learning_name}):
4445
def __init__(self, *args, **kwargs):
4546
super().__init__(*args, **kwargs)
4647
48+
""")
49+
50+
def make(self) -> agents.Agent:
51+
policy_name = self.registry[self.policy_agent].__name__
52+
learning_name = self.registry[self.learning_agent].__name__
53+
54+
exec(f"""
4755
global agent
48-
agent = {policy_name}{learning_name}(*self.args, **self.kwargs)
49-
""")
56+
agent = {policy_name}{learning_name}(*self.args, **self.kwargs) """)
5057

5158
return agent
5259

@@ -55,4 +62,3 @@ def __init__(self, *args, **kwargs):
5562
builder = AgentBuilder(policy="EGreedy", learning="TemporalDifference")
5663
builder.set(exploratory_rate=0.1, learning_rate=0.5)
5764
agent = builder.make()
58-
print()

rl/agents/learning/learning_agent.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#! /usr/bin/env python3
22
from abc import abstractmethod
33
from collections import defaultdict, Counter
4+
from typing import Tuple
45

56
import numpy
67

@@ -79,38 +80,24 @@ def merge(self, agent):
7980
for state_action, counts in agent.transitions.items():
8081
self.transitions[state_action] += counts
8182

82-
def transition_model(self, state: numpy.ndarray, action: int, copy: bool = False) -> numpy.ndarray:
83+
def transition_model(self, state: numpy.ndarray, action: int) -> Tuple[numpy.ndarray, numpy.ndarray]:
8384
"""
8485
State transition model that describes how the environment state changes when the
8586
agent performs an action depending on the action and the current state.
8687
:param state: The state of the environment
8788
:param action: An action available to the agent
88-
:param copy: When applying the action to the state, do so with a copy or apply it directly
8989
"""
90-
if copy:
91-
next_state = state.copy()
92-
else:
93-
next_state = state
94-
95-
state_counts = self.transitions[(*next_state, action)]
90+
state_action_pair = (*state, action)
91+
state_counts = self.transitions[state_action_pair]
9692

9793
if not state_counts:
98-
return state
94+
return numpy.array([]), numpy.array([])
9995

100-
states = list(state_counts.keys())
96+
states = numpy.array(list(state_counts.keys()))
10197
counts = numpy.array(list(state_counts.values()))
10298

103-
counts = numpy.maximum(counts, 0)
104-
sum = counts.sum()
105-
probabilities = counts / sum
106-
107-
# values = []
108-
# for p, s in zip(probabilities, states):
109-
# values.append(self.state_values[s].value)
110-
111-
index = numpy.random.choice(numpy.arange(len(state_counts)), p=probabilities)
112-
# return states[numpy.argmax(numpy.array(values))]
113-
return states[index]
99+
probabilities = counts / counts.sum()
100+
return probabilities, states
114101

115102
def value_model(self, state: numpy.ndarray, action: int) -> float:
116103
"""

rl/agents/policy/decaying_egreedy_policy_agent.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,12 @@ def greedy_action(self, state: numpy.ndarray, available_actions: numpy.ndarray)
5151
max_index: int = 0
5252

5353
for index, action in enumerate(available_actions):
54-
next_state: numpy.ndarray = self.transition_model(state, action, copy=True)
54+
probabilities: numpy.ndarray
55+
states: numpy.ndarray
56+
probabilities, states = self.transition_model(state.copy(), action)
57+
index = numpy.random.choice(numpy.arange(len(states), p=probabilities))
58+
59+
next_state: numpy.ndarray = states[index]
5560
next_value: float = self.value_model(next_state, action)
5661

5762
if next_value > max_value:

rl/agents/policy/egreedy_policy_agent.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,16 @@ def greedy_action(self, state: numpy.ndarray, available_actions: numpy.ndarray)
5050
max_index: int = 0
5151

5252
for index, action in enumerate(available_actions):
53-
next_state: numpy.ndarray = self.transition_model(state, action, copy=True)
54-
next_value: float = self.value_model(next_state, action)
53+
probabilities: numpy.ndarray
54+
states: numpy.ndarray
55+
probabilities, states = self.transition_model(state.copy(), action)
56+
57+
if probabilities.any():
58+
index = numpy.random.choice(numpy.arange(len(states), p=probabilities))
59+
next_state: numpy.ndarray = states[index]
60+
next_value: float = self.value_model(next_state, action)
61+
else:
62+
continue
5563

5664
if next_value > max_value:
5765
max_index: int = index

rl/agents/policy/human_policy_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def act(self, state: numpy.ndarray, available_actions: numpy.ndarray) -> int:
1818
:return: The action selected
1919
"""
2020
while True:
21-
user_input: str = input(f"available actions: {available_actions}")
21+
user_input: str = input(f"available actions: {[action + 1 for action in available_actions]}")
2222

2323
if user_input.startswith("q") or "quit" in user_input:
2424
print("quitting!")

rl/tictactoe/main.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ def learn_from_game(args):
8080
env = gym.make("TicTacToe-v0")
8181
obs: numpy.ndarray = env.reset()
8282

83-
builder = AgentBuilder(policy="EGreedy", learning="TemporalDifference")
84-
8583
players: Dict[Mark, Agent] = {
8684
Mark.X: builder.make(),
8785
Mark.O: builder.make(),

0 commit comments

Comments
 (0)