Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds multilabel efficientnet #278

Merged
merged 3 commits into from
Jun 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 77 additions & 3 deletions wbia/algo/detect/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
'deer_effnet_v0': 'https://wildbookiarepository.azureedge.net/models/labeler_deer_effnet.v0.zip',
'leopard_shark_effnet_v0': 'https://wildbookiarepository.azureedge.net/models/labeler_leopard_shark_effnet.v0.zip',
'trout_effnet_v0': 'https://wildbookiarepository.azureedge.net/models/labeler_trout_effnet.v0.zip',
'shark_effnet_v0': 'https://wildbookiarepository.azureedge.net/models/labeler_shark_effnet.v0.zip'
'shark_effnet_v0': 'https://wildbookiarepository.azureedge.net/models/labeler_shark_effnet.v0.zip',
'msv2_multilabel_effnet_v0': 'https://cthulhu.dyn.wildme.io/public/models/labeler_msv2_multilabel_effnet.v2.zip',
}


Expand Down Expand Up @@ -126,13 +127,49 @@ def _init_transforms(**kwargs):


class EfficientnetModel(nn.Module):
def __init__(self, n_class, model_arch='tf_efficientnet_b4_ns', pretrained=False):
def __init__(self, n_class, model_arch='tf_efficientnet_b4_ns', pretrained=False, multilabel=False):
super().__init__()
self.model = timm.create_model(model_arch, pretrained=pretrained)
self.multilabel = multilabel

self.labels = np.array(['back', 'down', 'front', 'left', 'right', 'up'])
self.sort_weights = {
'up': 1,
'down': 1,
'front': 2,
'back': 2,
'left': 3,
'right': 3,
}

self.reverse_label_map = {
'up': 1,
'down': 2,
'front': 3,
'back': 4,
'left': 5,
'right': 6,
'upfront': 7,
'upback': 8,
'upleft': 9,
'upright': 10,
'downfront': 11,
'downback': 12,
'downleft': 13,
'downright': 14,
'frontleft': 15,
'frontright': 16,
'backleft': 17,
'backright': 18,
}

if multilabel:
n_class = len(self.sort_weights)

if n_class is not None:
n_features = self.model.classifier.in_features
self.model.classifier = nn.Linear(n_features, n_class)

else:
self.model.classifier = nn.Identity(n_features, n_class)
'''
Expand All @@ -142,8 +179,44 @@ def __init__(self, n_class, model_arch='tf_efficientnet_b4_ns', pretrained=False
nn.Linear(n_features, n_class, bias=True)
)
'''

def process_row(self, row_labels, preds, sort_weights, labels):
multi_labels = labels[row_labels.astype(bool)]
preds = preds[row_labels.astype(bool)]
# Combine the multi_labels and preds into a list of tuples
label_pred_weight = [(label, pred, sort_weights[label]) for label, pred in zip(multi_labels, preds)]
# Sort by weights first, then by prediction values in descending order
label_pred_weight.sort(key=lambda x: (x[2], -x[1]))
# Create a dictionary to keep the highest value string for each weight
best_labels = {}
for label, pred, weight in label_pred_weight:
if weight not in best_labels or pred > best_labels[weight][1]:
best_labels[weight] = (label, pred)
# Extract the labels in the order of weights - top 2 labels
sorted_labels = [best_labels[weight][0] for weight in sorted(best_labels)[:2]]
return sorted_labels

def process_multilabel_preds(self, image_preds, sort_weights, labels, reverse_label_map):
multi_label_matrix = (image_preds > 0.5).cpu().numpy()
sorted_labels = [self.process_row(row, preds, sort_weights, labels) for row, preds in zip(multi_label_matrix, image_preds)]
fused_labels = [''.join(x) for x in sorted_labels]

num_labels = len(reverse_label_map)
one_hot_matrix = np.zeros((len(fused_labels), num_labels))

for i, label in enumerate(fused_labels):
if label in reverse_label_map:
one_hot_matrix[i, reverse_label_map[label] - 1] = 1

one_hot_tensor = torch.tensor(one_hot_matrix, dtype=torch.float32)

return one_hot_tensor

def forward(self, x):
x = self.model(x)
if self.multilabel:
x = self.process_multilabel_preds(x, self.sort_weights, self.labels, self.reverse_label_map)

return x


Expand Down Expand Up @@ -553,7 +626,8 @@ def test_single(filepath_list, weights_path, batch_size=1792, multi=PARALLEL, **
num_classes = len(classes)

# Initialize the model for this run
model = EfficientnetModel(n_class=num_classes)
multilabel = 'multilabel' in weights_path
model = EfficientnetModel(n_class=num_classes, multilabel=multilabel)
# num_ftrs = model.classifier.in_features
# model.classifier = nn.Linear(num_ftrs, num_classes)

Expand Down
Loading