-
Notifications
You must be signed in to change notification settings - Fork 1
/
radar_datasets.py
152 lines (122 loc) · 5.79 KB
/
radar_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import h5py
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
class DopNetH5Dataset(Dataset):
def __init__(self, h5_file, person_list, gestures, transform=None, target_width=144):
self.h5_file = h5_file
self.person_list = person_list
self.gestures = gestures
self.transform = transform
self.target_width = target_width # Fixed width to pad to
self.data_index = self._build_index()
def _build_index(self):
"""Builds an index for easy access to each sample."""
index = []
with h5py.File(self.h5_file, 'r') as h5f:
for person in self.person_list:
for gesture in self.gestures:
# Dynamically find all samples under each gesture
gesture_group = h5f[f"{person}/{gesture}"]
samples = list(gesture_group.keys())
for sample in samples:
index.append((person, gesture, sample))
return index
def __len__(self):
return len(self.data_index)
def __getitem__(self, idx):
# Open the HDF5 file in read mode
with h5py.File(self.h5_file, 'r') as h5f:
# Get the sample details
person, gesture, sample = self.data_index[idx]
data = h5f[f"{person}/{gesture}/{sample}"][:]
# Pad or trim data to the target width
padded_data = np.zeros((data.shape[0], self.target_width), dtype=np.float32)
padded_data[:, :data.shape[1]] = data[:, :self.target_width]
# Convert to tensor and apply transform if specified
data = torch.tensor(data, dtype=torch.float32)
label = gesture # Use gesture as label
if self.transform:
data = self.transform(data)
return data, label
def collate_fn(batch):
# Find the maximum width in the batch
# max_width = max(data.shape[1] for data, _ in batch)
# max_width = 540
max_width = 150
# Pad each sample in the batch to the max width
padded_batch = []
lengths = []
for data, label in batch:
lengths.append(data.shape[1])
padded_data = np.zeros((data.shape[0], max_width), dtype=np.float32)
padded_data[:, :data.shape[1]] = data
padded_batch.append((torch.tensor(padded_data, dtype=torch.float32), label))
# Separate data and labels for the batch
data, labels = zip(*padded_batch)
return torch.stack(data), torch.tensor(labels), lengths
class SoliH5Dataset(Dataset):
def __init__(self, h5_file, mode, transform=None):
self.h5_file = h5_file
self.transform = transform
assert mode in ['mdoppler', 'unrolled']
self.mode = mode
self.data_index = self._build_index()
def _build_index(self):
"""Builds an index for easy access to each sample."""
index = []
with h5py.File(self.h5_file, 'r') as h5f:
for key in h5f.keys():
index.append(key)
return index
def __len__(self):
return len(self.data_index)
def __getitem__(self, idx):
# Open the HDF5 file in read mode
with h5py.File(self.h5_file, 'r') as h5f:
data = h5f[self.data_index[idx]][self.mode][:]
label = h5f[self.data_index[idx]]['label'][0]
# Convert to tensor and apply transform if specified
data = torch.tensor(data, dtype=torch.float32)
if self.transform:
data = self.transform(data)
return data.T, label
def main():
dataset = 'soli'
if dataset == 'DopNet':
# Load the dataset and DataLoader
h5_file = os.path.join('data', 'gesture', 'train', 'preprocessed_gesture_data_train.h5')
person_list = ['A', 'B', 'C', 'D', 'E', 'F']
gestures = [0, 1, 2, 3] # Wave, Pinch, Swipe, Click
dopnet_h5_dataset = DopNetH5Dataset(h5_file=h5_file, person_list=person_list, gestures=gestures)
dopnet_h5_dataloader = DataLoader(dopnet_h5_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
for batch_x, batch_y in dopnet_h5_dataloader:
print("Batch spectrograms shape:", batch_x.shape)
print("Batch labels:", batch_y)
break
# Load the test dataset and DataLoader
test_h5_file = os.path.join('data', 'gesture', 'test', 'preprocessed_gesture_data_test.h5')
dopnet_test_dataset = DopNetH5Dataset(h5_file=test_h5_file, person_list=person_list, gestures=gestures)
dopnet_test_dataloader = DataLoader(dopnet_test_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)
for test_batch_x, _ in dopnet_test_dataloader:
print("Test batch spectrograms shape:", test_batch_x.shape)
break
elif dataset == 'soli':
# Load the dataset and DataLoader
h5_file = os.path.join('data', 'Soli_Preprocessed', 'soli_Preprocessed.h5')
soli_h5_dataset = SoliH5Dataset(h5_file=h5_file, mode='unrolled')
# split the dataset into training and validation
train_dataset, val_dataset = torch.utils.data.random_split(soli_h5_dataset, [int(len(soli_h5_dataset)*0.8), len(soli_h5_dataset) - int(len(soli_h5_dataset)*0.8)])
soli_h5_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
for batch_x, batch_y, length in soli_h5_dataloader:
print("Batch mdoppler shape:", batch_x.shape)
print("Batch labels:", batch_y)
break
for test_batch_x, test_batch_y, length in val_dataloader:
print("Validation batch mdoppler shape:", test_batch_x.shape)
print("Validation batch labels:", test_batch_y)
break
if __name__ == "__main__":
main()