-
Notifications
You must be signed in to change notification settings - Fork 56
/
model.py
194 lines (162 loc) · 6.34 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from torch import nn
import torch
import math
class SGN(nn.Module):
def __init__(self, num_classes, dataset, seg, args, bias = True):
super(SGN, self).__init__()
self.dim1 = 256
self.dataset = dataset
self.seg = seg
num_joint = 25
bs = args.batch_size
if args.train:
self.spa = self.one_hot(bs, num_joint, self.seg)
self.spa = self.spa.permute(0, 3, 2, 1).cuda()
self.tem = self.one_hot(bs, self.seg, num_joint)
self.tem = self.tem.permute(0, 3, 1, 2).cuda()
else:
self.spa = self.one_hot(32 * 5, num_joint, self.seg)
self.spa = self.spa.permute(0, 3, 2, 1).cuda()
self.tem = self.one_hot(32 * 5, self.seg, num_joint)
self.tem = self.tem.permute(0, 3, 1, 2).cuda()
self.tem_embed = embed(self.seg, 64*4, norm=False, bias=bias)
self.spa_embed = embed(num_joint, 64, norm=False, bias=bias)
self.joint_embed = embed(3, 64, norm=True, bias=bias)
self.dif_embed = embed(3, 64, norm=True, bias=bias)
self.maxpool = nn.AdaptiveMaxPool2d((1, 1))
self.cnn = local(self.dim1, self.dim1 * 2, bias=bias)
self.compute_g1 = compute_g_spa(self.dim1 // 2, self.dim1, bias=bias)
self.gcn1 = gcn_spa(self.dim1 // 2, self.dim1 // 2, bias=bias)
self.gcn2 = gcn_spa(self.dim1 // 2, self.dim1, bias=bias)
self.gcn3 = gcn_spa(self.dim1, self.dim1, bias=bias)
self.fc = nn.Linear(self.dim1 * 2, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
nn.init.constant_(self.gcn1.w.cnn.weight, 0)
nn.init.constant_(self.gcn2.w.cnn.weight, 0)
nn.init.constant_(self.gcn3.w.cnn.weight, 0)
def forward(self, input):
# Dynamic Representation
bs, step, dim = input.size()
num_joints = dim //3
input = input.view((bs, step, num_joints, 3))
input = input.permute(0, 3, 2, 1).contiguous()
dif = input[:, :, :, 1:] - input[:, :, :, 0:-1]
dif = torch.cat([dif.new(bs, dif.size(1), num_joints, 1).zero_(), dif], dim=-1)
pos = self.joint_embed(input)
tem1 = self.tem_embed(self.tem)
spa1 = self.spa_embed(self.spa)
dif = self.dif_embed(dif)
dy = pos + dif
# Joint-level Module
input= torch.cat([dy, spa1], 1)
g = self.compute_g1(input)
input = self.gcn1(input, g)
input = self.gcn2(input, g)
input = self.gcn3(input, g)
# Frame-level Module
input = input + tem1
input = self.cnn(input)
# Classification
output = self.maxpool(input)
output = torch.flatten(output, 1)
output = self.fc(output)
return output
def one_hot(self, bs, spa, tem):
y = torch.arange(spa).unsqueeze(-1)
y_onehot = torch.FloatTensor(spa, spa)
y_onehot.zero_()
y_onehot.scatter_(1, y, 1)
y_onehot = y_onehot.unsqueeze(0).unsqueeze(0)
y_onehot = y_onehot.repeat(bs, tem, 1, 1)
return y_onehot
class norm_data(nn.Module):
def __init__(self, dim= 64):
super(norm_data, self).__init__()
self.bn = nn.BatchNorm1d(dim* 25)
def forward(self, x):
bs, c, num_joints, step = x.size()
x = x.view(bs, -1, step)
x = self.bn(x)
x = x.view(bs, -1, num_joints, step).contiguous()
return x
class embed(nn.Module):
def __init__(self, dim = 3, dim1 = 128, norm = True, bias = False):
super(embed, self).__init__()
if norm:
self.cnn = nn.Sequential(
norm_data(dim),
cnn1x1(dim, 64, bias=bias),
nn.ReLU(),
cnn1x1(64, dim1, bias=bias),
nn.ReLU(),
)
else:
self.cnn = nn.Sequential(
cnn1x1(dim, 64, bias=bias),
nn.ReLU(),
cnn1x1(64, dim1, bias=bias),
nn.ReLU(),
)
def forward(self, x):
x = self.cnn(x)
return x
class cnn1x1(nn.Module):
def __init__(self, dim1 = 3, dim2 =3, bias = True):
super(cnn1x1, self).__init__()
self.cnn = nn.Conv2d(dim1, dim2, kernel_size=1, bias=bias)
def forward(self, x):
x = self.cnn(x)
return x
class local(nn.Module):
def __init__(self, dim1 = 3, dim2 = 3, bias = False):
super(local, self).__init__()
self.maxpool = nn.AdaptiveMaxPool2d((1, 20))
self.cnn1 = nn.Conv2d(dim1, dim1, kernel_size=(1, 3), padding=(0, 1), bias=bias)
self.bn1 = nn.BatchNorm2d(dim1)
self.relu = nn.ReLU()
self.cnn2 = nn.Conv2d(dim1, dim2, kernel_size=1, bias=bias)
self.bn2 = nn.BatchNorm2d(dim2)
self.dropout = nn.Dropout2d(0.2)
def forward(self, x1):
x1 = self.maxpool(x1)
x = self.cnn1(x1)
x = self.bn1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.cnn2(x)
x = self.bn2(x)
x = self.relu(x)
return x
class gcn_spa(nn.Module):
def __init__(self, in_feature, out_feature, bias = False):
super(gcn_spa, self).__init__()
self.bn = nn.BatchNorm2d(out_feature)
self.relu = nn.ReLU()
self.w = cnn1x1(in_feature, out_feature, bias=False)
self.w1 = cnn1x1(in_feature, out_feature, bias=bias)
def forward(self, x1, g):
x = x1.permute(0, 3, 2, 1).contiguous()
x = g.matmul(x)
x = x.permute(0, 3, 2, 1).contiguous()
x = self.w(x) + self.w1(x1)
x = self.relu(self.bn(x))
return x
class compute_g_spa(nn.Module):
def __init__(self, dim1 = 64 *3, dim2 = 64*3, bias = False):
super(compute_g_spa, self).__init__()
self.dim1 = dim1
self.dim2 = dim2
self.g1 = cnn1x1(self.dim1, self.dim2, bias=bias)
self.g2 = cnn1x1(self.dim1, self.dim2, bias=bias)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x1):
g1 = self.g1(x1).permute(0, 3, 2, 1).contiguous()
g2 = self.g2(x1).permute(0, 3, 1, 2).contiguous()
g3 = g1.matmul(g2)
g = self.softmax(g3)
return g