Skip to content

Commit 326f851

Browse files
committed
Add dropout
1 parent e27c4e7 commit 326f851

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

i3d_pt_demo.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def get_scores(sample, model):
2424
print(
2525
'Top {} classes and associated probabilities: '.format(args.top_k))
2626
for i in range(args.top_k):
27-
print('[{}]: {}'.format(kinetics_classes[top_idx[0, i]], top_val[
28-
0, i]))
27+
print('[{}]: {:.6E}'.format(kinetics_classes[top_idx[0, i]],
28+
top_val[0, i]))
2929
return out_logit
3030

3131
# Rung RGB model
@@ -55,11 +55,11 @@ def get_scores(sample, model):
5555
top_val, top_idx = torch.sort(out_softmax, 1, descending=True)
5656

5757
print('===== Final predictions ====')
58-
print('proba logits class '.format(args.top_k))
58+
print('logits proba class '.format(args.top_k))
5959
for i in range(args.top_k):
6060
logit_score = out_logit[0, top_idx[0, i]].data[0]
61-
print('{} {} {}'.format(logit_score, top_val[0, i],
62-
kinetics_classes[top_idx[0, i]]))
61+
print('{:.6e} {:.6e} {}'.format(logit_score, top_val[0, i],
62+
kinetics_classes[top_idx[0, i]]))
6363

6464

6565
if __name__ == "__main__":

src/i3dpt.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,11 @@ def forward(self, inp):
159159

160160

161161
class I3D(torch.nn.Module):
162-
def __init__(self, num_classes, modality='rgb', name='inception'):
162+
def __init__(self,
163+
num_classes,
164+
modality='rgb',
165+
dropout_keep_prob=1,
166+
name='inception'):
163167
super(I3D, self).__init__()
164168

165169
self.name = name
@@ -223,6 +227,7 @@ def __init__(self, num_classes, modality='rgb', name='inception'):
223227
self.mixed_5c = Mixed(832, [384, 192, 384, 48, 128, 128])
224228

225229
self.avg_pool = torch.nn.AvgPool3d((2, 7, 7), (1, 1, 1))
230+
self.dropout = torch.nn.Dropout(dropout_keep_prob)
226231
self.conv3d_0c_1x1 = Unit3Dpy(
227232
in_channels=1024,
228233
out_channels=self.num_classes,
@@ -252,6 +257,7 @@ def forward(self, inp):
252257
out = self.mixed_5b(out)
253258
out = self.mixed_5c(out)
254259
out = self.avg_pool(out)
260+
out = self.dropout(out)
255261
out = self.conv3d_0c_1x1(out)
256262
out = out.squeeze(3)
257263
out = out.squeeze(3)

0 commit comments

Comments
 (0)