Skip to content

Commit e395a75

Browse files
committed
Run scripts and dataloader
1 parent 1d62b49 commit e395a75

File tree

10 files changed

+235
-309
lines changed

10 files changed

+235
-309
lines changed

core/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def setup_parser(self):
6767
parser.add_argument('--features_directory', type=str, default='/data/rohith/captain_cook/features/gopro'
6868
'/segments', help='features directory')
6969
parser.add_argument('--ckpt_directory', type=str, default='/data/rohith/captain_cook/checkpoints', help='checkpoint directory')
70-
parser.add_argument('--split', type=str, default='recordings', help='split')
70+
parser.add_argument('--split', type=str, default=const.RECORDINGS_SPLIT, help='split')
7171
parser.add_argument('--variant', type=str, default=const.MLP_VARIANT, help='variant')
7272
parser.add_argument('--model_name', type=str, default=None, help='model name')
7373
parser.add_argument('--task_name', type=str, default=const.ERROR_RECOGNITION, help='task name')

dataloader/CaptainCookStepDataset.py

Lines changed: 124 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import torch
77
from torch.utils.data import Dataset
8+
from constants import Constants as const
89

910

1011
class CaptainCookStepDataset(Dataset):
@@ -18,124 +19,133 @@ def __init__(self, config, phase, split):
1819
with open('../annotations/annotation_json/step_annotations.json', 'r') as f:
1920
self._annotations = json.load(f)
2021

22+
print("Loaded annotations...... ")
23+
2124
assert self._phase in ["train", "val", "test"], f"Invalid phase: {self._phase}"
22-
self._features_directory = self._config.features_directory
2325

24-
if self._split == 'shuffle':
25-
self._recording_ids_file = f"recordings_combined_splits.json"
26-
print(f"Loading recording ids from {self._recording_ids_file}")
27-
28-
with open(f'../er_annotations/{self._recording_ids_file}', 'r') as file:
29-
self._recording_ids_json = json.load(file)
30-
31-
self._recording_ids = self._recording_ids_json['train'] + self._recording_ids_json['val'] + self._recording_ids_json['test']
32-
33-
self._step_dict = {}
34-
step_index_id = 0
35-
for recording_id in self._recording_ids:
36-
self._normal_step_dict = {}
37-
self._error_step_dict = {}
38-
normal_index_id = 0
39-
error_index_id = 0
40-
# 1. Prepare step_id, list(<start, end>) for the recording_id
41-
recording_step_dictionary = {}
42-
for step in self._annotations[recording_id]['steps']:
43-
if step['start_time'] < 0 or step['end_time'] < 0:
44-
# Ignore missing steps
45-
continue
46-
if recording_step_dictionary.get(step['step_id']) is None:
47-
recording_step_dictionary[step['step_id']] = []
48-
49-
recording_step_dictionary[step['step_id']].append(
50-
(math.floor(step['start_time']), math.ceil(step['end_time']), step['has_errors']))
51-
52-
# 2. Add step start and end time list to the step_dict
53-
for step_id in recording_step_dictionary.keys():
54-
# If the step has errors, add it to the error_step_dict, else add it to the normal_step_dict
55-
if recording_step_dictionary[step_id][0][2]:
56-
self._error_step_dict[f'E{error_index_id}'] = (recording_id, recording_step_dictionary[step_id])
57-
error_index_id += 1
58-
else:
59-
self._normal_step_dict[f'N{normal_index_id}'] = (
60-
recording_id, recording_step_dictionary[step_id])
61-
normal_index_id += 1
62-
63-
np.random.seed(config.seed)
64-
np.random.shuffle(list(self._normal_step_dict.keys()))
65-
np.random.shuffle(list(self._error_step_dict.keys()))
66-
67-
normal_step_indices = list(self._normal_step_dict.keys())
68-
error_step_indices = list(self._error_step_dict.keys())
69-
70-
self._split_proportion = [0.75, 0.16, 0.9]
71-
72-
num_normal_steps = len(normal_step_indices)
73-
num_error_steps = len(error_step_indices)
74-
75-
self._split_proportion_normal = [int(num_normal_steps * self._split_proportion[0]),
76-
int(num_normal_steps * (
77-
self._split_proportion[0] + self._split_proportion[1]))]
78-
self._split_proportion_error = [int(num_error_steps * self._split_proportion[0]),
79-
int(num_error_steps * (
80-
self._split_proportion[0] + self._split_proportion[1]))]
81-
82-
if phase == 'train':
83-
self._train_normal = normal_step_indices[:self._split_proportion_normal[0]]
84-
self._train_error = error_step_indices[:self._split_proportion_error[0]]
85-
train_indices = self._train_normal + self._train_error
86-
for index_id in train_indices:
87-
self._step_dict[step_index_id] = self._normal_step_dict.get(index_id,
88-
self._error_step_dict.get(index_id))
89-
step_index_id += 1
90-
elif phase == 'test':
91-
self._val_normal = normal_step_indices[
92-
self._split_proportion_normal[0]:self._split_proportion_normal[1]]
93-
self._val_error = error_step_indices[
94-
self._split_proportion_error[0]:self._split_proportion_error[1]]
95-
val_indices = self._val_normal + self._val_error
96-
for index_id in val_indices:
97-
self._step_dict[step_index_id] = self._normal_step_dict.get(index_id,
98-
self._error_step_dict.get(index_id))
99-
step_index_id += 1
100-
elif phase == 'val':
101-
self._test_normal = normal_step_indices[self._split_proportion_normal[1]:]
102-
self._test_error = error_step_indices[self._split_proportion_error[1]:]
103-
test_indices = self._test_normal + self._test_error
104-
for index_id in test_indices:
105-
self._step_dict[step_index_id] = self._normal_step_dict.get(index_id,
106-
self._error_step_dict.get(index_id))
107-
step_index_id += 1
26+
self._features_directory = self._config.features_directory
10827

