5
5
import numpy as np
6
6
import torch
7
7
from torch .utils .data import Dataset
8
+ from constants import Constants as const
8
9
9
10
10
11
class CaptainCookStepDataset (Dataset ):
@@ -18,124 +19,133 @@ def __init__(self, config, phase, split):
18
19
with open ('../annotations/annotation_json/step_annotations.json' , 'r' ) as f :
19
20
self ._annotations = json .load (f )
20
21
22
+ print ("Loaded annotations...... " )
23
+
21
24
assert self ._phase in ["train" , "val" , "test" ], f"Invalid phase: { self ._phase } "
22
- self ._features_directory = self ._config .features_directory
23
25
24
- if self ._split == 'shuffle' :
25
- self ._recording_ids_file = f"recordings_combined_splits.json"
26
- print (f"Loading recording ids from { self ._recording_ids_file } " )
27
-
28
- with open (f'../er_annotations/{ self ._recording_ids_file } ' , 'r' ) as file :
29
- self ._recording_ids_json = json .load (file )
30
-
31
- self ._recording_ids = self ._recording_ids_json ['train' ] + self ._recording_ids_json ['val' ] + self ._recording_ids_json ['test' ]
32
-
33
- self ._step_dict = {}
34
- step_index_id = 0
35
- for recording_id in self ._recording_ids :
36
- self ._normal_step_dict = {}
37
- self ._error_step_dict = {}
38
- normal_index_id = 0
39
- error_index_id = 0
40
- # 1. Prepare step_id, list(<start, end>) for the recording_id
41
- recording_step_dictionary = {}
42
- for step in self ._annotations [recording_id ]['steps' ]:
43
- if step ['start_time' ] < 0 or step ['end_time' ] < 0 :
44
- # Ignore missing steps
45
- continue
46
- if recording_step_dictionary .get (step ['step_id' ]) is None :
47
- recording_step_dictionary [step ['step_id' ]] = []
48
-
49
- recording_step_dictionary [step ['step_id' ]].append (
50
- (math .floor (step ['start_time' ]), math .ceil (step ['end_time' ]), step ['has_errors' ]))
51
-
52
- # 2. Add step start and end time list to the step_dict
53
- for step_id in recording_step_dictionary .keys ():
54
- # If the step has errors, add it to the error_step_dict, else add it to the normal_step_dict
55
- if recording_step_dictionary [step_id ][0 ][2 ]:
56
- self ._error_step_dict [f'E{ error_index_id } ' ] = (recording_id , recording_step_dictionary [step_id ])
57
- error_index_id += 1
58
- else :
59
- self ._normal_step_dict [f'N{ normal_index_id } ' ] = (
60
- recording_id , recording_step_dictionary [step_id ])
61
- normal_index_id += 1
62
-
63
- np .random .seed (config .seed )
64
- np .random .shuffle (list (self ._normal_step_dict .keys ()))
65
- np .random .shuffle (list (self ._error_step_dict .keys ()))
66
-
67
- normal_step_indices = list (self ._normal_step_dict .keys ())
68
- error_step_indices = list (self ._error_step_dict .keys ())
69
-
70
- self ._split_proportion = [0.75 , 0.16 , 0.9 ]
71
-
72
- num_normal_steps = len (normal_step_indices )
73
- num_error_steps = len (error_step_indices )
74
-
75
- self ._split_proportion_normal = [int (num_normal_steps * self ._split_proportion [0 ]),
76
- int (num_normal_steps * (
77
- self ._split_proportion [0 ] + self ._split_proportion [1 ]))]
78
- self ._split_proportion_error = [int (num_error_steps * self ._split_proportion [0 ]),
79
- int (num_error_steps * (
80
- self ._split_proportion [0 ] + self ._split_proportion [1 ]))]
81
-
82
- if phase == 'train' :
83
- self ._train_normal = normal_step_indices [:self ._split_proportion_normal [0 ]]
84
- self ._train_error = error_step_indices [:self ._split_proportion_error [0 ]]
85
- train_indices = self ._train_normal + self ._train_error
86
- for index_id in train_indices :
87
- self ._step_dict [step_index_id ] = self ._normal_step_dict .get (index_id ,
88
- self ._error_step_dict .get (index_id ))
89
- step_index_id += 1
90
- elif phase == 'test' :
91
- self ._val_normal = normal_step_indices [
92
- self ._split_proportion_normal [0 ]:self ._split_proportion_normal [1 ]]
93
- self ._val_error = error_step_indices [
94
- self ._split_proportion_error [0 ]:self ._split_proportion_error [1 ]]
95
- val_indices = self ._val_normal + self ._val_error
96
- for index_id in val_indices :
97
- self ._step_dict [step_index_id ] = self ._normal_step_dict .get (index_id ,
98
- self ._error_step_dict .get (index_id ))
99
- step_index_id += 1
100
- elif phase == 'val' :
101
- self ._test_normal = normal_step_indices [self ._split_proportion_normal [1 ]:]
102
- self ._test_error = error_step_indices [self ._split_proportion_error [1 ]:]
103
- test_indices = self ._test_normal + self ._test_error
104
- for index_id in test_indices :
105
- self ._step_dict [step_index_id ] = self ._normal_step_dict .get (index_id ,
106
- self ._error_step_dict .get (index_id ))
107
- step_index_id += 1
26
+ self ._features_directory = self ._config .features_directory
108
27
28
+ if self ._split == const .STEP_SPLIT :
29
+ self ._init_step_split (config , phase )
109
30
else :
110
-
111
- self ._recording_ids_file = f"{ self ._split } _combined_splits.json"
112
-
113
- print (f"Loading recording ids from { self ._recording_ids_file } " )
114
-
115
- with open (f'../er_annotations/{ self ._recording_ids_file } ' , 'r' ) as file :
116
- self ._recording_ids_json = json .load (file )
117
-
118
- self ._recording_ids = self ._recording_ids_json [self ._phase ]
119
-
120
- self ._step_dict = {}
121
- index_id = 0
122
- for recording in self ._recording_ids :
123
- # 1. Prepare step_id, list(<start, end>) for the recording_id
124
- recording_step_dictionary = {}
125
- for step in self ._annotations [recording ]['steps' ]:
126
- if step ['start_time' ] < 0 or step ['end_time' ] < 0 :
127
- # Ignore missing steps
128
- continue
129
- if recording_step_dictionary .get (step ['step_id' ]) is None :
130
- recording_step_dictionary [step ['step_id' ]] = []
131
-
132
- recording_step_dictionary [step ['step_id' ]].append (
133
- (math .floor (step ['start_time' ]), math .ceil (step ['end_time' ]), step ['has_errors' ]))
134
-
135
- # 2. Add step start and end time list to the step_dict
136
- for step_id in recording_step_dictionary .keys ():
137
- self ._step_dict [index_id ] = (recording , recording_step_dictionary [step_id ])
138
- index_id += 1
31
+ self ._init_other_split_from_file (config , phase )
32
+
33
+ def _init_step_split (self , config , phase ):
34
+ self ._recording_ids_file = "recordings_combined_splits.json"
35
+ print (f"Loading recording ids from { self ._recording_ids_file } " )
36
+ # annotations_file_path = os.path.join(os.path.dirname(__file__), f'../er_annotations/{
37
+ # self._recording_ids_file}')
38
+ annotations_file_path = f"/home/rxp190007/CODE/error_recognition/er_annotations/{ self ._recording_ids_file } "
39
+ with open (f'{ annotations_file_path } ' , 'r' ) as file :
40
+ self ._recording_ids_json = json .load (file )
41
+
42
+ self ._recording_ids = self ._recording_ids_json ['train' ] + self ._recording_ids_json ['val' ] + \
43
+ self ._recording_ids_json ['test' ]
44
+
45
+ self ._step_dict = {}
46
+ step_index_id = 0
47
+ for recording_id in self ._recording_ids :
48
+ self ._normal_step_dict = {}
49
+ self ._error_step_dict = {}
50
+ normal_index_id = 0
51
+ error_index_id = 0
52
+ # 1. Prepare step_id, list(<start, end>) for the recording_id
53
+ recording_step_dictionary = {}
54
+ for step in self ._annotations [recording_id ]['steps' ]:
55
+ if step ['start_time' ] < 0 or step ['end_time' ] < 0 :
56
+ # Ignore missing steps
57
+ continue
58
+ if recording_step_dictionary .get (step ['step_id' ]) is None :
59
+ recording_step_dictionary [step ['step_id' ]] = []
60
+
61
+ recording_step_dictionary [step ['step_id' ]].append (
62
+ (math .floor (step ['start_time' ]), math .ceil (step ['end_time' ]), step ['has_errors' ]))
63
+
64
+ # 2. Add step start and end time list to the step_dict
65
+ for step_id in recording_step_dictionary .keys ():
66
+ # If the step has errors, add it to the error_step_dict, else add it to the normal_step_dict
67
+ if recording_step_dictionary [step_id ][0 ][2 ]:
68
+ self ._error_step_dict [f'E{ error_index_id } ' ] = (recording_id , recording_step_dictionary [step_id ])
69
+ error_index_id += 1
70
+ else :
71
+ self ._normal_step_dict [f'N{ normal_index_id } ' ] = (
72
+ recording_id , recording_step_dictionary [step_id ])
73
+ normal_index_id += 1
74
+
75
+ np .random .seed (config .seed )
76
+ np .random .shuffle (list (self ._normal_step_dict .keys ()))
77
+ np .random .shuffle (list (self ._error_step_dict .keys ()))
78
+
79
+ normal_step_indices = list (self ._normal_step_dict .keys ())
80
+ error_step_indices = list (self ._error_step_dict .keys ())
81
+
82
+ self ._split_proportion = [0.75 , 0.16 , 0.9 ]
83
+
84
+ num_normal_steps = len (normal_step_indices )
85
+ num_error_steps = len (error_step_indices )
86
+
87
+ self ._split_proportion_normal = [int (num_normal_steps * self ._split_proportion [0 ]),
88
+ int (num_normal_steps * (
89
+ self ._split_proportion [0 ] + self ._split_proportion [1 ]))]
90
+ self ._split_proportion_error = [int (num_error_steps * self ._split_proportion [0 ]),
91
+ int (num_error_steps * (
92
+ self ._split_proportion [0 ] + self ._split_proportion [1 ]))]
93
+
94
+ if phase == 'train' :
95
+ self ._train_normal = normal_step_indices [:self ._split_proportion_normal [0 ]]
96
+ self ._train_error = error_step_indices [:self ._split_proportion_error [0 ]]
97
+ train_indices = self ._train_normal + self ._train_error
98
+ for index_id in train_indices :
99
+ self ._step_dict [step_index_id ] = self ._normal_step_dict .get (index_id ,
100
+ self ._error_step_dict .get (index_id ))
101
+ step_index_id += 1
102
+ elif phase == 'test' :
103
+ self ._val_normal = normal_step_indices [
104
+ self ._split_proportion_normal [0 ]:self ._split_proportion_normal [1 ]]
105
+ self ._val_error = error_step_indices [
106
+ self ._split_proportion_error [0 ]:self ._split_proportion_error [1 ]]
107
+ val_indices = self ._val_normal + self ._val_error
108
+ for index_id in val_indices :
109
+ self ._step_dict [step_index_id ] = self ._normal_step_dict .get (index_id ,
110
+ self ._error_step_dict .get (index_id ))
111
+ step_index_id += 1
112
+ elif phase == 'val' :
113
+ self ._test_normal = normal_step_indices [self ._split_proportion_normal [1 ]:]
114
+ self ._test_error = error_step_indices [self ._split_proportion_error [1 ]:]
115
+ test_indices = self ._test_normal + self ._test_error
116
+ for index_id in test_indices :
117
+ self ._step_dict [step_index_id ] = self ._normal_step_dict .get (index_id ,
118
+ self ._error_step_dict .get (index_id ))
119
+ step_index_id += 1
120
+
121
+ def _init_other_split_from_file (self , config , phase ):
122
+ self ._recording_ids_file = f"{ self ._split } _combined_splits.json"
123
+ # annotations_file_path = os.path.join(os.path.dirname(__file__), f'../er_annotations/{self._recording_ids_file}')
124
+ annotations_file_path = f"/home/rxp190007/CODE/error_recognition/er_annotations/{ self ._recording_ids_file } "
125
+ print (f"Loading recording ids from { self ._recording_ids_file } " )
126
+ with open (f'{ annotations_file_path } ' , 'r' ) as file :
127
+ self ._recording_ids_json = json .load (file )
128
+
129
+ self ._recording_ids = self ._recording_ids_json [phase ]
130
+ self ._step_dict = {}
131
+ index_id = 0
132
+ for recording in self ._recording_ids :
133
+ # 1. Prepare step_id, list(<start, end>) for the recording_id
134
+ recording_step_dictionary = {}
135
+ for step in self ._annotations [recording ]['steps' ]:
136
+ if step ['start_time' ] < 0 or step ['end_time' ] < 0 :
137
+ # Ignore missing steps
138
+ continue
139
+ if recording_step_dictionary .get (step ['step_id' ]) is None :
140
+ recording_step_dictionary [step ['step_id' ]] = []
141
+
142
+ recording_step_dictionary [step ['step_id' ]].append (
143
+ (math .floor (step ['start_time' ]), math .ceil (step ['end_time' ]), step ['has_errors' ]))
144
+
145
+ # 2. Add step start and end time list to the step_dict
146
+ for step_id in recording_step_dictionary .keys ():
147
+ self ._step_dict [index_id ] = (recording , recording_step_dictionary [step_id ])
148
+ index_id += 1
139
149
140
150
def __len__ (self ):
141
151
assert len (self ._step_dict ) > 0 , "No data found in the dataset"
0 commit comments