@@ -15,46 +15,127 @@ def __init__(self, config, phase, split):
15
15
self ._phase = phase
16
16
self ._split = split
17
17
18
- if self . _split is None :
19
- self ._split = "recordings"
18
+ with open ( '../annotations/annotation_json/step_annotations.json' , 'r' ) as f :
19
+ self ._annotations = json . load ( f )
20
20
21
21
assert self ._phase in ["train" , "val" , "test" ], f"Invalid phase: { self ._phase } "
22
22
self ._features_directory = self ._config .features_directory
23
23
24
- self ._recording_ids_file = f"{ self ._split } _data_split_combined.json"
24
+ if self ._split == 'shuffle' :
25
+ self ._recording_ids_file = f"recordings_combined_splits.json"
26
+ print (f"Loading recording ids from { self ._recording_ids_file } " )
27
+
28
+ with open (f'../er_annotations/{ self ._recording_ids_file } ' , 'r' ) as file :
29
+ self ._recording_ids_json = json .load (file )
30
+
31
+ self ._recording_ids = self ._recording_ids_json ['train' ] + self ._recording_ids_json ['val' ] + self ._recording_ids_json ['test' ]
32
+
33
+ self ._step_dict = {}
34
+ step_index_id = 0
35
+ for recording_id in self ._recording_ids :
36
+ self ._normal_step_dict = {}
37
+ self ._error_step_dict = {}
38
+ normal_index_id = 0
39
+ error_index_id = 0
40
+ # 1. Prepare step_id, list(<start, end>) for the recording_id
41
+ recording_step_dictionary = {}
42
+ for step in self ._annotations [recording_id ]['steps' ]:
43
+ if step ['start_time' ] < 0 or step ['end_time' ] < 0 :
44
+ # Ignore missing steps
45
+ continue
46
+ if recording_step_dictionary .get (step ['step_id' ]) is None :
47
+ recording_step_dictionary [step ['step_id' ]] = []
48
+
49
+ recording_step_dictionary [step ['step_id' ]].append (
50
+ (math .floor (step ['start_time' ]), math .ceil (step ['end_time' ]), step ['has_errors' ]))
51
+
52
+ # 2. Add step start and end time list to the step_dict
53
+ for step_id in recording_step_dictionary .keys ():
54
+ # If the step has errors, add it to the error_step_dict, else add it to the normal_step_dict
55
+ if recording_step_dictionary [step_id ][0 ][2 ]:
56
+ self ._error_step_dict [f'E{ error_index_id } ' ] = (recording_id , recording_step_dictionary [step_id ])
57
+ error_index_id += 1
58
+ else :
59
+ self ._normal_step_dict [f'N{ normal_index_id } ' ] = (
60
+ recording_id , recording_step_dictionary [step_id ])
61
+ normal_index_id += 1
62
+
63
+ np .random .seed (config .seed )
64
+ np .random .shuffle (list (self ._normal_step_dict .keys ()))
65
+ np .random .shuffle (list (self ._error_step_dict .keys ()))
66
+
67
+ normal_step_indices = list (self ._normal_step_dict .keys ())
68
+ error_step_indices = list (self ._error_step_dict .keys ())
69
+
70
+ self ._split_proportion = [0.75 , 0.16 , 0.9 ]
71
+
72
+ num_normal_steps = len (normal_step_indices )
73
+ num_error_steps = len (error_step_indices )
74
+
75
+ self ._split_proportion_normal = [int (num_normal_steps * self ._split_proportion [0 ]),
76
+ int (num_normal_steps * (
77
+ self ._split_proportion [0 ] + self ._split_proportion [1 ]))]
78
+ self ._split_proportion_error = [int (num_error_steps * self ._split_proportion [0 ]),
79
+ int (num_error_steps * (
80
+ self ._split_proportion [0 ] + self ._split_proportion [1 ]))]
81
+
82
+ if phase == 'train' :
83
+ self ._train_normal = normal_step_indices [:self ._split_proportion_normal [0 ]]
84
+ self ._train_error = error_step_indices [:self ._split_proportion_error [0 ]]
85
+ train_indices = self ._train_normal + self ._train_error
86
+ for index_id in train_indices :
87
+ self ._step_dict [step_index_id ] = self ._normal_step_dict .get (index_id ,
88
+ self ._error_step_dict .get (index_id ))
89
+ step_index_id += 1
90
+ elif phase == 'test' :
91
+ self ._val_normal = normal_step_indices [
92
+ self ._split_proportion_normal [0 ]:self ._split_proportion_normal [1 ]]
93
+ self ._val_error = error_step_indices [
94
+ self ._split_proportion_error [0 ]:self ._split_proportion_error [1 ]]
95
+ val_indices = self ._val_normal + self ._val_error
96
+ for index_id in val_indices :
97
+ self ._step_dict [step_index_id ] = self ._normal_step_dict .get (index_id ,
98
+ self ._error_step_dict .get (index_id ))
99
+ step_index_id += 1
100
+ elif phase == 'val' :
101
+ self ._test_normal = normal_step_indices [self ._split_proportion_normal [1 ]:]
102
+ self ._test_error = error_step_indices [self ._split_proportion_error [1 ]:]
103
+ test_indices = self ._test_normal + self ._test_error
104
+ for index_id in test_indices :
105
+ self ._step_dict [step_index_id ] = self ._normal_step_dict .get (index_id ,
106
+ self ._error_step_dict .get (index_id ))
107
+ step_index_id += 1
25
108
26
- print ( f"Loading recording ids from { self . _recording_ids_file } " )
109
+ else :
27
110
28
- with open (f'../annotations/data_splits/{ self ._recording_ids_file } ' , 'r' ) as file :
29
- self ._recording_ids_json = json .load (file )
111
+ self ._recording_ids_file = f"{ self ._split } _combined_splits.json"
112
+
113
+ print (f"Loading recording ids from { self ._recording_ids_file } " )
114
+
115
+ with open (f'../er_annotations/{ self ._recording_ids_file } ' , 'r' ) as file :
116
+ self ._recording_ids_json = json .load (file )
30
117
31
- if self ._phase == 'train' :
32
- self ._recording_ids = self ._recording_ids_json ['train' ] + self ._recording_ids_json ['val' ]
33
- else :
34
118
self ._recording_ids = self ._recording_ids_json [self ._phase ]
35
119
36
- with open ('../annotations/annotation_json/step_annotations.json' , 'r' ) as f :
37
- self ._annotations = json .load (f )
38
-
39
- self ._step_dict = {}
40
- index_id = 0
41
- for recording in self ._recording_ids :
42
- # 1. Prepare step_id, list(<start, end>) for the recording_id
43
- recording_step_dictionary = {}
44
- for step in self ._annotations [recording ]['steps' ]:
45
- if step ['start_time' ] < 0 or step ['end_time' ] < 0 :
46
- # Ignore missing steps
47
- continue
48
- if recording_step_dictionary .get (step ['step_id' ]) is None :
49
- recording_step_dictionary [step ['step_id' ]] = []
50
-
51
- recording_step_dictionary [step ['step_id' ]].append (
52
- (math .floor (step ['start_time' ]), math .ceil (step ['end_time' ]), step ['has_errors' ]))
53
-
54
- # 2. Add step start and end time list to the step_dict
55
- for step_id in recording_step_dictionary .keys ():
56
- self ._step_dict [index_id ] = (recording , recording_step_dictionary [step_id ])
57
- index_id += 1
120
+ self ._step_dict = {}
121
+ index_id = 0
122
+ for recording in self ._recording_ids :
123
+ # 1. Prepare step_id, list(<start, end>) for the recording_id
124
+ recording_step_dictionary = {}
125
+ for step in self ._annotations [recording ]['steps' ]:
126
+ if step ['start_time' ] < 0 or step ['end_time' ] < 0 :
127
+ # Ignore missing steps
128
+ continue
129
+ if recording_step_dictionary .get (step ['step_id' ]) is None :
130
+ recording_step_dictionary [step ['step_id' ]] = []
131
+
132
+ recording_step_dictionary [step ['step_id' ]].append (
133
+ (math .floor (step ['start_time' ]), math .ceil (step ['end_time' ]), step ['has_errors' ]))
134
+
135
+ # 2. Add step start and end time list to the step_dict
136
+ for step_id in recording_step_dictionary .keys ():
137
+ self ._step_dict [index_id ] = (recording , recording_step_dictionary [step_id ])
138
+ index_id += 1
58
139
59
140
def __len__ (self ):
60
141
assert len (self ._step_dict ) > 0 , "No data found in the dataset"
@@ -97,4 +178,4 @@ def collate_fn(batch):
97
178
step_features = torch .cat (step_features , dim = 0 )
98
179
step_labels = torch .cat (step_labels , dim = 0 )
99
180
100
- return step_features , step_labels
181
+ return step_features , step_labels
0 commit comments