|
1 | 1 | #! /usr/bin/env python3 |
2 | 2 | from abc import abstractmethod |
3 | 3 | from collections import defaultdict, Counter |
| 4 | +from typing import Tuple |
4 | 5 |
|
5 | 6 | import numpy |
6 | 7 |
|
@@ -79,38 +80,24 @@ def merge(self, agent): |
79 | 80 | for state_action, counts in agent.transitions.items(): |
80 | 81 | self.transitions[state_action] += counts |
81 | 82 |
|
82 | | - def transition_model(self, state: numpy.ndarray, action: int, copy: bool = False) -> numpy.ndarray: |
| 83 | + def transition_model(self, state: numpy.ndarray, action: int) -> Tuple[numpy.ndarray, numpy.ndarray]: |
83 | 84 | """ |
84 | 85 | State transition model that describes how the environment state changes when the |
85 | 86 | agent performs an action depending on the action and the current state. |
86 | 87 | :param state: The state of the environment |
87 | 88 | :param action: An action available to the agent |
88 | | - :param copy: When applying the action to the state, do so with a copy or apply it directly |
89 | 89 | """ |
90 | | - if copy: |
91 | | - next_state = state.copy() |
92 | | - else: |
93 | | - next_state = state |
94 | | - |
95 | | - state_counts = self.transitions[(*next_state, action)] |
| 90 | + state_action_pair = (*state, action) |
| 91 | + state_counts = self.transitions[state_action_pair] |
96 | 92 |
|
97 | 93 | if not state_counts: |
98 | | - return state |
| 94 | + return numpy.array([]), numpy.array([]) |
99 | 95 |
|
100 | | - states = list(state_counts.keys()) |
| 96 | + states = numpy.array(list(state_counts.keys())) |
101 | 97 | counts = numpy.array(list(state_counts.values())) |
102 | 98 |
|
103 | | - counts = numpy.maximum(counts, 0) |
104 | | - sum = counts.sum() |
105 | | - probabilities = counts / sum |
106 | | - |
107 | | - # values = [] |
108 | | - # for p, s in zip(probabilities, states): |
109 | | - # values.append(self.state_values[s].value) |
110 | | - |
111 | | - index = numpy.random.choice(numpy.arange(len(state_counts)), p=probabilities) |
112 | | - # return states[numpy.argmax(numpy.array(values))] |
113 | | - return states[index] |
| 99 | + probabilities = counts / counts.sum() |
| 100 | + return probabilities, states |
114 | 101 |
|
115 | 102 | def value_model(self, state: numpy.ndarray, action: int) -> float: |
116 | 103 | """ |
|
0 commit comments