@@ -35,7 +35,7 @@ def check_environment_reset(env_reset):
35
35
data = []
36
36
start , end , deadline = '' , '' , ''
37
37
if env_reset .startswith ('Environment.reset()' ):
38
- data = re_findall ('[\w.]+' , env_reset )
38
+ data = re_findall ('[- \w.]+' , env_reset )
39
39
for i , each in enumerate (data ):
40
40
if each == 'start' :
41
41
start = map (int , [data [i + 1 ], data [i + 2 ]])
@@ -49,15 +49,15 @@ def check_learning_update(learning_update):
49
49
deadline , expected_reward , reward = 0 , 0 , 0
50
50
data = []
51
51
if learning_update .startswith ('LearningAgent.update()' ):
52
- data = re_findall ('[\w.]+' , learning_update )
52
+ data = re_findall ('[- \w.]+' , learning_update )
53
53
for i , each in enumerate (data ):
54
54
if each == 'deadline' :
55
55
deadline = data [i + 1 ]
56
56
elif each == 'expected_reward' :
57
57
expected_reward = data [i + 1 ]
58
58
elif each == 'reward' :
59
59
reward = data [i + 1 ]
60
- return deadline , expected_reward , reward
60
+ return map ( float , [ deadline , expected_reward , reward ])
61
61
62
62
63
63
def check_learning_update_old (learning_update ):
@@ -136,16 +136,18 @@ def success_check(data):
136
136
return ret_dict
137
137
138
138
139
- def total_stats (filename = FILE ):
139
+ def total_stats (filename = FILE , return_dict = False ):
140
140
data = fetch_data (filename )
141
141
game = success_check (data )
142
142
game_stats = pd .DataFrame .from_dict (game )
143
+ if return_dict :
144
+ return game_stats
143
145
144
146
for col in [u'all_deadlines' , u'all_expected_rewards' , u'all_trails' ,
145
147
u'all_main_deadline' , u'all_outcomes' , u'all_rewards' , # u'all_start', u'all_destinations'
146
148
]:
147
149
game_stats [col ] = pd .to_numeric (game_stats [col ])
148
-
150
+
149
151
game_stats ['Q_pred' ] = game_stats .all_expected_rewards - game_stats .all_outcomes
150
152
game_stats ['steps' ] = game_stats .all_main_deadline - game_stats .all_deadlines
151
153
game_stats ['avg_steps' ] = game_stats .all_rewards / game_stats .steps
0 commit comments