-
Notifications
You must be signed in to change notification settings - Fork 40
Feat/Add graceful shutdown support for grpc server #53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,8 @@ | ||
| import datetime | ||
| import asyncio | ||
| import signal | ||
| import threading | ||
| import time | ||
|
|
||
| from django.core.management.base import BaseCommand | ||
| from django.utils import autoreload | ||
|
|
@@ -13,6 +16,13 @@ class Command(BaseCommand): | |
| help = "Run gRPC server" | ||
| config = getattr(settings, "GRPCSERVER", dict()) | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| # State management for graceful shutdown | ||
| self._shutdown_event = threading.Event() | ||
| self._server = None | ||
| self._original_sigterm_handler = None | ||
|
|
||
| def add_arguments(self, parser): | ||
| parser.add_argument("--max_workers", type=int, help="Number of workers") | ||
| parser.add_argument("--port", type=int, default=50051, help="Port number to listen") | ||
|
|
@@ -40,49 +50,141 @@ def handle(self, *args, **options): | |
| else: | ||
| self._serve(**options) | ||
|
|
||
| def _setup_signal_handlers(self): | ||
| """Setup signal handlers (inspired by Gunicorn arbiter.py)""" | ||
| # Store SIGTERM handler | ||
| self._original_sigterm_handler = signal.signal(signal.SIGTERM, self._handle_sigterm) | ||
|
|
||
| # Also set SIGINT handler (Ctrl+C) | ||
| signal.signal(signal.SIGINT, self._handle_sigterm) | ||
|
|
||
| self.stdout.write("Signal handlers registered for graceful shutdown") | ||
|
|
||
| def _handle_sigterm(self, signum, frame): | ||
| """Handle SIGTERM signal to start graceful shutdown""" | ||
| self.stdout.write(f"Received signal {signum}. Starting graceful shutdown...") | ||
| self._shutdown_event.set() | ||
|
Comment on lines
+63
to
+66
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added Sets the shutdown event when SIGTERM or SIGINT is received, triggering the graceful shutdown process. |
||
|
|
||
| def _graceful_shutdown(self, server): | ||
| """Gracefully shutdown the server""" | ||
| try: | ||
| # Stop accepting new connections | ||
| self.stdout.write("Stopping server from accepting new connections...") | ||
|
|
||
| # Stop gRPC server (with grace=True to wait for ongoing requests to complete) | ||
| if hasattr(server, 'stop'): | ||
| # For synchronous server | ||
| server.stop(grace=True) | ||
| else: | ||
| # For asynchronous server | ||
| asyncio.create_task(server.stop(grace=True)) | ||
|
|
||
| # Send Django signal | ||
| grpc_shutdown.send(None) | ||
|
|
||
| self.stdout.write("Graceful shutdown completed") | ||
|
|
||
| except Exception as e: | ||
| self.stderr.write(f"Error during graceful shutdown: {e}") | ||
|
Comment on lines
+68
to
+88
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added Handles the actual server shutdown process, stopping new connections, waiting for ongoing requests to complete, and sending Django signals. |
||
|
|
||
| async def _graceful_shutdown_async(self, server): | ||
| """Gracefully shutdown the async server""" | ||
| try: | ||
| # Stop accepting new connections | ||
| self.stdout.write("Stopping async server from accepting new connections...") | ||
|
|
||
| # Stop gRPC async server | ||
| await server.stop(grace=True) | ||
|
|
||
| # Send Django signal | ||
| grpc_shutdown.send(None) | ||
|
|
||
| self.stdout.write("Async graceful shutdown completed") | ||
|
|
||
| except Exception as e: | ||
| self.stderr.write(f"Error during async graceful shutdown: {e}") | ||
|
|
||
| def _serve(self, max_workers, port, *args, **kwargs): | ||
| """ | ||
| Run gRPC server | ||
| """ | ||
| autoreload.raise_last_exception() | ||
| self.stdout.write("gRPC server starting at %s" % datetime.datetime.now()) | ||
|
|
||
| # Only setup signal handlers when not in autoreload mode | ||
| # autoreload runs in a separate thread, not the main thread, so signal handlers cannot be registered | ||
| if not kwargs.get("autoreload", False): | ||
| self._setup_signal_handlers() | ||
|
|
||
| server = create_server(max_workers, port) | ||
| self._server = server | ||
|
|
||
| server.start() | ||
|
|
||
| self.stdout.write("gRPC server is listening port %s" % port) | ||
|
|
||
| if kwargs["list_handlers"] is True: | ||
| # Print handler list if list_handlers option is enabled (default: False) | ||
| if kwargs.get("list_handlers", False): | ||
| self.stdout.write("Registered handlers:") | ||
| for handler in extract_handlers(server): | ||
| self.stdout.write("* %s" % handler) | ||
|
|
||
| server.wait_for_termination() | ||
| # Send shutdown signal to all connected receivers | ||
| grpc_shutdown.send(None) | ||
| # Only execute graceful shutdown logic when not in autoreload mode | ||
| if not kwargs.get("autoreload", False): | ||
| # Wait loop for graceful shutdown | ||
| try: | ||
| while not self._shutdown_event.is_set(): | ||
| time.sleep(0.1) | ||
| except KeyboardInterrupt: | ||
| self.stdout.write("Received keyboard interrupt, starting graceful shutdown...") | ||
| self._shutdown_event.set() | ||
|
|
||
| # Perform graceful shutdown | ||
| self._graceful_shutdown(server) | ||
| else: | ||
| # Use original wait_for_termination for autoreload mode | ||
| server.wait_for_termination() | ||
| # Send shutdown signal to all connected receivers | ||
| grpc_shutdown.send(None) | ||
|
Comment on lines
107
to
+148
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modified Integrates signal handling and graceful shutdown logic into the main server serving methods |
||
|
|
||
| def _serve_async(self, max_workers, port, *args, **kwargs): | ||
| """ | ||
| Run gRPC server in async mode | ||
| """ | ||
| self.stdout.write("gRPC async server starting at %s" % datetime.datetime.now()) | ||
|
|
||
| # Only setup signal handlers when not in autoreload mode | ||
| # autoreload runs in a separate thread, not the main thread, so signal handlers cannot be registered | ||
| if not kwargs.get("autoreload", False): | ||
| self._setup_signal_handlers() | ||
|
|
||
| # Coroutines to be invoked when the event loop is shutting down. | ||
| _cleanup_coroutines = [] | ||
|
|
||
| server = create_server(max_workers, port) | ||
| self._server = server | ||
|
|
||
| async def _main_routine(): | ||
| await server.start() | ||
| self.stdout.write("gRPC async server is listening port %s" % port) | ||
|
|
||
| if kwargs["list_handlers"] is True: | ||
| # Print handler list if list_handlers option is enabled (default: False) | ||
| if kwargs.get("list_handlers", False): | ||
| self.stdout.write("Registered handlers:") | ||
| for handler in extract_handlers(server): | ||
| self.stdout.write("* %s" % handler) | ||
|
|
||
| await server.wait_for_termination() | ||
| # Only execute graceful shutdown logic when not in autoreload mode | ||
| if not kwargs.get("autoreload", False): | ||
| # Wait for graceful shutdown | ||
| while not self._shutdown_event.is_set(): | ||
| await asyncio.sleep(0.1) | ||
|
|
||
| # Perform graceful shutdown | ||
| await self._graceful_shutdown_async(server) | ||
| else: | ||
| # Use original wait_for_termination for autoreload mode | ||
| await server.wait_for_termination() | ||
|
|
||
| async def _graceful_shutdown(): | ||
| # Send the signal to all connected receivers on server shutdown. | ||
|
|
@@ -92,6 +194,14 @@ async def _graceful_shutdown(): | |
| loop = asyncio.get_event_loop() | ||
| try: | ||
| loop.run_until_complete(_main_routine()) | ||
| except KeyboardInterrupt: | ||
| if not kwargs.get("autoreload", False): | ||
| self.stdout.write("Received keyboard interrupt, starting graceful shutdown...") | ||
| self._shutdown_event.set() | ||
| loop.run_until_complete(_main_routine()) | ||
| else: | ||
| # Ignore KeyboardInterrupt in autoreload mode and exit normally | ||
| pass | ||
| finally: | ||
| loop.run_until_complete(*_cleanup_coroutines) | ||
| loop.close() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| import os | ||
| import signal | ||
| import subprocess | ||
| import time | ||
| import pytest | ||
| from django.test import TestCase | ||
| from django.core.management import call_command | ||
| from django.core.management.base import CommandError | ||
| from unittest.mock import patch, MagicMock | ||
|
|
||
|
|
||
| class GracefulShutdownTestCase(TestCase): | ||
| """Test class for graceful shutdown functionality""" | ||
|
|
||
| def setUp(self): | ||
| """Test setup""" | ||
| super().setUp() | ||
| self.port = 50052 # Test port | ||
|
|
||
| def test_signal_handler_registration(self): | ||
| """Test that signal handlers are properly registered""" | ||
| from django_grpc.management.commands.grpcserver import Command | ||
|
|
||
| command = Command() | ||
|
|
||
| # Check state before signal handler setup | ||
| self.assertIsNone(command._original_sigterm_handler) | ||
|
|
||
| # Setup signal handlers | ||
| with patch('signal.signal') as mock_signal: | ||
| command._setup_signal_handlers() | ||
|
|
||
| # Verify signal.signal was called twice (SIGTERM, SIGINT) | ||
| self.assertEqual(mock_signal.call_count, 2) | ||
|
|
||
| # Verify SIGTERM handler was saved | ||
| self.assertIsNotNone(command._original_sigterm_handler) | ||
|
|
||
| def test_sigterm_handler(self): | ||
| """Test that SIGTERM handler works correctly""" | ||
| from django_grpc.management.commands.grpcserver import Command | ||
|
|
||
| command = Command() | ||
|
|
||
| # Check initial state | ||
| self.assertFalse(command._shutdown_event.is_set()) | ||
|
|
||
| # Call SIGTERM handler | ||
| command._handle_sigterm(signal.SIGTERM, None) | ||
|
|
||
| # Verify shutdown event was set | ||
| self.assertTrue(command._shutdown_event.is_set()) | ||
|
|
||
| @patch('django_grpc.management.commands.grpcserver.create_server') | ||
| def test_graceful_shutdown_sync_server(self, mock_create_server): | ||
| """Test graceful shutdown for synchronous server""" | ||
| from django_grpc.management.commands.grpcserver import Command | ||
|
|
||
| # Create mock server | ||
| mock_server = MagicMock() | ||
| mock_create_server.return_value = mock_server | ||
|
|
||
| command = Command() | ||
|
|
||
| # Call graceful shutdown | ||
| command._graceful_shutdown(mock_server) | ||
|
|
||
| # Verify server's stop method was called with grace=True | ||
| mock_server.stop.assert_called_once_with(grace=True) | ||
|
|
||
| @patch('django_grpc.management.commands.grpcserver.create_server') | ||
| def test_graceful_shutdown_async_server(self, mock_create_server): | ||
| """Test graceful shutdown for asynchronous server""" | ||
| from django_grpc.management.commands.grpcserver import Command | ||
|
|
||
| # Create mock server (without stop method) | ||
| mock_server = MagicMock() | ||
| del mock_server.stop | ||
| mock_create_server.return_value = mock_server | ||
|
|
||
| command = Command() | ||
|
|
||
| # Call graceful shutdown | ||
| command._graceful_shutdown(mock_server) | ||
|
|
||
| # Verify asyncio.create_task was called | ||
| # (Actually difficult to verify through mock, so only check exception handling) | ||
|
|
||
| def test_command_initialization(self): | ||
| """Test that Command initialization is correct""" | ||
| from django_grpc.management.commands.grpcserver import Command | ||
|
|
||
| command = Command() | ||
|
|
||
| # Check initial state | ||
| self.assertIsNotNone(command._shutdown_event) | ||
| self.assertIsNone(command._server) | ||
| self.assertIsNone(command._original_sigterm_handler) | ||
|
|
||
| @patch('django_grpc.management.commands.grpcserver.create_server') | ||
| @patch('django_grpc.management.commands.grpcserver.signal.signal') | ||
| def test_serve_method_signal_setup(self, mock_signal, mock_create_server): | ||
| """Test that signal handlers are set up in _serve method""" | ||
| from django_grpc.management.commands.grpcserver import Command | ||
|
|
||
| # Create mock server | ||
| mock_server = MagicMock() | ||
| mock_create_server.return_value = mock_server | ||
|
|
||
| command = Command() | ||
|
|
||
| # Call _serve method (actually infinite loop, so test partially) | ||
| with patch.object(command, '_setup_signal_handlers') as mock_setup: | ||
| with patch.object(command, '_graceful_shutdown'): | ||
| # Set shutdown event in advance to exit loop | ||
| command._shutdown_event.set() | ||
| command._serve(max_workers=1, port=self.port) | ||
|
|
||
| # Verify signal handler setup was called | ||
| mock_setup.assert_called_once() | ||
|
|
||
|
|
||
| class GracefulShutdownIntegrationTestCase(TestCase): | ||
| """Integration test: graceful shutdown test with actual process""" | ||
|
|
||
| def setUp(self): | ||
| """Test setup""" | ||
| super().setUp() | ||
| self.port = 50053 # Integration test port | ||
|
|
||
| @pytest.mark.skipif( | ||
| os.name == 'nt', # Skip on Windows due to different signal handling | ||
| reason="Skip on Windows due to different signal handling" | ||
| ) | ||
| def test_sigterm_integration(self): | ||
| """Test graceful shutdown by sending actual SIGTERM signal""" | ||
| # This test is an integration test that starts an actual process and sends SIGTERM. | ||
| # Should only run in actual environment. | ||
|
|
||
| # Django settings for testing | ||
| os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'tests.settings') | ||
|
|
||
| # Start process | ||
| process = subprocess.Popen([ | ||
| 'python', 'manage.py', 'grpcserver', | ||
| '--port', str(self.port), | ||
| '--max_workers', '1' | ||
| ], stdout=subprocess.PIPE, stderr=subprocess.PIPE) | ||
|
|
||
| try: | ||
| # Wait for server to start | ||
| time.sleep(2) | ||
|
|
||
| # Send SIGTERM signal | ||
| process.send_signal(signal.SIGTERM) | ||
|
|
||
| # Wait for graceful shutdown | ||
| process.wait(timeout=10) | ||
|
|
||
| # Verify process terminated normally | ||
| self.assertEqual(process.returncode, 0) | ||
|
|
||
| except subprocess.TimeoutExpired: | ||
| # Force kill process if timeout occurs | ||
| process.kill() | ||
| process.wait() | ||
| self.fail("Graceful shutdown timed out") | ||
|
|
||
| finally: | ||
| # Force kill if process is still running | ||
| if process.poll() is None: | ||
| process.kill() | ||
| process.wait() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
_setup_signal_handlers()method to register signal handlers