Skip to content

Commit 70e2f53

Browse files
committed
feat: add ml training within Django admin
1 parent 4596744 commit 70e2f53

32 files changed

+1134
-5
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ docs/_build/
5757
target/
5858
Python.gitignore
5959
venv/
60+
.venv
6061

6162
# Notepad++ backups #
6263
*.bak

Dockerfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ FROM python:${PYTHON_VERSION}-slim-bookworm AS prod
88

99
ENV PYTHONUNBUFFERED 1
1010
ENV DJANGO_SETTINGS_MODULE=signals.settings
11+
ENV NLTK_DOWNLOAD_DIR=/tmp/nltk_data
1112
ARG DJANGO_SECRET_KEY=insecure_docker_build_key
1213

1314
WORKDIR /app
@@ -44,9 +45,11 @@ RUN set -eux; \
4445
rm -rf /var/lib/apt/lists/*
4546

4647
COPY app/requirements /app/requirements
48+
COPY app/signals/apps/classification/requirements.txt /app/signals/apps/classification/requirements.txt
4749

4850
RUN set -eux; \
4951
pip install --no-cache -r /app/requirements/requirements.txt; \
52+
pip install --no-cache -r /app/signals/apps/classification/requirements.txt; \
5053
pip install --no-cache tox; \
5154
chgrp signals /app; \
5255
chmod g+w /app; \

app/signals/apps/api/urls.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,9 @@
9393
# Status message search
9494
re_path(r'v1/private/status-messages/search/?$', StatusMessageSearchView.as_view(), name='status-message-search'),
9595

96-
# Legacy prediction proxy endpoint, still needed
97-
path('category/prediction', LegacyMlPredictCategoryView.as_view(), name='ml-tool-predict-proxy'),
96+
# # Legacy prediction proxy endpoint, still needed
97+
# path('category/prediction', LegacyMlPredictCategoryView.as_view(), name='ml-tool-predict-proxy'),
98+
path('', include('signals.apps.classification.urls')),
9899

99100
# The base routes of the API
100101
path('v1/', include(base_router.urls)),

app/signals/apps/classification/__init__.py

Whitespace-only changes.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from django.contrib import admin
2+
3+
from signals.apps.classification.admin.admins import TrainingSetAdmin, ClassifierAdmin
4+
from signals.apps.classification.models import TrainingSet
5+
from signals.apps.classification.models.classifier import Classifier
6+
7+
admin.site.register(TrainingSet, TrainingSetAdmin)
8+
admin.site.register(Classifier, ClassifierAdmin)
9+
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
import pandas as pd
2+
from django.contrib import admin, messages
3+
from django import forms
4+
from django.db.models import F
5+
from django.http import FileResponse, HttpResponse
6+
from django.urls import reverse, path
7+
from django.utils.html import format_html
8+
9+
from signals.apps.classification.models import Classifier
10+
from signals.apps.classification.tasks import train_classifier
11+
import openpyxl
12+
13+
from signals.apps.signals import workflow
14+
from signals.apps.signals.models import Category, Signal
15+
16+
17+
class RunTrainingForm(admin.helpers.ActionForm):
18+
use_signals_in_database_for_training = forms.ChoiceField(
19+
choices=((False, "Nee"), (True, "Ja")),
20+
label='Neem meldingen uit Signalen mee',
21+
required=False
22+
)
23+
24+
25+
class TrainingSetAdmin(admin.ModelAdmin):
26+
list_display = ('name', 'file', )
27+
actions = ["run_training_with_training_set"]
28+
action_form = RunTrainingForm
29+
30+
@admin.action(description="Train model met geselecteerde dataset")
31+
def run_training_with_training_set(self, request, queryset):
32+
"""
33+
Run validation, if validation fails show an error message.
34+
35+
First we validate if there are no missing columns (Main, Sub and Text column are required), after this we check if there is atleast one row of data (next
36+
to the headers)
37+
"""
38+
training_set_ids = []
39+
use_signals_in_database_for_training = request.POST['use_signals_in_database_for_training']
40+
41+
for training_set in queryset:
42+
file = training_set.file
43+
44+
wb = openpyxl.load_workbook(file)
45+
first_sheet = wb.active
46+
47+
# Check if there are any missing columns
48+
headers = [cell.value for cell in first_sheet[1]]
49+
required_columns = ["Main", "Sub", "Text"]
50+
missing_columns = [col for col in required_columns if col not in headers]
51+
52+
if missing_columns:
53+
self.message_user(
54+
request,
55+
f"Training set { training_set.name } is missing required columns: {', '.join(missing_columns)}",
56+
messages.ERROR,
57+
)
58+
59+
return
60+
61+
# Check if the training set contains any data rows
62+
data_rows = list(first_sheet.iter_rows(min_row=2, values_only=True))
63+
if not any(data_rows):
64+
self.message_user(
65+
request,
66+
f"The training set { training_set.name } does not contain any data rows.",
67+
messages.ERROR
68+
)
69+
return
70+
71+
# Check if there are no sub categories present in the training set that are not present in the database
72+
sub_col_index = headers.index("Sub")
73+
subcategory_values = {row[sub_col_index] for row in data_rows if row[sub_col_index]}
74+
existing_subcategories = set(Category.objects.filter(name__in=subcategory_values).values_list('name', flat=True))
75+
missing_subcategories = subcategory_values - existing_subcategories
76+
77+
if missing_subcategories:
78+
self.message_user(
79+
request,
80+
f"The training set {training_set.name} contains unknown sub categories: {', '.join(missing_subcategories)}. Add these to Signalen before continuing.",
81+
messages.ERROR
82+
)
83+
return
84+
85+
# Check if there are no main categories present in the training set that are not present in the database
86+
main_col_index = headers.index("Main")
87+
maincategory_values = {row[main_col_index] for row in data_rows if row[main_col_index]}
88+
existing_maincategories = set(
89+
Category.objects.filter(name__in=maincategory_values).values_list('name', flat=True))
90+
missing_maincategories = maincategory_values - existing_maincategories
91+
92+
if missing_maincategories:
93+
self.message_user(
94+
request,
95+
f"The training set {training_set.name} contains unknown main categories: {', '.join(missing_maincategories)}. Add these to Signalen before continuing.",
96+
messages.ERROR
97+
)
98+
return
99+
100+
training_set_ids.append(training_set.id)
101+
102+
# Training will fail if any subcategory or main category appears in only one signal,
103+
# when use_signals_in_database_for_training is set to True
104+
if use_signals_in_database_for_training and use_signals_in_database_for_training != "False":
105+
signals = Signal.objects.filter(
106+
status__state=workflow.AFGEHANDELD,
107+
category_assignment__category__is_active=True,
108+
category_assignment__category__parent__is_active=True
109+
).exclude(
110+
category_assignment__category__slug="overig",
111+
category_assignment__category__parent__slug="overig"
112+
).values(
113+
'text',
114+
sub_category=F('category_assignment__category__name'),
115+
main_category=F('category_assignment__category__parent__name'),
116+
)
117+
118+
data = [{
119+
"Sub": signal["sub_category"],
120+
"Main": signal["main_category"],
121+
"Text": signal["text"]
122+
} for signal in signals]
123+
124+
signals_df = pd.DataFrame(data)
125+
126+
sub_counts = signals_df['Sub'].value_counts()
127+
main_counts = signals_df['Main'].value_counts()
128+
129+
sub_issues = sub_counts[sub_counts == 1]
130+
main_issues = main_counts[main_counts == 1]
131+
132+
if sub_issues.any() or main_issues.any():
133+
parts = []
134+
135+
if not sub_issues.empty:
136+
sub_list = ", ".join(map(str, sub_issues.index))
137+
parts.append(f"sub categories: {sub_list}")
138+
139+
if not main_issues.empty:
140+
main_list = ", ".join(map(str, main_issues.index))
141+
parts.append(f"main categories: {main_list}")
142+
143+
message = (
144+
"The database contains not the minimum of two signals with " +
145+
" and ".join(parts)
146+
)
147+
148+
self.message_user(
149+
request,
150+
message,
151+
messages.ERROR
152+
)
153+
return
154+
155+
train_classifier.delay(training_set_ids, use_signals_in_database_for_training)
156+
157+
self.message_user(
158+
request,
159+
"Training of the model has been initiated. This can take a few minutes.",
160+
messages.SUCCESS,
161+
)
162+
163+
164+
class ClassifierAdmin(admin.ModelAdmin):
165+
"""
166+
Creating or disabling classifiers by hand in the Admin interface is disabled,
167+
168+
a successful training job should create his own classifier object.
169+
"""
170+
list_display = ('name', 'precision', 'recall', 'accuracy', 'is_active', )
171+
actions = ["activate_classifier"]
172+
readonly_fields = ('training_status', 'training_error', 'download_main_confusion_matrix', 'download_sub_confusion_matrix',)
173+
174+
@admin.action(description="Maak deze classifier actief")
175+
def activate_classifier(self, request, queryset):
176+
"""
177+
Make the chosen classifier active, disable other classifiers
178+
"""
179+
180+
if queryset.count() > 1:
181+
self.message_user(
182+
request,
183+
"You can only make one classifier active.",
184+
messages.ERROR
185+
)
186+
return
187+
188+
try:
189+
Classifier.objects.update(is_active=False)
190+
Classifier.objects.filter(id=queryset.first().id).update(is_active=True)
191+
192+
self.message_user(
193+
request,
194+
f"Classifier { queryset.first().name } has been activated.",
195+
messages.SUCCESS
196+
)
197+
except Exception:
198+
self.message_user(
199+
request,
200+
f"Classifier { queryset.first().name } has not been activated.",
201+
messages.ERROR
202+
)
203+
204+
fieldsets = (
205+
(None, {
206+
'fields': (
207+
'name',
208+
'download_main_confusion_matrix',
209+
'download_sub_confusion_matrix',
210+
'precision',
211+
'recall',
212+
'accuracy',
213+
'is_active',
214+
'training_status',
215+
'training_error',
216+
)
217+
}),
218+
)
219+
220+
def download_main_confusion_matrix(self, obj):
221+
if obj.main_confusion_matrix:
222+
url = reverse('admin:classification_classifier_download', args=[obj.pk, 'main_confusion_matrix'])
223+
224+
return format_html(
225+
'<a href="{}" class="button" style="padding:6px 12px; background:#007bff; color:white; border-radius:4px;">Download main confusion matrix</a>',
226+
url
227+
)
228+
return "No file found"
229+
230+
def download_sub_confusion_matrix(self, obj):
231+
if obj.sub_confusion_matrix:
232+
url = reverse('admin:classification_classifier_download', args=[obj.pk, 'sub_confusion_matrix'])
233+
234+
return format_html(
235+
'<a href="{}" class="button" style="padding:6px 12px; background:#007bff; color:white; border-radius:4px;">Download sub confusion matrix</a>',
236+
url
237+
)
238+
return "No file found"
239+
240+
def get_urls(self):
241+
urls = super().get_urls()
242+
my_urls = [
243+
path('<path:object_id>/download/<str:field_name>/',
244+
self.admin_site.admin_view(self.download_file),
245+
name='classification_classifier_download'),
246+
]
247+
return my_urls + urls
248+
249+
def download_file(self, request, object_id, field_name):
250+
obj = self.get_object(request, object_id)
251+
file_field = getattr(obj, field_name)
252+
if file_field:
253+
response = FileResponse(file_field.open('rb'))
254+
response['Content-Disposition'] = f'attachment; filename="{file_field.name.split("/")[-1]}"'
255+
return response
256+
return HttpResponse("File not found", status=404)
257+
258+
def has_add_permission(self, request):
259+
return False
260+
261+
def has_change_permission(self, request, obj=None):
262+
return False
263+
264+
def has_delete_permission(self, request, obj=None):
265+
return True
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from django.apps import AppConfig
2+
3+
4+
class ClassificationConfig(AppConfig):
5+
name = 'signals.apps.classification'
6+
verbose_name = 'Classificatie management'

app/signals/apps/classification/management/__init__.py

Whitespace-only changes.

app/signals/apps/classification/management/commands/__init__.py

Whitespace-only changes.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from django.core.management.base import BaseCommand, CommandError
2+
3+
from signals.apps.classification.models import TrainingSet
4+
from signals.apps.classification.tasks import train_classifier
5+
6+
class Command(BaseCommand):
7+
help = "Train specific model"
8+
9+
def add_arguments(self, parser):
10+
parser.add_argument("training_set_id", type=int)
11+
12+
def handle(self, *args, **options):
13+
try:
14+
training_set = TrainingSet.objects.get(pk=options["training_set_id"])
15+
except TrainingSet.DoesNotExist:
16+
raise CommandError('Training Set "%s" does not exist' % options["training_set_id"])
17+
18+
train_classifier(training_set.id)
19+

0 commit comments

Comments
 (0)