| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524 |
- """
- Test suite for socketserver.
- """
- import contextlib
- import io
- import os
- import select
- import signal
- import socket
- import tempfile
- import threading
- import unittest
- import socketserver
- import test.support
- from test.support import reap_children, verbose
- from test.support import os_helper
- from test.support import socket_helper
- from test.support import threading_helper
- test.support.requires("network")
- test.support.requires_working_socket(module=True)
- TEST_STR = b"hello world\n"
- HOST = socket_helper.HOST
- HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
- requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
- 'requires Unix sockets')
- HAVE_FORKING = test.support.has_fork_support
- requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
- def signal_alarm(n):
- """Call signal.alarm when it exists (i.e. not on Windows)."""
- if hasattr(signal, 'alarm'):
- signal.alarm(n)
- # Remember real select() to avoid interferences with mocking
- _real_select = select.select
- def receive(sock, n, timeout=test.support.SHORT_TIMEOUT):
- r, w, x = _real_select([sock], [], [], timeout)
- if sock in r:
- return sock.recv(n)
- else:
- raise RuntimeError("timed out on %r" % (sock,))
- if HAVE_UNIX_SOCKETS and HAVE_FORKING:
- class ForkingUnixStreamServer(socketserver.ForkingMixIn,
- socketserver.UnixStreamServer):
- pass
- class ForkingUnixDatagramServer(socketserver.ForkingMixIn,
- socketserver.UnixDatagramServer):
- pass
- @contextlib.contextmanager
- def simple_subprocess(testcase):
- """Tests that a custom child process is not waited on (Issue 1540386)"""
- pid = os.fork()
- if pid == 0:
- # Don't raise an exception; it would be caught by the test harness.
- os._exit(72)
- try:
- yield None
- except:
- raise
- finally:
- test.support.wait_process(pid, exitcode=72)
- class SocketServerTest(unittest.TestCase):
- """Test all socket servers."""
- def setUp(self):
- signal_alarm(60) # Kill deadlocks after 60 seconds.
- self.port_seed = 0
- self.test_files = []
- def tearDown(self):
- signal_alarm(0) # Didn't deadlock.
- reap_children()
- for fn in self.test_files:
- try:
- os.remove(fn)
- except OSError:
- pass
- self.test_files[:] = []
- def pickaddr(self, proto):
- if proto == socket.AF_INET:
- return (HOST, 0)
- else:
- # XXX: We need a way to tell AF_UNIX to pick its own name
- # like AF_INET provides port==0.
- dir = None
- fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
- self.test_files.append(fn)
- return fn
- def make_server(self, addr, svrcls, hdlrbase):
- class MyServer(svrcls):
- def handle_error(self, request, client_address):
- self.close_request(request)
- raise
- class MyHandler(hdlrbase):
- def handle(self):
- line = self.rfile.readline()
- self.wfile.write(line)
- if verbose: print("creating server")
- try:
- server = MyServer(addr, MyHandler)
- except PermissionError as e:
- # Issue 29184: cannot bind() a Unix socket on Android.
- self.skipTest('Cannot create server (%s, %s): %s' %
- (svrcls, addr, e))
- self.assertEqual(server.server_address, server.socket.getsockname())
- return server
- @threading_helper.reap_threads
- def run_server(self, svrcls, hdlrbase, testfunc):
- server = self.make_server(self.pickaddr(svrcls.address_family),
- svrcls, hdlrbase)
- # We had the OS pick a port, so pull the real address out of
- # the server.
- addr = server.server_address
- if verbose:
- print("ADDR =", addr)
- print("CLASS =", svrcls)
- t = threading.Thread(
- name='%s serving' % svrcls,
- target=server.serve_forever,
- # Short poll interval to make the test finish quickly.
- # Time between requests is short enough that we won't wake
- # up spuriously too many times.
- kwargs={'poll_interval':0.01})
- t.daemon = True # In case this function raises.
- t.start()
- if verbose: print("server running")
- for i in range(3):
- if verbose: print("test client", i)
- testfunc(svrcls.address_family, addr)
- if verbose: print("waiting for server")
- server.shutdown()
- t.join()
- server.server_close()
- self.assertEqual(-1, server.socket.fileno())
- if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn):
- # bpo-31151: Check that ForkingMixIn.server_close() waits until
- # all children completed
- self.assertFalse(server.active_children)
- if verbose: print("done")
- def stream_examine(self, proto, addr):
- with socket.socket(proto, socket.SOCK_STREAM) as s:
- s.connect(addr)
- s.sendall(TEST_STR)
- buf = data = receive(s, 100)
- while data and b'\n' not in buf:
- data = receive(s, 100)
- buf += data
- self.assertEqual(buf, TEST_STR)
- def dgram_examine(self, proto, addr):
- with socket.socket(proto, socket.SOCK_DGRAM) as s:
- if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX:
- s.bind(self.pickaddr(proto))
- s.sendto(TEST_STR, addr)
- buf = data = receive(s, 100)
- while data and b'\n' not in buf:
- data = receive(s, 100)
- buf += data
- self.assertEqual(buf, TEST_STR)
- def test_TCPServer(self):
- self.run_server(socketserver.TCPServer,
- socketserver.StreamRequestHandler,
- self.stream_examine)
- def test_ThreadingTCPServer(self):
- self.run_server(socketserver.ThreadingTCPServer,
- socketserver.StreamRequestHandler,
- self.stream_examine)
- @requires_forking
- def test_ForkingTCPServer(self):
- with simple_subprocess(self):
- self.run_server(socketserver.ForkingTCPServer,
- socketserver.StreamRequestHandler,
- self.stream_examine)
- @requires_unix_sockets
- def test_UnixStreamServer(self):
- self.run_server(socketserver.UnixStreamServer,
- socketserver.StreamRequestHandler,
- self.stream_examine)
- @requires_unix_sockets
- def test_ThreadingUnixStreamServer(self):
- self.run_server(socketserver.ThreadingUnixStreamServer,
- socketserver.StreamRequestHandler,
- self.stream_examine)
- @requires_unix_sockets
- @requires_forking
- def test_ForkingUnixStreamServer(self):
- with simple_subprocess(self):
- self.run_server(ForkingUnixStreamServer,
- socketserver.StreamRequestHandler,
- self.stream_examine)
- def test_UDPServer(self):
- self.run_server(socketserver.UDPServer,
- socketserver.DatagramRequestHandler,
- self.dgram_examine)
- def test_ThreadingUDPServer(self):
- self.run_server(socketserver.ThreadingUDPServer,
- socketserver.DatagramRequestHandler,
- self.dgram_examine)
- @requires_forking
- def test_ForkingUDPServer(self):
- with simple_subprocess(self):
- self.run_server(socketserver.ForkingUDPServer,
- socketserver.DatagramRequestHandler,
- self.dgram_examine)
- @requires_unix_sockets
- def test_UnixDatagramServer(self):
- self.run_server(socketserver.UnixDatagramServer,
- socketserver.DatagramRequestHandler,
- self.dgram_examine)
- @requires_unix_sockets
- def test_ThreadingUnixDatagramServer(self):
- self.run_server(socketserver.ThreadingUnixDatagramServer,
- socketserver.DatagramRequestHandler,
- self.dgram_examine)
- @requires_unix_sockets
- @requires_forking
- def test_ForkingUnixDatagramServer(self):
- self.run_server(ForkingUnixDatagramServer,
- socketserver.DatagramRequestHandler,
- self.dgram_examine)
- @threading_helper.reap_threads
- def test_shutdown(self):
- # Issue #2302: shutdown() should always succeed in making an
- # other thread leave serve_forever().
- class MyServer(socketserver.TCPServer):
- pass
- class MyHandler(socketserver.StreamRequestHandler):
- pass
- threads = []
- for i in range(20):
- s = MyServer((HOST, 0), MyHandler)
- t = threading.Thread(
- name='MyServer serving',
- target=s.serve_forever,
- kwargs={'poll_interval':0.01})
- t.daemon = True # In case this function raises.
- threads.append((t, s))
- for t, s in threads:
- t.start()
- s.shutdown()
- for t, s in threads:
- t.join()
- s.server_close()
- def test_close_immediately(self):
- class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
- pass
- server = MyServer((HOST, 0), lambda: None)
- server.server_close()
- def test_tcpserver_bind_leak(self):
- # Issue #22435: the server socket wouldn't be closed if bind()/listen()
- # failed.
- # Create many servers for which bind() will fail, to see if this result
- # in FD exhaustion.
- for i in range(1024):
- with self.assertRaises(OverflowError):
- socketserver.TCPServer((HOST, -1),
- socketserver.StreamRequestHandler)
- def test_context_manager(self):
- with socketserver.TCPServer((HOST, 0),
- socketserver.StreamRequestHandler) as server:
- pass
- self.assertEqual(-1, server.socket.fileno())
- class ErrorHandlerTest(unittest.TestCase):
- """Test that the servers pass normal exceptions from the handler to
- handle_error(), and that exiting exceptions like SystemExit and
- KeyboardInterrupt are not passed."""
- def tearDown(self):
- os_helper.unlink(os_helper.TESTFN)
- def test_sync_handled(self):
- BaseErrorTestServer(ValueError)
- self.check_result(handled=True)
- def test_sync_not_handled(self):
- with self.assertRaises(SystemExit):
- BaseErrorTestServer(SystemExit)
- self.check_result(handled=False)
- def test_threading_handled(self):
- ThreadingErrorTestServer(ValueError)
- self.check_result(handled=True)
- def test_threading_not_handled(self):
- with threading_helper.catch_threading_exception() as cm:
- ThreadingErrorTestServer(SystemExit)
- self.check_result(handled=False)
- self.assertIs(cm.exc_type, SystemExit)
- @requires_forking
- def test_forking_handled(self):
- ForkingErrorTestServer(ValueError)
- self.check_result(handled=True)
- @requires_forking
- def test_forking_not_handled(self):
- ForkingErrorTestServer(SystemExit)
- self.check_result(handled=False)
- def check_result(self, handled):
- with open(os_helper.TESTFN) as log:
- expected = 'Handler called\n' + 'Error handled\n' * handled
- self.assertEqual(log.read(), expected)
- class BaseErrorTestServer(socketserver.TCPServer):
- def __init__(self, exception):
- self.exception = exception
- super().__init__((HOST, 0), BadHandler)
- with socket.create_connection(self.server_address):
- pass
- try:
- self.handle_request()
- finally:
- self.server_close()
- self.wait_done()
- def handle_error(self, request, client_address):
- with open(os_helper.TESTFN, 'a') as log:
- log.write('Error handled\n')
- def wait_done(self):
- pass
- class BadHandler(socketserver.BaseRequestHandler):
- def handle(self):
- with open(os_helper.TESTFN, 'a') as log:
- log.write('Handler called\n')
- raise self.server.exception('Test error')
- class ThreadingErrorTestServer(socketserver.ThreadingMixIn,
- BaseErrorTestServer):
- def __init__(self, *pos, **kw):
- self.done = threading.Event()
- super().__init__(*pos, **kw)
- def shutdown_request(self, *pos, **kw):
- super().shutdown_request(*pos, **kw)
- self.done.set()
- def wait_done(self):
- self.done.wait()
- if HAVE_FORKING:
- class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer):
- pass
- class SocketWriterTest(unittest.TestCase):
- def test_basics(self):
- class Handler(socketserver.StreamRequestHandler):
- def handle(self):
- self.server.wfile = self.wfile
- self.server.wfile_fileno = self.wfile.fileno()
- self.server.request_fileno = self.request.fileno()
- server = socketserver.TCPServer((HOST, 0), Handler)
- self.addCleanup(server.server_close)
- s = socket.socket(
- server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
- with s:
- s.connect(server.server_address)
- server.handle_request()
- self.assertIsInstance(server.wfile, io.BufferedIOBase)
- self.assertEqual(server.wfile_fileno, server.request_fileno)
- def test_write(self):
- # Test that wfile.write() sends data immediately, and that it does
- # not truncate sends when interrupted by a Unix signal
- pthread_kill = test.support.get_attribute(signal, 'pthread_kill')
- class Handler(socketserver.StreamRequestHandler):
- def handle(self):
- self.server.sent1 = self.wfile.write(b'write data\n')
- # Should be sent immediately, without requiring flush()
- self.server.received = self.rfile.readline()
- big_chunk = b'\0' * test.support.SOCK_MAX_SIZE
- self.server.sent2 = self.wfile.write(big_chunk)
- server = socketserver.TCPServer((HOST, 0), Handler)
- self.addCleanup(server.server_close)
- interrupted = threading.Event()
- def signal_handler(signum, frame):
- interrupted.set()
- original = signal.signal(signal.SIGUSR1, signal_handler)
- self.addCleanup(signal.signal, signal.SIGUSR1, original)
- response1 = None
- received2 = None
- main_thread = threading.get_ident()
- def run_client():
- s = socket.socket(server.address_family, socket.SOCK_STREAM,
- socket.IPPROTO_TCP)
- with s, s.makefile('rb') as reader:
- s.connect(server.server_address)
- nonlocal response1
- response1 = reader.readline()
- s.sendall(b'client response\n')
- reader.read(100)
- # The main thread should now be blocking in a send() syscall.
- # But in theory, it could get interrupted by other signals,
- # and then retried. So keep sending the signal in a loop, in
- # case an earlier signal happens to be delivered at an
- # inconvenient moment.
- while True:
- pthread_kill(main_thread, signal.SIGUSR1)
- if interrupted.wait(timeout=float(1)):
- break
- nonlocal received2
- received2 = len(reader.read())
- background = threading.Thread(target=run_client)
- background.start()
- server.handle_request()
- background.join()
- self.assertEqual(server.sent1, len(response1))
- self.assertEqual(response1, b'write data\n')
- self.assertEqual(server.received, b'client response\n')
- self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE)
- self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100)
- class MiscTestCase(unittest.TestCase):
- def test_all(self):
- # objects defined in the module should be in __all__
- expected = []
- for name in dir(socketserver):
- if not name.startswith('_'):
- mod_object = getattr(socketserver, name)
- if getattr(mod_object, '__module__', None) == 'socketserver':
- expected.append(name)
- self.assertCountEqual(socketserver.__all__, expected)
- def test_shutdown_request_called_if_verify_request_false(self):
- # Issue #26309: BaseServer should call shutdown_request even if
- # verify_request is False
- class MyServer(socketserver.TCPServer):
- def verify_request(self, request, client_address):
- return False
- shutdown_called = 0
- def shutdown_request(self, request):
- self.shutdown_called += 1
- socketserver.TCPServer.shutdown_request(self, request)
- server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
- s = socket.socket(server.address_family, socket.SOCK_STREAM)
- s.connect(server.server_address)
- s.close()
- server.handle_request()
- self.assertEqual(server.shutdown_called, 1)
- server.server_close()
- def test_threads_reaped(self):
- """
- In #37193, users reported a memory leak
- due to the saving of every request thread. Ensure that
- not all threads are kept forever.
- """
- class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
- pass
- server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
- for n in range(10):
- with socket.create_connection(server.server_address):
- server.handle_request()
- self.assertLess(len(server._threads), 10)
- server.server_close()
- if __name__ == "__main__":
- unittest.main()
|