Skip to content

Commit 1d62b49

Browse files
committed
Updated Dataloader for all 4 splits
1 parent e8cda08 commit 1d62b49

File tree

3 files changed

+118
-37
lines changed

3 files changed

+118
-37
lines changed

dataloader/CaptainCookStepDataset.py

Lines changed: 113 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,46 +15,127 @@ def __init__(self, config, phase, split):
1515
self._phase = phase
1616
self._split = split
1717

18-
if self._split is None:
19-
self._split = "recordings"
18+
with open('../annotations/annotation_json/step_annotations.json', 'r') as f:
19+
self._annotations = json.load(f)
2020

2121
assert self._phase in ["train", "val", "test"], f"Invalid phase: {self._phase}"
2222
self._features_directory = self._config.features_directory
2323

24-
self._recording_ids_file = f"{self._split}_data_split_combined.json"
24+
if self._split == 'shuffle':
25+
self._recording_ids_file = f"recordings_combined_splits.json"
26+
print(f"Loading recording ids from {self._recording_ids_file}")
27+
28+
with open(f'../er_annotations/{self._recording_ids_file}', 'r') as file:
29+
self._recording_ids_json = json.load(file)
30+
31+
self._recording_ids = self._recording_ids_json['train'] + self._recording_ids_json['val'] + self._recording_ids_json['test']
32+
33+
self._step_dict = {}
34+
step_index_id = 0
35+
for recording_id in self._recording_ids:
36+
self._normal_step_dict = {}
37+
self._error_step_dict = {}
38+
normal_index_id = 0
39+
error_index_id = 0
40+
# 1. Prepare step_id, list(<start, end>) for the recording_id
41+
recording_step_dictionary = {}
42+
for step in self._annotations[recording_id]['steps']:
43+
if step['start_time'] < 0 or step['end_time'] < 0:
44+
# Ignore missing steps
45+
continue
46+
if recording_step_dictionary.get(step['step_id']) is None:
47+
recording_step_dictionary[step['step_id']] = []
48+
49+
recording_step_dictionary[step['step_id']].append(
50+
(math.floor(step['start_time']), math.ceil(step['end_time']), step['has_errors']))
51+
52+
# 2. Add step start and end time list to the step_dict
53+
for step_id in recording_step_dictionary.keys():
54+
# If the step has errors, add it to the error_step_dict, else add it to the normal_step_dict
55+
if recording_step_dictionary[step_id][0][2]:
56+
self._error_step_dict[f'E{error_index_id}'] = (recording_id, recording_step_dictionary[step_id])
57+
error_index_id += 1
58+
else:
59+
self._normal_step_dict[f'N{normal_index_id}'] = (
60+
recording_id, recording_step_dictionary[step_id])
61+
normal_index_id += 1
62+
63+
np.random.seed(config.seed)
64+
np.random.shuffle(list(self._normal_step_dict.keys()))
65+
np.random.shuffle(list(self._error_step_dict.keys()))
66+
67+
normal_step_indices = list(self._normal_step_dict.keys())
68+
error_step_indices = list(self._error_step_dict.keys())
69+
70+
self._split_proportion = [0.75, 0.16, 0.9]
71+
72+
num_normal_steps = len(normal_step_indices)
73+
num_error_steps = len(error_step_indices)
74+
75+
self._split_proportion_normal = [int(num_normal_steps * self._split_proportion[0]),
76+
int(num_normal_steps * (
77+
self._split_proportion[0] + self._split_proportion[1]))]
78+
self._split_proportion_error = [int(num_error_steps * self._split_proportion[0]),
79+
int(num_error_steps * (
80+
self._split_proportion[0] + self._split_proportion[1]))]
81+
82+
if phase == 'train':
83+
self._train_normal = normal_step_indices[:self._split_proportion_normal[0]]
84+
self._train_error = error_step_indices[:self._split_proportion_error[0]]
85+
train_indices = self._train_normal + self._train_error
86+
for index_id in train_indices:
87+
self._step_dict[step_index_id] = self._normal_step_dict.get(index_id,
88+
self._error_step_dict.get(index_id))
89+
step_index_id += 1
90+
elif phase == 'test':
91+
self._val_normal = normal_step_indices[
92+
self._split_proportion_normal[0]:self._split_proportion_normal[1]]
93+
self._val_error = error_step_indices[
94+
self._split_proportion_error[0]:self._split_proportion_error[1]]
95+
val_indices = self._val_normal + self._val_error
96+
for index_id in val_indices:
97+
self._step_dict[step_index_id] = self._normal_step_dict.get(index_id,
98+
self._error_step_dict.get(index_id))
99+
step_index_id += 1
100+
elif phase == 'val':
101+
self._test_normal = normal_step_indices[self._split_proportion_normal[1]:]
102+
self._test_error = error_step_indices[self._split_proportion_error[1]:]
103+
test_indices = self._test_normal + self._test_error
104+
for index_id in test_indices:
105+
self._step_dict[step_index_id] = self._normal_step_dict.get(index_id,
106+
self._error_step_dict.get(index_id))
107+
step_index_id += 1
25108

