functional.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. import asyncio
  2. import asyncio.events
  3. import contextlib
  4. import os
  5. import pprint
  6. import select
  7. import socket
  8. import tempfile
  9. import threading
  10. from test import support
  11. class FunctionalTestCaseMixin:
  12. def new_loop(self):
  13. return asyncio.new_event_loop()
  14. def run_loop_briefly(self, *, delay=0.01):
  15. self.loop.run_until_complete(asyncio.sleep(delay))
  16. def loop_exception_handler(self, loop, context):
  17. self.__unhandled_exceptions.append(context)
  18. self.loop.default_exception_handler(context)
  19. def setUp(self):
  20. self.loop = self.new_loop()
  21. asyncio.set_event_loop(None)
  22. self.loop.set_exception_handler(self.loop_exception_handler)
  23. self.__unhandled_exceptions = []
  24. def tearDown(self):
  25. try:
  26. self.loop.close()
  27. if self.__unhandled_exceptions:
  28. print('Unexpected calls to loop.call_exception_handler():')
  29. pprint.pprint(self.__unhandled_exceptions)
  30. self.fail('unexpected calls to loop.call_exception_handler()')
  31. finally:
  32. asyncio.set_event_loop(None)
  33. self.loop = None
  34. def tcp_server(self, server_prog, *,
  35. family=socket.AF_INET,
  36. addr=None,
  37. timeout=support.LOOPBACK_TIMEOUT,
  38. backlog=1,
  39. max_clients=10):
  40. if addr is None:
  41. if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX:
  42. with tempfile.NamedTemporaryFile() as tmp:
  43. addr = tmp.name
  44. else:
  45. addr = ('127.0.0.1', 0)
  46. sock = socket.create_server(addr, family=family, backlog=backlog)
  47. if timeout is None:
  48. raise RuntimeError('timeout is required')
  49. if timeout <= 0:
  50. raise RuntimeError('only blocking sockets are supported')
  51. sock.settimeout(timeout)
  52. return TestThreadedServer(
  53. self, sock, server_prog, timeout, max_clients)
  54. def tcp_client(self, client_prog,
  55. family=socket.AF_INET,
  56. timeout=support.LOOPBACK_TIMEOUT):
  57. sock = socket.socket(family, socket.SOCK_STREAM)
  58. if timeout is None:
  59. raise RuntimeError('timeout is required')
  60. if timeout <= 0:
  61. raise RuntimeError('only blocking sockets are supported')
  62. sock.settimeout(timeout)
  63. return TestThreadedClient(
  64. self, sock, client_prog, timeout)
  65. def unix_server(self, *args, **kwargs):
  66. if not hasattr(socket, 'AF_UNIX'):
  67. raise NotImplementedError
  68. return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
  69. def unix_client(self, *args, **kwargs):
  70. if not hasattr(socket, 'AF_UNIX'):
  71. raise NotImplementedError
  72. return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
  73. @contextlib.contextmanager
  74. def unix_sock_name(self):
  75. with tempfile.TemporaryDirectory() as td:
  76. fn = os.path.join(td, 'sock')
  77. try:
  78. yield fn
  79. finally:
  80. try:
  81. os.unlink(fn)
  82. except OSError:
  83. pass
  84. def _abort_socket_test(self, ex):
  85. try:
  86. self.loop.stop()
  87. finally:
  88. self.fail(ex)
  89. ##############################################################################
  90. # Socket Testing Utilities
  91. ##############################################################################
  92. class TestSocketWrapper:
  93. def __init__(self, sock):
  94. self.__sock = sock
  95. def recv_all(self, n):
  96. buf = b''
  97. while len(buf) < n:
  98. data = self.recv(n - len(buf))
  99. if data == b'':
  100. raise ConnectionAbortedError
  101. buf += data
  102. return buf
  103. def start_tls(self, ssl_context, *,
  104. server_side=False,
  105. server_hostname=None):
  106. ssl_sock = ssl_context.wrap_socket(
  107. self.__sock, server_side=server_side,
  108. server_hostname=server_hostname,
  109. do_handshake_on_connect=False)
  110. try:
  111. ssl_sock.do_handshake()
  112. except:
  113. ssl_sock.close()
  114. raise
  115. finally:
  116. self.__sock.close()
  117. self.__sock = ssl_sock
  118. def __getattr__(self, name):
  119. return getattr(self.__sock, name)
  120. def __repr__(self):
  121. return '<{} {!r}>'.format(type(self).__name__, self.__sock)
  122. class SocketThread(threading.Thread):
  123. def stop(self):
  124. self._active = False
  125. self.join()
  126. def __enter__(self):
  127. self.start()
  128. return self
  129. def __exit__(self, *exc):
  130. self.stop()
  131. class TestThreadedClient(SocketThread):
  132. def __init__(self, test, sock, prog, timeout):
  133. threading.Thread.__init__(self, None, None, 'test-client')
  134. self.daemon = True
  135. self._timeout = timeout
  136. self._sock = sock
  137. self._active = True
  138. self._prog = prog
  139. self._test = test
  140. def run(self):
  141. try:
  142. self._prog(TestSocketWrapper(self._sock))
  143. except Exception as ex:
  144. self._test._abort_socket_test(ex)
  145. class TestThreadedServer(SocketThread):
  146. def __init__(self, test, sock, prog, timeout, max_clients):
  147. threading.Thread.__init__(self, None, None, 'test-server')
  148. self.daemon = True
  149. self._clients = 0
  150. self._finished_clients = 0
  151. self._max_clients = max_clients
  152. self._timeout = timeout
  153. self._sock = sock
  154. self._active = True
  155. self._prog = prog
  156. self._s1, self._s2 = socket.socketpair()
  157. self._s1.setblocking(False)
  158. self._test = test
  159. def stop(self):
  160. try:
  161. if self._s2 and self._s2.fileno() != -1:
  162. try:
  163. self._s2.send(b'stop')
  164. except OSError:
  165. pass
  166. finally:
  167. super().stop()
  168. def run(self):
  169. try:
  170. with self._sock:
  171. self._sock.setblocking(False)
  172. self._run()
  173. finally:
  174. self._s1.close()
  175. self._s2.close()
  176. def _run(self):
  177. while self._active:
  178. if self._clients >= self._max_clients:
  179. return
  180. r, w, x = select.select(
  181. [self._sock, self._s1], [], [], self._timeout)
  182. if self._s1 in r:
  183. return
  184. if self._sock in r:
  185. try:
  186. conn, addr = self._sock.accept()
  187. except BlockingIOError:
  188. continue
  189. except TimeoutError:
  190. if not self._active:
  191. return
  192. else:
  193. raise
  194. else:
  195. self._clients += 1
  196. conn.settimeout(self._timeout)
  197. try:
  198. with conn:
  199. self._handle_client(conn)
  200. except Exception as ex:
  201. self._active = False
  202. try:
  203. raise
  204. finally:
  205. self._test._abort_socket_test(ex)
  206. def _handle_client(self, sock):
  207. self._prog(TestSocketWrapper(sock))
  208. @property
  209. def addr(self):
  210. return self._sock.getsockname()