Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 116 additions & 6 deletions django_grpc/management/commands/grpcserver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import datetime
import asyncio
import signal
import threading
import time

from django.core.management.base import BaseCommand
from django.utils import autoreload
Expand All @@ -13,6 +16,13 @@ class Command(BaseCommand):
help = "Run gRPC server"
config = getattr(settings, "GRPCSERVER", dict())

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# State management for graceful shutdown
self._shutdown_event = threading.Event()
self._server = None
self._original_sigterm_handler = None

def add_arguments(self, parser):
parser.add_argument("--max_workers", type=int, help="Number of workers")
parser.add_argument("--port", type=int, default=50051, help="Port number to listen")
Expand Down Expand Up @@ -40,49 +50,141 @@ def handle(self, *args, **options):
else:
self._serve(**options)

def _setup_signal_handlers(self):
"""Setup signal handlers (inspired by Gunicorn arbiter.py)"""
# Store SIGTERM handler
self._original_sigterm_handler = signal.signal(signal.SIGTERM, self._handle_sigterm)

# Also set SIGINT handler (Ctrl+C)
signal.signal(signal.SIGINT, self._handle_sigterm)

self.stdout.write("Signal handlers registered for graceful shutdown")
Comment on lines +53 to +61
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added _setup_signal_handlers() method to register signal handlers


def _handle_sigterm(self, signum, frame):
"""Handle SIGTERM signal to start graceful shutdown"""
self.stdout.write(f"Received signal {signum}. Starting graceful shutdown...")
self._shutdown_event.set()
Comment on lines +63 to +66
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added _handle_sigterm() method to handle shutdown signals

Sets the shutdown event when SIGTERM or SIGINT is received, triggering the graceful shutdown process.


def _graceful_shutdown(self, server):
"""Gracefully shutdown the server"""
try:
# Stop accepting new connections
self.stdout.write("Stopping server from accepting new connections...")

# Stop gRPC server (with grace=True to wait for ongoing requests to complete)
if hasattr(server, 'stop'):
# For synchronous server
server.stop(grace=True)
else:
# For asynchronous server
asyncio.create_task(server.stop(grace=True))

# Send Django signal
grpc_shutdown.send(None)

self.stdout.write("Graceful shutdown completed")

except Exception as e:
self.stderr.write(f"Error during graceful shutdown: {e}")
Comment on lines +68 to +88
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added _graceful_shutdown() and _graceful_shutdown_async() methods for server cleanup

Handles the actual server shutdown process, stopping new connections, waiting for ongoing requests to complete, and sending Django signals.


async def _graceful_shutdown_async(self, server):
"""Gracefully shutdown the async server"""
try:
# Stop accepting new connections
self.stdout.write("Stopping async server from accepting new connections...")

# Stop gRPC async server
await server.stop(grace=True)

# Send Django signal
grpc_shutdown.send(None)

self.stdout.write("Async graceful shutdown completed")

except Exception as e:
self.stderr.write(f"Error during async graceful shutdown: {e}")

def _serve(self, max_workers, port, *args, **kwargs):
"""
Run gRPC server
"""
autoreload.raise_last_exception()
self.stdout.write("gRPC server starting at %s" % datetime.datetime.now())

# Only setup signal handlers when not in autoreload mode
# autoreload runs in a separate thread, not the main thread, so signal handlers cannot be registered
if not kwargs.get("autoreload", False):
self._setup_signal_handlers()

server = create_server(max_workers, port)
self._server = server

server.start()

self.stdout.write("gRPC server is listening port %s" % port)

if kwargs["list_handlers"] is True:
# Print handler list if list_handlers option is enabled (default: False)
if kwargs.get("list_handlers", False):
self.stdout.write("Registered handlers:")
for handler in extract_handlers(server):
self.stdout.write("* %s" % handler)

server.wait_for_termination()
# Send shutdown signal to all connected receivers
grpc_shutdown.send(None)
# Only execute graceful shutdown logic when not in autoreload mode
if not kwargs.get("autoreload", False):
# Wait loop for graceful shutdown
try:
while not self._shutdown_event.is_set():
time.sleep(0.1)
except KeyboardInterrupt:
self.stdout.write("Received keyboard interrupt, starting graceful shutdown...")
self._shutdown_event.set()

# Perform graceful shutdown
self._graceful_shutdown(server)
else:
# Use original wait_for_termination for autoreload mode
server.wait_for_termination()
# Send shutdown signal to all connected receivers
grpc_shutdown.send(None)
Comment on lines 107 to +148
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified _serve() and _serve_async() methods to include graceful shutdown logic

Integrates signal handling and graceful shutdown logic into the main server serving methods


def _serve_async(self, max_workers, port, *args, **kwargs):
"""
Run gRPC server in async mode
"""
self.stdout.write("gRPC async server starting at %s" % datetime.datetime.now())

# Only setup signal handlers when not in autoreload mode
# autoreload runs in a separate thread, not the main thread, so signal handlers cannot be registered
if not kwargs.get("autoreload", False):
self._setup_signal_handlers()

# Coroutines to be invoked when the event loop is shutting down.
_cleanup_coroutines = []

server = create_server(max_workers, port)
self._server = server

async def _main_routine():
await server.start()
self.stdout.write("gRPC async server is listening port %s" % port)

if kwargs["list_handlers"] is True:
# Print handler list if list_handlers option is enabled (default: False)
if kwargs.get("list_handlers", False):
self.stdout.write("Registered handlers:")
for handler in extract_handlers(server):
self.stdout.write("* %s" % handler)

await server.wait_for_termination()
# Only execute graceful shutdown logic when not in autoreload mode
if not kwargs.get("autoreload", False):
# Wait for graceful shutdown
while not self._shutdown_event.is_set():
await asyncio.sleep(0.1)

