Skip to content

Commit 14a83e6

Browse files
committed
All agents must inheit fro Agent. Added in null learning and policy agent for cases where only learning, or only acting is necessary (e..g random or human agents)
1 parent 3b9130e commit 14a83e6

File tree

10 files changed

+81
-76
lines changed

10 files changed

+81
-76
lines changed

rl/agents/agent.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,25 @@
11
#! /usr/bin/env python3
2-
from abc import ABC
2+
from abc import ABC, abstractmethod
33

44

55
class Agent(ABC):
6-
pass
6+
"""
7+
Agent class serves as an interface definition. Every concrete Agent must
8+
implement these four functions: act, learn, render, and reset.
9+
"""
10+
11+
@abstractmethod
12+
def __init__(self, **kwargs):
13+
pass
14+
15+
@abstractmethod
16+
def act(self, **kwargs):
17+
pass
18+
19+
@abstractmethod
20+
def learn(self, **kwargs):
21+
pass
22+
23+
@abstractmethod
24+
def reset(self, **kwargs):
25+
pass

rl/agents/learning/learning_agent.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#! /usr/bin/env python3
22
from abc import abstractmethod
33
from collections import defaultdict, Counter
4-
from typing import Dict, Tuple, Union
54

65
from rl.agents.agent import Agent
76
from rl.reprs import Transition
@@ -13,24 +12,19 @@ class LearningAgent(Agent):
1312
The learning agent implements a learning method and is used for purposes of building a state value map
1413
"""
1514

16-
def __init__(self, transitions=None):
15+
def __init__(self, state_values=None, transitions=None):
1716
self.trajectory = []
1817

19-
if not transitions:
18+
if state_values is None:
19+
self.state_values = defaultdict(Value)
20+
else:
21+
self.state_values = state_values
22+
23+
if transitions is None:
2024
self.transitions = defaultdict(Counter)
2125
else:
2226
self.transitions = transitions
2327

24-
@property
25-
@abstractmethod
26-
def state_values(self) -> Dict[Tuple[Union[int, float]], Value]:
27-
pass
28-
29-
@state_values.setter
30-
@abstractmethod
31-
def state_values(self, state_values: Dict[Tuple[Union[int, float]], Value]):
32-
pass
33-
3428
@abstractmethod
3529
def learn_value(self):
3630
pass
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#! /usr/bin/env python3
2+
3+
from rl.agents.learning import LearningAgent
4+
5+
6+
class NullLearningAgent(LearningAgent):
7+
"""
8+
The learning agent implements a learning method and is used for purposes of building a state value map
9+
"""
10+
11+
def __init__(self, *args, **kwargs):
12+
super().__init__(args, kwargs)
13+
14+
def learn_value(self):
15+
pass

rl/agents/learning/sample_averaging_agent.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#! /usr/bin/env python3
2-
from collections import defaultdict
32
from typing import Dict, Tuple, Union
43

54
from rl.agents.learning import LearningAgent
@@ -18,19 +17,7 @@ def __init__(self, state_values: Dict[Tuple[Union[int, float]], Value] = None, t
1817
Represents an agent learning with temporal difference
1918
:param state_values: A mapping of states to and their associated values
2019
"""
21-
super().__init__(transitions=transitions)
22-
if not state_values:
23-
self._state_values = defaultdict(Value)
24-
else:
25-
self._state_values = state_values
26-
27-
@property
28-
def state_values(self) -> Dict[Tuple[Union[int, float]], Value]:
29-
return self._state_values
30-
31-
@state_values.setter
32-
def state_values(self, state_values: Dict[Tuple[Union[int, float]], Value]):
33-
self._state_values = state_values
20+
super().__init__(state_values=state_values, transitions=transitions)
3421

3522
def learn_value(self):
3623
"""

rl/agents/learning/temporal_difference_agent.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#! /usr/bin/env python3
2-
from collections import defaultdict
32
from typing import Dict, Tuple, Union
43

54
from rl.agents.learning import LearningAgent
@@ -19,22 +18,9 @@ def __init__(self, learning_rate: float, state_values: Dict[Tuple[Union[int, flo
1918
:param learning_rate: How much to learn from the most recent action
2019
:param state_values: A mapping of states to and their associated values
2120
"""
22-
super().__init__(transitions=transitions)
23-
if not state_values:
24-
self._state_values = defaultdict(Value)
25-
else:
26-
self._state_values = state_values
27-
21+
super().__init__(state_values=state_values, transitions=transitions)
2822
self.learning_rate: float = learning_rate
2923