28+
if self._split == const.STEP_SPLIT:
29+
self._init_step_split(config, phase)
10930
else:
110-
111-
self._recording_ids_file = f"{self._split}_combined_splits.json"
112-
113-
print(f"Loading recording ids from {self._recording_ids_file}")
114-
115-
with open(f'../er_annotations/{self._recording_ids_file}', 'r') as file:
116-
self._recording_ids_json = json.load(file)
117-
118-
self._recording_ids = self._recording_ids_json[self._phase]
119-
120-
self._step_dict = {}
121-
index_id = 0
122-
for recording in self._recording_ids:
123-
# 1. Prepare step_id, list(<start, end>) for the recording_id
124-
recording_step_dictionary = {}
125-
for step in self._annotations[recording]['steps']:
126-
if step['start_time'] < 0 or step['end_time'] < 0:
127-
# Ignore missing steps
128-
continue
129-
if recording_step_dictionary.get(step['step_id']) is None:
130-
recording_step_dictionary[step['step_id']] = []
131-
132-
recording_step_dictionary[step['step_id']].append(
133-
(math.floor(step['start_time']), math.ceil(step['end_time']), step['has_errors']))
134-
135-
# 2. Add step start and end time list to the step_dict
136-
for step_id in recording_step_dictionary.keys():
137-
self._step_dict[index_id] = (recording, recording_step_dictionary[step_id])
138-
index_id += 1
31+
self._init_other_split_from_file(config, phase)
32+
33+
def _init_step_split(self, config, phase):
34+
self._recording_ids_file = "recordings_combined_splits.json"
35+
print(f"Loading recording ids from {self._recording_ids_file}")
36+
# annotations_file_path = os.path.join(os.path.dirname(__file__), f'../er_annotations/{
37+
# self._recording_ids_file}')
38+
annotations_file_path = f"/home/rxp190007/CODE/error_recognition/er_annotations/{self._recording_ids_file}"
39+
with open(f'{annotations_file_path}', 'r') as file:
40+
self._recording_ids_json = json.load(file)
41+
42+
self._recording_ids = self._recording_ids_json['train'] + self._recording_ids_json['val'] + \
43+
self._recording_ids_json['test']
44+
45+
self._step_dict = {}
46+
step_index_id = 0
47+
for recording_id in self._recording_ids:
48+
self._normal_step_dict = {}
49+
self._error_step_dict = {}
50+
normal_index_id = 0
51+
error_index_id = 0
52+
# 1. Prepare step_id, list(<start, end>) for the recording_id
53+
recording_step_dictionary = {}
54+
for step in self._annotations[recording_id]['steps']:
55+
if step['start_time'] < 0 or step['end_time'] < 0:
56+
# Ignore missing steps
57+
continue
58+
if recording_step_dictionary.get(step['step_id']) is None:
59+
recording_step_dictionary[step['step_id']] = []
60+
61+
recording_step_dictionary[step['step_id']].append(
62+
(math.floor(step['start_time']), math.ceil(step['end_time']), step['has_errors']))
63+
64+
# 2. Add step start and end time list to the step_dict
65+
for step_id in recording_step_dictionary.keys():
66+
# If the step has errors, add it to the error_step_dict, else add it to the normal_step_dict
67+
if recording_step_dictionary[step_id][0][2]:
68+
self._error_step_dict[f'E{error_index_id}'] = (recording_id, recording_step_dictionary[step_id])
69+
error_index_id += 1
70+
else:
71+
self._normal_step_dict[f'N{normal_index_id}'] = (
72+
recording_id, recording_step_dictionary[step_id])
73+
normal_index_id += 1
74+
75+
np.random.seed(config.seed)
76+
np.random.shuffle(list(self._normal_step_dict.keys()))
77+
np.random.shuffle(list(self._error_step_dict.keys()))
78+
79+
normal_step_indices = list(self._normal_step_dict.keys())
80+
error_step_indices = list(self._error_step_dict.keys())
81+
82+
self._split_proportion = [0.75, 0.16, 0.9]
83+
84+
num_normal_steps = len(normal_step_indices)
85+
num_error_steps = len(error_step_indices)
86+
87+
self._split_proportion_normal = [int(num_normal_steps * self._split_proportion[0]),
88+
int(num_normal_steps * (
89+
self._split_proportion[0] + self._split_proportion[1]))]
90+
self._split_proportion_error = [int(num_error_steps * self._split_proportion[0]),
91+
int(num_error_steps * (
92+
self._split_proportion[0] + self._split_proportion[1]))]
93+
94+
if phase == 'train':
95+
self._train_normal = normal_step_indices[:self._split_proportion_normal[0]]
96+
self._train_error = error_step_indices[:self._split_proportion_error[0]]
97+
train_indices = self._train_normal + self._train_error
98+
for index_id in train_indices:
99+
self._step_dict[step_index_id] = self._normal_step_dict.get(index_id,
100+
self._error_step_dict.get(index_id))
101+
step_index_id += 1
102+
elif phase == 'test':
103+
self._val_normal = normal_step_indices[
104+
self._split_proportion_normal[0]:self._split_proportion_normal[1]]
105+
self._val_error = error_step_indices[
106+
self._split_proportion_error[0]:self._split_proportion_error[1]]
107+
val_indices = self._val_normal + self._val_error
108+
for index_id in val_indices:
109+
self._step_dict[step_index_id] = self._normal_step_dict.get(index_id,
110+
self._error_step_dict.get(index_id))
111+
step_index_id += 1
112+
elif phase == 'val':
113+
self._test_normal = normal_step_indices[self._split_proportion_normal[1]:]
114+
self._test_error = error_step_indices[self._split_proportion_error[1]:]
115+
test_indices = self._test_normal + self._test_error
116+
for index_id in test_indices:
117+
self._step_dict[step_index_id] = self._normal_step_dict.get(index_id,
118+
self._error_step_dict.get(index_id))
119+
step_index_id += 1
120+
121+
def _init_other_split_from_file(self, config, phase):
122+
self._recording_ids_file = f"{self._split}_combined_splits.json"
123+
# annotations_file_path = os.path.join(os.path.dirname(__file__), f'../er_annotations/{self._recording_ids_file}')
124+
annotations_file_path = f"/home/rxp190007/CODE/error_recognition/er_annotations/{self._recording_ids_file}"
125+
print(f"Loading recording ids from {self._recording_ids_file}")
126+
with open(f'{annotations_file_path}', 'r') as file:
127+
self._recording_ids_json = json.load(file)
128+
129+
self._recording_ids = self._recording_ids_json[phase]
130+
self._step_dict = {}
131+
index_id = 0
132+
for recording in self._recording_ids:
133+
# 1. Prepare step_id, list(<start, end>) for the recording_id
134+
recording_step_dictionary = {}
135+
for step in self._annotations[recording]['steps']:
136+
if step['start_time'] < 0 or step['end_time'] < 0:
137+
# Ignore missing steps
138+
continue
139+
if recording_step_dictionary.get(step['step_id']) is None:
140+
recording_step_dictionary[step['step_id']] = []
141+
142+
recording_step_dictionary[step['step_id']].append(
143+
(math.floor(step['start_time']), math.ceil(step['end_time']), step['has_errors']))
144+
145+
# 2. Add step start and end time list to the step_dict
146+
for step_id in recording_step_dictionary.keys():
147+
self._step_dict[index_id] = (recording, recording_step_dictionary[step_id])
148+
index_id += 1
139149

140150
def __len__(self):
141151
assert len(self._step_dict) > 0, "No data found in the dataset"

0 commit comments

Comments
 (0)