Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3.8] bpo-37193: remove thread objects which finished process its request (GH-13893) #23088

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 60 additions & 13 deletions Lib/socketserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class will essentially render the service "deaf" while one request is
import os
import sys
import threading
import contextlib
from io import BufferedIOBase
from time import monotonic as time

Expand Down Expand Up @@ -628,6 +629,55 @@ def server_close(self):
self.collect_children(blocking=self.block_on_close)


class _Threads(list):
"""
Joinable list of all non-daemon threads.
"""
def __init__(self):
self._lock = threading.Lock()

def append(self, thread):
if thread.daemon:
return
with self._lock:
super().append(thread)

def remove(self, thread):
with self._lock:
# should not happen, but safe to ignore
with contextlib.suppress(ValueError):
super().remove(thread)

def remove_current(self):
"""Remove a current non-daemon thread."""
thread = threading.current_thread()
if not thread.daemon:
self.remove(thread)

def pop_all(self):
with self._lock:
self[:], result = [], self[:]
return result

def join(self):
for thread in self.pop_all():
thread.join()


class _NoThreads:
"""
Degenerate version of _Threads.
"""
def append(self, thread):
pass

def join(self):
pass

def remove_current(self):
pass


class ThreadingMixIn:
"""Mix-in class to handle each request in a new thread."""

Expand All @@ -636,9 +686,9 @@ class ThreadingMixIn:
daemon_threads = False
# If true, server_close() waits until all non-daemonic threads terminate.
block_on_close = True
# For non-daemonic threads, list of threading.Threading objects
# Threads object
# used by server_close() to wait for all threads completion.
_threads = None
_threads = _NoThreads()

def process_request_thread(self, request, client_address):
"""Same as in BaseServer but as a thread.
Expand All @@ -651,27 +701,24 @@ def process_request_thread(self, request, client_address):
except Exception:
self.handle_error(request, client_address)
finally:
self.shutdown_request(request)
try:
self.shutdown_request(request)
finally:
self._threads.remove_current()

def process_request(self, request, client_address):
"""Start a new thread to process the request."""
if self.block_on_close:
vars(self).setdefault('_threads', _Threads())
t = threading.Thread(target = self.process_request_thread,
args = (request, client_address))
t.daemon = self.daemon_threads
if not t.daemon and self.block_on_close:
if self._threads is None:
self._threads = []
self._threads.append(t)
self._threads.append(t)
t.start()

def server_close(self):
super().server_close()
if self.block_on_close:
threads = self._threads
self._threads = None
if threads:
for thread in threads:
thread.join()
self._threads.join()


if hasattr(os, "fork"):
Expand Down
24 changes: 24 additions & 0 deletions Lib/test/test_socketserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,13 @@ class MyHandler(socketserver.StreamRequestHandler):
t.join()
s.server_close()

def test_close_immediately(self):
class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
pass

server = MyServer((HOST, 0), lambda: None)
server.server_close()

def test_tcpserver_bind_leak(self):
# Issue #22435: the server socket wouldn't be closed if bind()/listen()
# failed.
Expand Down Expand Up @@ -490,6 +497,23 @@ def shutdown_request(self, request):
self.assertEqual(server.shutdown_called, 1)
server.server_close()

def test_threads_reaped(self):
"""
In #37193, users reported a memory leak
due to the saving of every request thread. Ensure that the
threads are cleaned up after the requests complete.
"""
class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
pass

server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
for n in range(10):
with socket.create_connection(server.server_address):
server.handle_request()
[thread.join() for thread in server._threads]
self.assertEqual(len(server._threads), 0)
server.server_close()


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fixed memory leak in ``socketserver.ThreadingMixIn`` introduced in Python
3.7.