30-
@property
31-
def state_values(self) -> Dict[Tuple[Union[int, float]], Value]:
32-
return self._state_values
33-
34-
@state_values.setter
35-
def state_values(self, state_values: Dict[Tuple[Union[int, float]], Value]):
36-
self._state_values = state_values
37-
3824
def learn_value(self):
3925
"""
4026
Apply temporal difference learning and update the state and values of this agent

rl/agents/learning/temporal_difference_averaging_agent.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#! /usr/bin/env python3
2-
from collections import defaultdict
32
from typing import Dict, Tuple, Union
43

54
from rl.agents.learning import LearningAgent
@@ -20,19 +19,7 @@ def __init__(self, state_values: Dict[Tuple[Union[int, float]], Value] = None,
2019
:param learning_rate: How much to learn from the most recent action
2120
:param state_values: A mapping of states to and their associated values
2221
"""
23-
super().__init__(transitions=transitions)
24-
if not state_values:
25-
self._state_values = defaultdict(Value)
26-
else:
27-
self._state_values = state_values
28-
29-
@property
30-
def state_values(self) -> Dict[Tuple[Union[int, float]], Value]:
31-
return self._state_values
32-
33-
@state_values.setter
34-
def state_values(self, state_values: Dict[Tuple[Union[int, float]], Value]):
35-
self._state_values = state_values
22+
super().__init__(state_values=state_values, transitions=transitions)
3623

3724
def learn_value(self):
3825
"""
@@ -46,7 +33,6 @@ def learn_value(self):
4633

4734
if current_value.value != 0:
4835
for i in range(-2, -1 * len(self.trajectory), -1):
49-
5036
previous_transition = self.trajectory[i - 1]
5137
previous_value: Value = self.state_values[previous_transition.state]
5238

rl/agents/learning/weighted_averaging_agent.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#! /usr/bin/env python3
2-
from collections import defaultdict
32
from typing import Dict, Tuple, Union
43

54
from rl.agents.learning import LearningAgent
@@ -18,22 +17,9 @@ def __init__(self, learning_rate, state_values: Dict[Tuple[Union[int, float]], V
1817
Represents an agent learning with temporal difference
1918
:param state_values: A mapping of states to and their associated values
2019
"""
20+
super().__init__(state_values=state_values, transitions=transitions)
2121
self.learning_rate = learning_rate
2222

23-
super().__init__(transitions=transitions)
24-
if not state_values:
25-
self._state_values = defaultdict(Value)
26-
else:
27-
self._state_values = state_values
28-
29-
@property
30-
def state_values(self) -> Dict[Tuple[Union[int, float]], Value]:
31-
return self._state_values
32-
33-
@state_values.setter
34-
def state_values(self, state_values: Dict[Tuple[Union[int, float]], Value]):
35-
self._state_values = state_values
36-
3723
def learn_value(self):
3824
"""
3925
Apply temporal difference learning and update the state and values of this agent
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#! /usr/bin/env python3
2+
3+
import numpy
4+
5+
from rl.agents.policy import PolicyAgent
6+
7+
8+
class NullPolicyAgent(PolicyAgent):
9+
10+
def __init__(self, *args, **kwargs):
11+
super().__init__(args, kwargs)
12+
13+
def act(self, state: numpy.ndarray):
14+
"""
15+
A policy for this agent that maps an state to an action
16+
:param state: The state of the environment
17+
"""
18+
pass
19+
20+
def available_actions(self, state: numpy.ndarray) -> numpy.ndarray:
21+
"""
22+
Given a state, determine the available actions
23+
:param state: The state of the environment
24+
"""
25+
pass

rl/tictactoe/base_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import numpy
55

66
from rl.agents import RandomPolicyAgent
7+
from rl.agents.learning.null_learning_agent import NullLearningAgent
78
from rl.envs.tictactoe import Mark
89

910

10-
class BaseAgent(RandomPolicyAgent):
11+
class BaseAgent(RandomPolicyAgent, NullLearningAgent):
1112
def available_actions(self, state: numpy.ndarray) -> numpy.ndarray:
1213
"""
1314
Determines the available actions for the agent given the state

rl/tictactoe/human_agent.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,16 @@
44
import numpy
55

66
from rl.agents import HumanPolicyAgent
7+
from rl.agents.learning.null_learning_agent import NullLearningAgent
78
from rl.envs.tictactoe import Mark
89

910

10-
class HumanAgent(HumanPolicyAgent):
11+
class HumanAgent(HumanPolicyAgent, NullLearningAgent):
12+
13+
def __init__(self):
14+
HumanPolicyAgent.__init__(self)
15+
NullLearningAgent.__init__(self)
16+
1117
def available_actions(self, state: numpy.ndarray) -> numpy.ndarray:
1218
"""
1319
Determines the available actions for the agent given the state

0 commit comments

Comments
 (0)