Skip to content

Commit

Permalink
add uniformSampling option
Browse files Browse the repository at this point in the history
  • Loading branch information
rm-wu committed Jan 18, 2021
1 parent 55721c8 commit e644827
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 17 deletions.
11 changes: 7 additions & 4 deletions main-run-twoStream.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
DEVICE = 'cuda'

def main_run(dataset, flowModel, rgbModel, stackSize, seqLen, memSize, trainDatasetDir, valDatasetDir, outDir,
trainBatchSize, valBatchSize, lr1, numEpochs, decay_step, decay_factor):
trainBatchSize, valBatchSize, lr1, numEpochs, decay_step, decay_factor, uniformSampling):
# GTEA 61
num_classes = 61

Expand Down Expand Up @@ -46,15 +46,16 @@ def main_run(dataset, flowModel, rgbModel, stackSize, seqLen, memSize, trainData
normalize])

vid_seq_train = makeDataset(directory, train_splits, spatial_transform=spatial_transform,
sequence=False, numSeg=1, stackSize=stackSize, fmt='.png', seqLen=seqLen)
sequence=False, numSeg=1, stackSize=stackSize, fmt='.png', seqLen=seqLen,
uniform_sampling=uniformSampling)

train_loader = torch.utils.data.DataLoader(vid_seq_train, batch_size=trainBatchSize,
shuffle=True, num_workers=4, pin_memory=True)

vid_seq_val = makeDataset(directory, val_splits,
spatial_transform=Compose([Scale(256), CenterCrop(224), ToTensor(), normalize]),
sequence=False, numSeg=1, stackSize=stackSize, fmt='.png', phase='Test',
seqLen=seqLen)
seqLen=seqLen, uniform_sampling=uniformSampling)

val_loader = torch.utils.data.DataLoader(vid_seq_val, batch_size=valBatchSize,
shuffle=False, num_workers=2, pin_memory=True)
Expand Down Expand Up @@ -216,6 +217,7 @@ def __main__():
parser.add_argument('--stepSize', type=float, default=1, help='Learning rate decay step')
parser.add_argument('--decayRate', type=float, default=0.99, help='Learning rate decay rate')
parser.add_argument('--memSize', type=int, default=512, help='ConvLSTM hidden state size')
parser.add_argument('--uniformSampling', action="store_true")

args = parser.parse_args()

Expand All @@ -234,9 +236,10 @@ def __main__():
decay_step = args.stepSize
decay_factor = args.decayRate
memSize = args.memSize
uniformSampling = args.uniformSampling

main_run(dataset, flowModel, rgbModel, stackSize, seqLen, memSize, trainDatasetDir, valDatasetDir, outDir,
trainBatchSize, valBatchSize, lr1, numEpochs, decay_step, decay_factor)
trainBatchSize, valBatchSize, lr1, numEpochs, decay_step, decay_factor, uniformSampling)


__main__()
37 changes: 24 additions & 13 deletions makeDatasetTwoStream.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def __init__(self, root_dir, splits,
numSeg=5,
fmt='.png',
phase='train',
seqLen = 25):
seqLen = 25,
uniform_sampling=True):
"""
Args:
root_dir (string): Directory with all the images.
Expand All @@ -67,6 +68,7 @@ def __init__(self, root_dir, splits,
self.fmt = fmt
self.phase = phase
self.seqLen = seqLen
self.uniform_sampling = uniform_sampling

def __len__(self):
return len(self.imagesX)
Expand Down Expand Up @@ -110,18 +112,27 @@ def __getitem__(self, idx):
else:
startFrame = np.ceil((numFrame - self.stackSize)/2)
inpSeq = []
#for k in range(self.stackSize):
# i = k + int(startFrame)
# TODO: Make it optional
for k in sorted(np.random.choice(np.arange(startFrame, numFrame+1), size=self.stackSize, replace=False)):
i = k
fl_name = vid_nameX + '/flow_x_' + str(int(round(i))).zfill(5) + self.fmt
img = Image.open(fl_name)
inpSeq.append(self.spatial_transform(img.convert('L'), inv=True, flow=True))
# fl_names.append(fl_name)
fl_name = vid_nameY + '/flow_y_' + str(int(round(i))).zfill(5) + self.fmt
img = Image.open(fl_name)
inpSeq.append(self.spatial_transform(img.convert('L'), inv=False, flow=True))
if self.uniform_sampling:
for k in sorted(np.random.choice(np.arange(startFrame, numFrame + 1), size=self.stackSize, replace=False)):
i = k
fl_name = vid_nameX + '/flow_x_' + str(int(round(i))).zfill(5) + self.fmt
img = Image.open(fl_name)
inpSeq.append(self.spatial_transform(img.convert('L'), inv=True, flow=True))
# fl_names.append(fl_name)
fl_name = vid_nameY + '/flow_y_' + str(int(round(i))).zfill(5) + self.fmt
img = Image.open(fl_name)
inpSeq.append(self.spatial_transform(img.convert('L'), inv=False, flow=True))
else:
for k in range(self.stackSize):
i = k + int(startFrame)
fl_name = vid_nameX + '/flow_x_' + str(int(round(i))).zfill(5) + self.fmt
img = Image.open(fl_name)
inpSeq.append(self.spatial_transform(img.convert('L'), inv=True, flow=True))
# fl_names.append(fl_name)
fl_name = vid_nameY + '/flow_y_' + str(int(round(i))).zfill(5) + self.fmt
img = Image.open(fl_name)
inpSeq.append(self.spatial_transform(img.convert('L'), inv=False, flow=True))

inpSeqSegs = torch.stack(inpSeq, 0).squeeze(1)

# Collect the rgb frames
Expand Down

0 comments on commit e644827

Please sign in to comment.