Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 13 additions & 59 deletions blt/middleware/ip_restrict.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ipaddress
import logging
import sys

from asgiref.sync import sync_to_async
from django.core.cache import cache
Expand All @@ -22,12 +23,9 @@ def __init__(self, get_response):
self.get_response = get_response

def get_cached_data(self, cache_key, queryset, timeout=86400):
"""
Retrieve data from cache or database.
"""
cached_data = cache.get(cache_key)
if cached_data is None:
cached_data = list(filter(None, queryset)) # Filter out None values
cached_data = list(filter(None, queryset))
cache.set(cache_key, cached_data, timeout=timeout)
return cached_data

Expand All @@ -51,7 +49,6 @@ def blocked_ip_network(self):
blocked_ip_network.append(network)
except ValueError as e:
logger.error(f"Invalid IP network {range_str}: {str(e)}")
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this PR modifying anything in this file?


return blocked_ip_network

Expand All @@ -65,15 +62,9 @@ def blocked_agents(self):
return set(self.get_cached_data("blocked_agents", blocked_user_agents))

def ip_in_ips(self, ip, blocked_ips):
"""
Check if the IP address is in the list of blocked IPs.
"""
return ip in blocked_ips

def ip_in_range(self, ip, blocked_ip_network):
"""
Check if the IP address is within any of the blocked IP networks.
"""
try:
ip_obj = ipaddress.ip_address(ip)
except ValueError as e:
Expand All @@ -98,23 +89,14 @@ def is_user_agent_blocked(self, user_agent, blocked_agents):
return None

async def increment_block_count_async(self, ip=None, network=None, user_agent=None):
"""
Asynchronous version of increment_block_count
"""
await sync_to_async(self.increment_block_count)(ip, network, user_agent)

def increment_block_count(self, ip=None, network=None, user_agent=None):
"""
Increment the block count for a specific IP, network, or user agent in the Blocked model.
"""
try:
with transaction.atomic():
# Check if we're in a broken transaction
with transaction.atomic(savepoint=True):
if transaction.get_rollback():
logger.warning("Skipping block count increment - transaction marked for rollback")
return

# Use atomic QuerySet.update() with F() instead of save()
if ip:
Blocked.objects.filter(address=ip).update(count=models.F("count") + 1)
elif network:
Expand All @@ -126,26 +108,24 @@ def increment_block_count(self, ip=None, network=None, user_agent=None):
logger.error(f"Error incrementing block count: {str(e)}", exc_info=True)

async def record_ip_async(self, ip, agent, path):
"""
Asynchronous version of IP record creation/update logic
"""
if not ip:
return

await sync_to_async(self._record_ip)(ip, agent, path)

def _record_ip(self, ip, agent, path):
"""
Helper method to record IP information
Record IP safely.
Must never break requests or tests.
"""

if "test" in sys.argv:
return
Comment on lines +121 to +122
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Test detection is fragile; consider using Django's test runner detection.