26-
print(f"Loading recording ids from {self._recording_ids_file}")
109+
else:
27110

28-
with open(f'../annotations/data_splits/{self._recording_ids_file}', 'r') as file:
29-
self._recording_ids_json = json.load(file)
111+
self._recording_ids_file = f"{self._split}_combined_splits.json"
112+
113+
print(f"Loading recording ids from {self._recording_ids_file}")
114+
115+
with open(f'../er_annotations/{self._recording_ids_file}', 'r') as file:
116+
self._recording_ids_json = json.load(file)
30117

31-
if self._phase == 'train':
32-
self._recording_ids = self._recording_ids_json['train'] + self._recording_ids_json['val']
33-
else:
34118
self._recording_ids = self._recording_ids_json[self._phase]
35119

36-
with open('../annotations/annotation_json/step_annotations.json', 'r') as f:
37-
self._annotations = json.load(f)
38-
39-
self._step_dict = {}
40-
index_id = 0
41-
for recording in self._recording_ids:
42-
# 1. Prepare step_id, list(<start, end>) for the recording_id
43-
recording_step_dictionary = {}
44-
for step in self._annotations[recording]['steps']:
45-
if step['start_time'] < 0 or step['end_time'] < 0:
46-
# Ignore missing steps
47-
continue
48-
if recording_step_dictionary.get(step['step_id']) is None:
49-
recording_step_dictionary[step['step_id']] = []
50-
51-
recording_step_dictionary[step['step_id']].append(
52-
(math.floor(step['start_time']), math.ceil(step['end_time']), step['has_errors']))
53-
54-
# 2. Add step start and end time list to the step_dict
55-
for step_id in recording_step_dictionary.keys():
56-
self._step_dict[index_id] = (recording, recording_step_dictionary[step_id])
57-
index_id += 1
120+
self._step_dict = {}
121+
index_id = 0
122+
for recording in self._recording_ids:
123+
# 1. Prepare step_id, list(<start, end>) for the recording_id
124+
recording_step_dictionary = {}
125+
for step in self._annotations[recording]['steps']:
126+
if step['start_time'] < 0 or step['end_time'] < 0:
127+
# Ignore missing steps
128+
continue
129+
if recording_step_dictionary.get(step['step_id']) is None:
130+
recording_step_dictionary[step['step_id']] = []
131+
132+
recording_step_dictionary[step['step_id']].append(
133+
(math.floor(step['start_time']), math.ceil(step['end_time']), step['has_errors']))
134+
135+
# 2. Add step start and end time list to the step_dict
136+
for step_id in recording_step_dictionary.keys():
137+
self._step_dict[index_id] = (recording, recording_step_dictionary[step_id])
138+
index_id += 1
58139

59140
def __len__(self):
60141
assert len(self._step_dict) > 0, "No data found in the dataset"
@@ -97,4 +178,4 @@ def collate_fn(batch):
97178
step_features = torch.cat(step_features, dim=0)
98179
step_labels = torch.cat(step_labels, dim=0)
99180

100-
return step_features, step_labels
181+
return step_features, step_labels

train_er.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,18 +162,18 @@ def train_step_test_step_er(config):
162162
# val_dataset = CaptainCookStepDataset(config, const.TEST, config.split)
163163
# val_loader = DataLoader(val_dataset, collate_fn=collate_fn, **test_kwargs)
164164

165-
train_dataset = CaptainCookStepShuffleDataset(config, const.TRAIN)
165+
train_dataset = CaptainCookStepDataset(config, const.TRAIN, config.split)
166166
train_loader = DataLoader(train_dataset, collate_fn=collate_fn, **train_kwargs)
167-
val_dataset = CaptainCookStepShuffleDataset(config, const.VAL)
167+
val_dataset = CaptainCookStepDataset(config, const.VAL, config.split)
168168
val_loader = DataLoader(val_dataset, collate_fn=collate_fn, **test_kwargs)
169-
test_dataset = CaptainCookStepShuffleDataset(config, const.TEST)
169+
test_dataset = CaptainCookStepDataset(config, const.TEST, config.split)
170170
test_loader = DataLoader(test_dataset, collate_fn=collate_fn, **test_kwargs)
171171

172172
train_er_model(train_loader, val_loader, device, config, test_loader=test_loader)
173173

174174

175175
if __name__ == "__main__":
176176
conf = Config()
177-
init_logger_and_wandb(conf)
177+
# init_logger_and_wandb(conf)
178178
train_step_test_step_er(conf)
179-
wandb.finish()
179+
# wandb.finish()

0 commit comments

Comments
 (0)