-
Notifications
You must be signed in to change notification settings - Fork 2
/
clf_job.py
177 lines (157 loc) · 7.53 KB
/
clf_job.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from typing import Any, Optional
import pandas as pd
import torch
import numpy as np
import os
from transformers import pipeline, DistilBertTokenizer
import logging
import constants
from multi_model_inference import MultiModalModel, get_inference_data, run_inference
class CLFJob:
def __init__(
self,
bucket: str,
final_bucket: str,
minio_client: Any,
date_folder: Optional[str],
column: Optional[str],
model: Optional[str],
task: str):
self.bucket = bucket
self.final_bucket = final_bucket
self.minio_client = minio_client
self.column = column
self.date = date_folder
self.classifier = None
self.task = task
self.labels = constants.classifier_target_labels
self.hypothesis_template = constants.classifier_hypothesis_template
self.model = model
def perform_clf(self):
files = self.minio_client.list_objects_names(self.bucket, self.date)
for file in files:
filename = file.split(".")[0]
if self.final_bucket != self.bucket and self.minio_client.check_obj_exists(self.final_bucket, file):
logging.warning(f"File {file} ignored as it already present in destination")
continue
logging.info(f"Starting inference for {file}")
df = self.minio_client.read_df_parquet(bucket=self.bucket, file_name=file)
if df.empty:
logging.warning(f"File {file} is empty")
continue
if self.task == "zero-shot-classification":
df = self.get_inference(df=df)
elif self.task == "text-classification":
df = self.get_inference_for_column(df=df)
elif self.task == "both":
df = self.get_inference(df=df)
df = self.get_inference_for_column(df=df)
else:
raise ValueError(f"Task \"{self.task}\" is not available")
self.minio_client.save_df_parquet(self.final_bucket, filename, df)
def get_inference(self, df: pd.DataFrame) -> pd.DataFrame:
classifier = self.maybe_load_classifier(task="zero-shot-classification")
print("classifier loaded")
if not self.column:
self.column = "product"
df[self.column] = df['title'].fillna(df['description']).fillna(df['name'])
df[self.column] = df[self.column].fillna("")
not_empty_filter = (df[self.column] != "")
inputs = df[not_empty_filter][self.column].to_list()
results = classifier(inputs, self.labels, hypothesis_template=self.hypothesis_template)
# Set the results to a new column in the dataframe
labels = [result['labels'][0] if result is not None else np.nan for result in results]
scores = [result['scores'][0] if result is not None else np.nan for result in results]
df.loc[not_empty_filter, "label_product"] = labels
df.loc[not_empty_filter, "score_product"] = scores
return df
def get_inference_for_mmm(self, df: pd.DataFrame, bucket_name) -> pd.DataFrame:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
classifier = self.maybe_load_classifier(task="zero-shot-classification")
indices = ~df["image_path"].isnull()
inference_data = get_inference_data(df[indices], self.minio_client, bucket_name)
predictions = run_inference(classifier, inference_data, device)
df.loc[indices, 'predicted_label'] = predictions
return df
def get_inference_for_column(self, df: pd.DataFrame) -> pd.DataFrame:
labels_map = {"LABEL_0": 0, "LABEL_1": 1}
# this case cannot be zero-shot
classifier = self.maybe_load_classifier(task="text-classification")
print("classifier loaded")
# Product column
if not self.column:
self.column = "product"
df[self.column] = df['title'].fillna(df['description']).fillna(df['name'])
df[self.column] = df[self.column].fillna("")
not_empty_filter = (df[self.column] != "")
inputs = df[not_empty_filter][self.column].to_list()
results = classifier(inputs)
# Set the results to a new column in the dataframe
labels = [labels_map[result['label']] if result is not None and result['label'] in labels_map else np.nan for
result in results]
scores = [result['score'] if result is not None else np.nan for result in results]
df.loc[not_empty_filter, "label"] = labels
df.loc[not_empty_filter, "score"] = scores
print(f"Inference completed for {len(labels)} rows")
return df
def maybe_load_classifier(self, task: Optional[str]):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if self.task == "text-classification":
model = self.model or 'julesbarbosa/wildlife-classification'
if not self.classifier:
self.classifier = pipeline(self.task,
model=model,
device=device,
use_auth_token=os.environ["HUGGINGFACE_API_KEY"])
return self.classifier
elif self.task == "zero-shot-classification":
model = self.model or 'facebook/bart-large-mnli'
if not self.classifier:
self.classifier = pipeline(self.task,
model=model,
device=device,
use_auth_token=os.environ["HUGGINGFACE_API_KEY"])
return self.classifier
elif self.task == "multi-model":
# Initialize an empty model
loaded_model = MultiModalModel(num_labels=2)
# Load the state dictionary
if self.minio_client:
model_load_path = './model.pth'
self.minio_client.get_model("multimodal", "model.pth", model_load_path)
else:
model_load_path = './model/model.pth'
# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the model weights
if device == torch.device('cpu'):
loaded_model.load_state_dict(torch.load(model_load_path, map_location=device), strict=False)
else:
loaded_model.load_state_dict(torch.load(model_load_path), strict=False)
# Move model to evaluation mode and to the device
loaded_model.eval()
loaded_model = loaded_model.to(device)
self.classifier = loaded_model
return self.classifier
elif self.task == "both":
if task == "text-classification":
model = 'julesbarbosa/wildlife-classification'
else:
model = 'facebook/bart-large-mnli'
return pipeline(task,
model=model,
device=device,
use_auth_token=os.environ["HUGGINGFACE_API_KEY"])
@staticmethod
def get_label(x):
if x["label_product"] and x["label_description"]:
return None
elif x["label_product"]:
return x["label_description"]
elif x["label_description"]:
return x["label_product"]
else:
if x["score_description"] > x["score_product"]:
return x["label_description"]
else:
return x["label_product"]