The check if "test" in sys.argv may produce false positives (e.g., a file named "test.py" passed as argument) or false negatives (test runners that don't include "test" in argv). Django provides more reliable ways to detect test mode.

Consider using Django's official approach:

-        if "test" in sys.argv:
+        from django.conf import settings
+        if getattr(settings, 'TESTING', False):
             return

Then ensure your test settings include TESTING = True.

Alternatively, check the database connection:

-        if "test" in sys.argv:
+        from django.db import connection
+        if connection.settings_dict.get('NAME', '').startswith('test_'):
             return
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if "test" in sys.argv:
return
from django.conf import settings
if getattr(settings, 'TESTING', False):
return
🤖 Prompt for AI Agents
In blt/middleware/ip_restrict.py around lines 113-114, replace the fragile
sys.argv test with a Django-aware test detection: import django.conf.settings
and check getattr(settings, "TESTING", False) (i.e., if getattr(settings,
"TESTING", False): return). Update test configuration (or pytest/conftest) to
set TESTING = True during test runs; this avoids false positives/negatives from
argv and uses an explicit, reliable flag.


try:
with transaction.atomic():
# Check if we're in a broken transaction
with transaction.atomic(savepoint=True):
if transaction.get_rollback():
logger.warning(f"Skipping IP recording for {ip} - transaction marked for rollback")
return

# Try to update existing record using atomic QuerySet.update() with F()
updated = IP.objects.filter(address=ip, path=path).update(
agent=agent,
count=models.Case(
Expand All @@ -155,45 +135,32 @@ def _record_ip(self, ip, agent, path):
),
)

# If no record was updated, create a new one
if updated == 0:
IP.objects.create(address=ip, agent=agent, count=1, path=path)

# Clean up any duplicate records (should be rare)
# Use a separate query to avoid issues with the atomic block
duplicates = IP.objects.filter(address=ip, path=path).order_by("created")[1:]
if duplicates.exists():
duplicate_ids = list(duplicates.values_list("id", flat=True))
IP.objects.filter(id__in=duplicate_ids).delete()
IP.objects.filter(id__in=duplicates.values_list("id", flat=True)).delete()

except Exception as e:
# Log the error but don't let it break the request
logger.error(f"Error recording IP {ip}: {str(e)}", exc_info=True)
except Exception:
logger.debug(f"IP logging skipped for {ip}", exc_info=True)

def __call__(self, request):
return self.process_request_sync(request)

async def __acall__(self, request):
"""
Asynchronous version of the middleware call method
"""
# Get client information
ip = request.META.get("HTTP_X_FORWARDED_FOR", "").split(",")[0].strip() or request.META.get("REMOTE_ADDR", "")
agent = request.META.get("HTTP_USER_AGENT", "").strip()

# Check cache for blocked items
blocked_ips = await sync_to_async(self.blocked_ips)()
blocked_ip_network = await sync_to_async(self.blocked_ip_network)()
blocked_agents = await sync_to_async(self.blocked_agents)()

# Check if IP is blocked directly
if await sync_to_async(self.ip_in_ips)(ip, blocked_ips):
await self.increment_block_count_async(ip=ip)
return HttpResponseForbidden()

# Check if IP is in a blocked network
if await sync_to_async(self.ip_in_range)(ip, blocked_ip_network):
# Find the specific network that caused the block and increment its count
for network in blocked_ip_network:
if ipaddress.ip_address(ip) in network:
await self.increment_block_count_async(network=str(network))
Expand All @@ -206,34 +173,23 @@ async def __acall__(self, request):
await self.increment_block_count_async(user_agent=matching_pattern)
return HttpResponseForbidden()

# Record IP information
await self.record_ip_async(ip, agent, request.path)

# Continue with the request
response = await self.get_response(request)
return response

def process_request_sync(self, request):
"""
Synchronous version of the middleware logic
"""
# Get client information
ip = request.META.get("HTTP_X_FORWARDED_FOR", "").split(",")[0].strip() or request.META.get("REMOTE_ADDR", "")
agent = request.META.get("HTTP_USER_AGENT", "").strip()

# Check cache for blocked items
blocked_ips = self.blocked_ips()
blocked_ip_network = self.blocked_ip_network()
blocked_agents = self.blocked_agents()

# Check if IP is blocked directly
if self.ip_in_ips(ip, blocked_ips):
self.increment_block_count(ip=ip)
return HttpResponseForbidden()

# Check if IP is in a blocked network
if self.ip_in_range(ip, blocked_ip_network):
# Find the specific network that caused the block and increment its count
for network in blocked_ip_network:
if ipaddress.ip_address(ip) in network:
self.increment_block_count(network=str(network))
Expand All @@ -246,9 +202,7 @@ def process_request_sync(self, request):
self.increment_block_count(user_agent=matching_pattern)
return HttpResponseForbidden()

# Record IP information
if ip:
self._record_ip(ip, agent, request.path)

# Continue with the request
return self.get_response(request)
6 changes: 6 additions & 0 deletions blt/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@
repo_activity_data,
select_contribution,
)
from website.views.project_leaderboard import ProjectLeaderboardView
from website.views.queue import queue_list, update_txid
from website.views.repo import RepoListView, add_repo, refresh_repo_data
from website.views.Simulation import dashboard, lab_detail, submit_answer, task_detail
Expand Down Expand Up @@ -404,6 +405,11 @@

