Skip to content

Commit 1027546

Browse files
committed
black
1 parent 0c6283b commit 1027546

File tree

1 file changed

+56
-46
lines changed

1 file changed

+56
-46
lines changed

openadapt/strategies/stateful.py

Lines changed: 56 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
from copy import deepcopy
1010
from pprint import pformat
11-
#import datetime
11+
12+
# import datetime
1213

1314
from loguru import logger
1415
import deepdiff
@@ -25,7 +26,6 @@ class StatefulReplayStrategy(
2526
OpenAIReplayStrategyMixin,
2627
strategies.base.BaseReplayStrategy,
2728
):
28-
2929
def __init__(
3030
self,
3131
recording: models.Recording,
@@ -39,8 +39,7 @@ def __init__(
3939
for action_event in self.recording.processed_action_events
4040
][:-1]
4141
self.recording_action_diff_tups = zip(
42-
self.recording_window_state_diffs,
43-
self.recording_action_strs
42+
self.recording_window_state_diffs, self.recording_action_strs
4443
)
4544
self.recording_action_idx = 0
4645

@@ -52,53 +51,63 @@ def get_next_action_event(
5251
logger.debug(f"{self.recording_action_idx=}")
5352
if self.recording_action_idx == len(self.recording.processed_action_events):
5453
raise StopIteration()
55-
reference_action = (
56-
self.recording.processed_action_events[self.recording_action_idx]
57-
)
54+
reference_action = self.recording.processed_action_events[
55+
self.recording_action_idx
56+
]
5857
reference_window = reference_action.window_event
5958

60-
reference_window_dict = deepcopy({
61-
key: val
62-
for key, val in utils.row2dict(reference_window, follow=False).items()
63-
if val is not None
64-
and not key.endswith("timestamp")
65-
and not key.endswith("id")
66-
#and not isinstance(getattr(models.WindowEvent, key), property)
67-
})
59+
reference_window_dict = deepcopy(
60+
{
61+
key: val
62+
for key, val in utils.row2dict(reference_window, follow=False).items()
63+
if val is not None
64+
and not key.endswith("timestamp")
65+
and not key.endswith("id")
66+
# and not isinstance(getattr(models.WindowEvent, key), property)
67+
}
68+
)
6869
if reference_action.children:
6970
reference_action_dicts = [
70-
deepcopy({
71-
key: val
72-
for key, val in utils.row2dict(child, follow=False).items()
73-
if val is not None
74-
and not key.endswith("timestamp")
75-
and not key.endswith("id")
76-
and not isinstance(getattr(models.ActionEvent, key), property)
77-
})
71+
deepcopy(
72+
{
73+
key: val
74+
for key, val in utils.row2dict(child, follow=False).items()
75+
if val is not None
76+
and not key.endswith("timestamp")
77+
and not key.endswith("id")
78+
and not isinstance(getattr(models.ActionEvent, key), property)
79+
}
80+
)
7881
for child in reference_action.children
7982
]
8083
else:
8184
reference_action_dicts = [
82-
deepcopy({
83-
key: val
84-
for key, val in utils.row2dict(reference_action, follow=False).items()
85-
if val is not None
86-
and not key.endswith("timestamp")
87-
and not key.endswith("id")
88-
#and not isinstance(getattr(models.ActionEvent, key), property)
89-
})
85+
deepcopy(
86+
{
87+
key: val
88+
for key, val in utils.row2dict(
89+
reference_action, follow=False
90+
).items()
91+
if val is not None
92+
and not key.endswith("timestamp")
93+
and not key.endswith("id")
94+
# and not isinstance(getattr(models.ActionEvent, key), property)
95+
}
96+
)
9097
]
91-
active_window_dict = deepcopy({
92-
key: val
93-
for key, val in utils.row2dict(active_window, follow=False).items()
94-
if val is not None
95-
and not key.endswith("timestamp")
96-
and not key.endswith("id")
97-
#and not isinstance(getattr(models.WindowEvent, key), property)
98-
})
98+
active_window_dict = deepcopy(
99+
{
100+
key: val
101+
for key, val in utils.row2dict(active_window, follow=False).items()
102+
if val is not None
103+
and not key.endswith("timestamp")
104+
and not key.endswith("id")
105+
# and not isinstance(getattr(models.WindowEvent, key), property)
106+
}
107+
)
99108
if reference_window_dict and "state" in reference_window_dict:
100109
reference_window_dict["state"].pop("data")
101-
if active_window_dict and "state" in active_window_dict :
110+
if active_window_dict and "state" in active_window_dict:
102111
active_window_dict["state"].pop("data")
103112

104113
prompt = (
@@ -145,16 +154,16 @@ def get_window_state_diffs(
145154
ignore_window_ids = set()
146155
if ignore_boundary_windows:
147156
first_window_event = action_events[0].window_event
148-
if first_window_event.state :
157+
if first_window_event.state:
149158
first_window_id = first_window_event.state["window_id"]
150-
else :
159+
else:
151160
first_window_id = None
152161
first_window_title = first_window_event.title
153162
last_window_event = action_events[-1].window_event
154-
if last_window_event.state :
163+
if last_window_event.state:
155164
last_window_id = last_window_event.state["window_id"]
156165
last_window_title = last_window_event.title
157-
else :
166+
else:
158167
last_window_id = None
159168
last_window_title = last_window_event.title
160169
if first_window_id != last_window_id:
@@ -164,7 +173,8 @@ def get_window_state_diffs(
164173
logger.info(f"ignoring {first_window_title=} {last_window_title=}")
165174
window_event_states = [
166175
action_event.window_event.state
167-
if action_event.window_event.state is not None and action_event.window_event.state["window_id"] not in ignore_window_ids
176+
if action_event.window_event.state is not None
177+
and action_event.window_event.state["window_id"] not in ignore_window_ids
168178
else {}
169179
for action_event in action_events
170180
]
@@ -174,4 +184,4 @@ def get_window_state_diffs(
174184
window_event_states, window_event_states[1:]
175185
)
176186
]
177-
return diffs
187+
return diffs

0 commit comments

Comments
 (0)