Skip to content

Commit

Permalink
Adds multilabel efficientnet
Browse files Browse the repository at this point in the history
  • Loading branch information
LashaO committed Jun 13, 2024
1 parent 7cce85c commit 6a1ad02
Showing 1 changed file with 77 additions and 3 deletions.
80 changes: 77 additions & 3 deletions wbia/algo/detect/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
'deer_effnet_v0': 'https://wildbookiarepository.azureedge.net/models/labeler_deer_effnet.v0.zip',
'leopard_shark_effnet_v0': 'https://wildbookiarepository.azureedge.net/models/labeler_leopard_shark_effnet.v0.zip',
'trout_effnet_v0': 'https://wildbookiarepository.azureedge.net/models/labeler_trout_effnet.v0.zip',
'shark_effnet_v0': 'https://wildbookiarepository.azureedge.net/models/labeler_shark_effnet.v0.zip'
'shark_effnet_v0': 'https://wildbookiarepository.azureedge.net/models/labeler_shark_effnet.v0.zip',
'msv2_multilabel_effnet_v0': 'https://cthulhu.dyn.wildme.io/public/models/labeler_msv2_multilabel_effnet.v2.zip',
}


Expand Down Expand Up @@ -126,13 +127,49 @@ def _init_transforms(**kwargs):


class EfficientnetModel(nn.Module):
def __init__(self, n_class, model_arch='tf_efficientnet_b4_ns', pretrained=False):
def __init__(self, n_class, model_arch='tf_efficientnet_b4_ns', pretrained=False, multilabel=False):
super().__init__()
self.model = timm.create_model(model_arch, pretrained=pretrained)
self.multilabel = multilabel

self.labels = np.array(['back', 'down', 'front', 'left', 'right', 'up'])
self.sort_weights = {
'up': 1,
'down': 1,
'front': 2,
'back': 2,
'left': 3,
'right': 3,
}

self.reverse_label_map = {
'up':1,
'down':2,
'front':3,
'back':4,
'left':5,
'right':6,
'upfront':7,
'upback':8,
'upleft':9,
'upright':10,
'downfront':11,
'downback':12,
'downleft':13,
'downright':14,
'frontleft':15,
'frontright':16,
'backleft':17,
'backright':18,
}

if multilabel:
n_class = len(self.sort_weights)

if n_class is not None:
n_features = self.model.classifier.in_features
self.model.classifier = nn.Linear(n_features, n_class)

else:
self.model.classifier = nn.Identity(n_features, n_class)
'''
Expand All @@ -142,8 +179,44 @@ def __init__(self, n_class, model_arch='tf_efficientnet_b4_ns', pretrained=False
nn.Linear(n_features, n_class, bias=True)
)
'''

def process_row(self, row_labels, preds, sort_weights, labels):
multi_labels = labels[row_labels == True]
preds = preds[row_labels == True]
# Combine the multi_labels and preds into a list of tuples
label_pred_weight = [(label, pred, sort_weights[label]) for label, pred in zip(multi_labels, preds)]
# Sort by weights first, then by prediction values in descending order
label_pred_weight.sort(key=lambda x: (x[2], -x[1]))
# Create a dictionary to keep the highest value string for each weight
best_labels = {}
for label, pred, weight in label_pred_weight:
if weight not in best_labels or pred > best_labels[weight][1]:
best_labels[weight] = (label, pred)
# Extract the labels in the order of weights - top 2 labels
sorted_labels = [best_labels[weight][0] for weight in sorted(best_labels)[:2]]
return sorted_labels

def process_multilabel_preds(self, image_preds, sort_weights, labels, reverse_label_map):
multi_label_matrix = (image_preds > 0.5).cpu().numpy()
sorted_labels = [self.process_row(row, preds, sort_weights, labels) for row, preds in zip(multi_label_matrix, image_preds)]
fused_labels = [''.join(x) for x in sorted_labels]

num_labels = len(reverse_label_map)
one_hot_matrix = np.zeros((len(fused_labels), num_labels))

for i, label in enumerate(fused_labels):
if label in reverse_label_map:
one_hot_matrix[i, reverse_label_map[label] - 1] = 1

one_hot_tensor = torch.tensor(one_hot_matrix, dtype=torch.float32)

return one_hot_tensor

def forward(self, x):
x = self.model(x)
if self.multilabel:
x = self.process_multilabel_preds(x, self.sort_weights, self.labels, self.reverse_label_map)

return x


Expand Down Expand Up @@ -553,7 +626,8 @@ def test_single(filepath_list, weights_path, batch_size=1792, multi=PARALLEL, **
num_classes = len(classes)

# Initialize the model for this run
model = EfficientnetModel(n_class=num_classes)
multilabel = 'multilabel' in weights_path
model = EfficientnetModel(n_class=num_classes, multilabel=multilabel)
# num_ftrs = model.classifier.in_features
# model.classifier = nn.Linear(num_ftrs, num_classes)

Expand Down

0 comments on commit 6a1ad02

Please sign in to comment.