Skip to content

Commit d9a3449

Browse files
committed
Saving and loading agent data with pickling.
1 parent e64eb30 commit d9a3449

File tree

3 files changed

+11
-38
lines changed

3 files changed

+11
-38
lines changed

rl/tictactoe/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def learn(builder: AgentBuilder, num_games: int, num_agents: int, policy_filenam
176176
filename = policy_filename
177177
else:
178178
timestamp = time.strftime("%Y%m%d_%H%M%S")
179-
filename = os.getcwd() + "/" + timestamp + ".json"
179+
filename = os.getcwd() + "/" + timestamp + ".pickle"
180180

181181
save_learning_agent(main_agent, filename=filename)
182182

rl/tictactoe/optimal.pickle

15.6 MB
Binary file not shown.

rl/utils/io_utils.py

Lines changed: 10 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
#! /usr/bin/env python3
2-
import json
3-
from collections import defaultdict, Counter
4-
5-
from rl.reprs import Value
2+
import pickle
63

74

85
def load_learning_agent(filename: str):
@@ -12,25 +9,10 @@ def load_learning_agent(filename: str):
129
:return: The state value mapping
1310
"""
1411

15-
state_values = defaultdict(Value)
16-
transitions = defaultdict(Counter)
17-
with open(filename, "r") as f:
18-
data = json.load(f)
19-
state_values_data = data["state_values"]
20-
transitions_data = data["transitions"]
21-
22-
for state_value in state_values_data:
23-
state_values[tuple(state_value[0])] = Value(**state_value[1])
24-
25-
for state_action_counts in transitions_data:
26-
state_action_pair = tuple(state_action_counts[0])
27-
28-
for state_count in state_action_counts[1]:
29-
counts = Counter()
30-
counts[tuple(state_count[0])] = state_count[1]
31-
transitions[state_action_pair] += counts
12+
with open(filename, "rb") as f:
13+
data = pickle.load(f)
3214

33-
return state_values, transitions
15+
return data["state_values"], data["transitions"]
3416

3517

3618
def save_learning_agent(agent, filename: str):
@@ -40,19 +22,10 @@ def save_learning_agent(agent, filename: str):
4022
:param filename: The name of the file to write to
4123
"""
4224

43-
data = {"state_values": [], "transitions": []}
44-
45-
for state, value in agent.state_values.items():
46-
if value.count > 0:
47-
data["state_values"].append([[float(num) for num in state], value.__dict__])
48-
49-
for state_action, transition_counts in agent.transitions.items():
50-
transition = [[float(num) for num in state_action], []]
51-
for state, count in transition_counts.items():
52-
state_counts = [[float(num) for num in state], count]
53-
transition[1].append(state_counts)
54-
55-
data["transitions"].append(transition)
25+
data = {
26+
"state_values": agent.state_values,
27+
"transitions": agent.transitions
28+
}
5629

57-
with open(filename, "w") as f:
58-
json.dump(data, f, sort_keys=True, indent=2)
30+
with open(filename, "ab") as f:
31+
pickle.dump(data, f)

0 commit comments

Comments
 (0)