# Perform graceful shutdown
await self._graceful_shutdown_async(server)
else:
# Use original wait_for_termination for autoreload mode
await server.wait_for_termination()

async def _graceful_shutdown():
# Send the signal to all connected receivers on server shutdown.
Expand All @@ -92,6 +194,14 @@ async def _graceful_shutdown():
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(_main_routine())
except KeyboardInterrupt:
if not kwargs.get("autoreload", False):
self.stdout.write("Received keyboard interrupt, starting graceful shutdown...")
self._shutdown_event.set()
loop.run_until_complete(_main_routine())
else:
# Ignore KeyboardInterrupt in autoreload mode and exit normally
pass
finally:
loop.run_until_complete(*_cleanup_coroutines)
loop.close()
173 changes: 173 additions & 0 deletions tests/test_graceful_shutdown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import os
import signal
import subprocess
import time
import pytest
from django.test import TestCase
from django.core.management import call_command
from django.core.management.base import CommandError
from unittest.mock import patch, MagicMock


class GracefulShutdownTestCase(TestCase):
"""Test class for graceful shutdown functionality"""

def setUp(self):
"""Test setup"""
super().setUp()
self.port = 50052 # Test port

def test_signal_handler_registration(self):
"""Test that signal handlers are properly registered"""
from django_grpc.management.commands.grpcserver import Command

command = Command()

# Check state before signal handler setup
self.assertIsNone(command._original_sigterm_handler)

# Setup signal handlers
with patch('signal.signal') as mock_signal:
command._setup_signal_handlers()

# Verify signal.signal was called twice (SIGTERM, SIGINT)
self.assertEqual(mock_signal.call_count, 2)

# Verify SIGTERM handler was saved
self.assertIsNotNone(command._original_sigterm_handler)

def test_sigterm_handler(self):
"""Test that SIGTERM handler works correctly"""
from django_grpc.management.commands.grpcserver import Command

command = Command()

# Check initial state
self.assertFalse(command._shutdown_event.is_set())

# Call SIGTERM handler
command._handle_sigterm(signal.SIGTERM, None)

# Verify shutdown event was set
self.assertTrue(command._shutdown_event.is_set())

@patch('django_grpc.management.commands.grpcserver.create_server')
def test_graceful_shutdown_sync_server(self, mock_create_server):
"""Test graceful shutdown for synchronous server"""
from django_grpc.management.commands.grpcserver import Command

# Create mock server
mock_server = MagicMock()
mock_create_server.return_value = mock_server

command = Command()

# Call graceful shutdown
command._graceful_shutdown(mock_server)

# Verify server's stop method was called with grace=True
mock_server.stop.assert_called_once_with(grace=True)

@patch('django_grpc.management.commands.grpcserver.create_server')
def test_graceful_shutdown_async_server(self, mock_create_server):
"""Test graceful shutdown for asynchronous server"""
from django_grpc.management.commands.grpcserver import Command

# Create mock server (without stop method)
mock_server = MagicMock()
del mock_server.stop
mock_create_server.return_value = mock_server

command = Command()

# Call graceful shutdown
command._graceful_shutdown(mock_server)

# Verify asyncio.create_task was called
# (Actually difficult to verify through mock, so only check exception handling)

def test_command_initialization(self):
"""Test that Command initialization is correct"""
from django_grpc.management.commands.grpcserver import Command

command = Command()

# Check initial state
self.assertIsNotNone(command._shutdown_event)
self.assertIsNone(command._server)
self.assertIsNone(command._original_sigterm_handler)

@patch('django_grpc.management.commands.grpcserver.create_server')
@patch('django_grpc.management.commands.grpcserver.signal.signal')
def test_serve_method_signal_setup(self, mock_signal, mock_create_server):
"""Test that signal handlers are set up in _serve method"""
from django_grpc.management.commands.grpcserver import Command

# Create mock server
mock_server = MagicMock()
mock_create_server.return_value = mock_server

command = Command()

# Call _serve method (actually infinite loop, so test partially)
with patch.object(command, '_setup_signal_handlers') as mock_setup:
with patch.object(command, '_graceful_shutdown'):
# Set shutdown event in advance to exit loop
command._shutdown_event.set()
command._serve(max_workers=1, port=self.port)

# Verify signal handler setup was called
mock_setup.assert_called_once()


class GracefulShutdownIntegrationTestCase(TestCase):
"""Integration test: graceful shutdown test with actual process"""

def setUp(self):
"""Test setup"""
super().setUp()
self.port = 50053 # Integration test port

@pytest.mark.skipif(
os.name == 'nt', # Skip on Windows due to different signal handling
reason="Skip on Windows due to different signal handling"
)
def test_sigterm_integration(self):
"""Test graceful shutdown by sending actual SIGTERM signal"""
# This test is an integration test that starts an actual process and sends SIGTERM.
# Should only run in actual environment.

# Django settings for testing
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'tests.settings')

# Start process
process = subprocess.Popen([
'python', 'manage.py', 'grpcserver',
'--port', str(self.port),
'--max_workers', '1'
], stdout=subprocess.PIPE, stderr=subprocess.PIPE)

try:
# Wait for server to start
time.sleep(2)

# Send SIGTERM signal
process.send_signal(signal.SIGTERM)

# Wait for graceful shutdown
process.wait(timeout=10)

# Verify process terminated normally
self.assertEqual(process.returncode, 0)

except subprocess.TimeoutExpired:
# Force kill process if timeout occurs
process.kill()
process.wait()
self.fail("Graceful shutdown timed out")

finally:
# Force kill if process is still running
if process.poll() is None:
process.kill()
process.wait()