-
Notifications
You must be signed in to change notification settings - Fork 2
/
test.py
33 lines (21 loc) · 850 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from AgeNet.models import *
from torch import nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.gender_model = GenderClassificationModel()
self.age_range_model = AgeRangeModel()
self.age_estimation_model = AgeEstimationModel()
def forward(self,x):
"""x: batch, 3, 64, 64"""
if len(x.shape) == 3:
x = x[None, ...]
predicted_genders = self.gender_model(x)
age_ranges = self.age_range_model(x)
y = torch.argmax(age_ranges, dim = 1).view(-1,)
estimated_ages = self.age_estimation_model(x, y)
return predicted_genders, estimated_ages
model = Model()
x = torch.rand(3,64,64)
genders, ages = model(x)
print(genders.shape, ages.shape)