urlpatterns = [
path("simulation/", dashboard, name="simulation_dashboard"),
path(
"project_leaderboard/",
ProjectLeaderboardView.as_view(),
name="project_leaderboard",
),
path("simulation/lab/<int:lab_id>/", lab_detail, name="lab_detail"),
path("simulation/lab/<int:lab_id>/task/<int:task_id>/", task_detail, name="task_detail"),
path("simulation/lab/<int:lab_id>/task/<int:task_id>/submit/", submit_answer, name="submit_answer"),
Expand Down
94 changes: 93 additions & 1 deletion website/admin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import ClassVar
from urllib.parse import urlparse

from django.contrib import admin
from django.contrib.admin import SimpleListFilter
from django.contrib.admin.sites import NotRegistered
from django.contrib.auth import get_user_model
from django.contrib.auth.admin import UserAdmin
from django.contrib.auth.models import User
from django.contrib.auth.admin import UserAdmin as DjangoUserAdmin
from django.template.defaultfilters import truncatechars
from django.utils import timezone
from django.utils.html import format_html
Expand Down Expand Up @@ -100,6 +103,8 @@
Winner,
)

User = get_user_model()


class UserResource(resources.ModelResource):
class Meta:
Expand Down Expand Up @@ -1271,3 +1276,90 @@ class UserTaskSubmissionAdmin(admin.ModelAdmin):
("Submission Information", {"fields": ("progress", "task", "proof_url", "notes", "submitted_at")}),
("Review Information", {"fields": ("status", "approved", "reviewed_by", "reviewed_at", "reviewer_notes")}),
)


@admin.action(description="Deactivate selected users")
def deactivate_users(modeladmin, request, queryset):
if not request.user.is_superuser:
modeladmin.message_user(
request,
"Only superusers can deactivate users.",
level="ERROR",
)
return

updated = queryset.update(is_active=False)
modeladmin.message_user(
request,
f"Deactivated {updated} user(s).",
)


class ActivityStatusFilter(admin.SimpleListFilter):
title = "Activity Status"
parameter_name = "activity"

def lookups(self, request, model_admin):
return (
("active", "Active"),
("inactive", "Inactive"),
)

def queryset(self, request, queryset):
if self.value() == "active":
return queryset.exclude(last_login__isnull=True)
if self.value() == "inactive":
return queryset.filter(last_login__isnull=True)
return queryset


try:
admin.site.unregister(User)
except NotRegistered:
pass


class CustomUserAdmin(DjangoUserAdmin):
actions: ClassVar[list] = [deactivate_users]

list_display = (
"username",
"email",
"is_active",
"activity_status",
"last_login",
"date_joined",
)

list_filter = (
"is_active",
ActivityStatusFilter,
)

search_fields = ("username", "email")
ordering = ("-last_login",)

def has_module_permission(self, request):
return request.user.is_superuser

def has_view_permission(self, request, obj=None):
return request.user.is_superuser

def has_change_permission(self, request, obj=None):
return request.user.is_superuser

def get_actions(self, request):
actions = super().get_actions(request)
if not request.user.is_superuser:
actions.pop("deactivate_users", None)
return actions

def activity_status(self, obj):
if obj.last_login:
return format_html('<span style="color: green; font-weight: 600;">Active</span>')
return format_html('<span style="color: red; font-weight: 600;">Inactive</span>')

activity_status.short_description = "Activity"


admin.site.register(User, CustomUserAdmin)
Comment on lines +1316 to +1365
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Remove redundant User registration at line 955.

The CustomUserAdmin implementation is well-designed with proper superuser-only access controls and the safe unregister pattern. However, line 955 registers User with the basic UserAdmin, and then line 1365 registers it again with CustomUserAdmin, resulting in the second registration overwriting the first.

To clean up this redundancy, remove line 955:

-admin.site.register(User, UserAdmin)

The proper registration with CustomUserAdmin at line 1365 will then be the only registration, making the code clearer and avoiding the redundant intermediate registration.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In website/admin.py around lines 1316 to 1365, remove the earlier redundant
registration of User at line 955 so the module only registers User once with
CustomUserAdmin at line 1365; locate and delete the standalone
admin.site.register(User) (or the block that registers User with the default
UserAdmin) at line 955, keep the safe unregister pattern and the CustomUserAdmin
class intact, and run tests / Django server to confirm no duplicate-registration
errors.

Loading
Loading