utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608
  1. """Utilities shared by tests."""
  2. import asyncio
  3. import collections
  4. import contextlib
  5. import io
  6. import logging
  7. import os
  8. import re
  9. import selectors
  10. import socket
  11. import socketserver
  12. import sys
  13. import tempfile
  14. import threading
  15. import time
  16. import unittest
  17. import weakref
  18. from unittest import mock
  19. from http.server import HTTPServer
  20. from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
  21. try:
  22. import ssl
  23. except ImportError: # pragma: no cover
  24. ssl = None
  25. from asyncio import base_events
  26. from asyncio import events
  27. from asyncio import format_helpers
  28. from asyncio import futures
  29. from asyncio import tasks
  30. from asyncio.log import logger
  31. from test import support
  32. from test.support import threading_helper
  33. def data_file(filename):
  34. if hasattr(support, 'TEST_HOME_DIR'):
  35. fullname = os.path.join(support.TEST_HOME_DIR, filename)
  36. if os.path.isfile(fullname):
  37. return fullname
  38. fullname = os.path.join(os.path.dirname(__file__), '..', filename)
  39. if os.path.isfile(fullname):
  40. return fullname
  41. raise FileNotFoundError(filename)
  42. ONLYCERT = data_file('ssl_cert.pem')
  43. ONLYKEY = data_file('ssl_key.pem')
  44. SIGNED_CERTFILE = data_file('keycert3.pem')
  45. SIGNING_CA = data_file('pycacert.pem')
  46. PEERCERT = {
  47. 'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
  48. 'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
  49. 'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
  50. 'issuer': ((('countryName', 'XY'),),
  51. (('organizationName', 'Python Software Foundation CA'),),
  52. (('commonName', 'our-ca-server'),)),
  53. 'notAfter': 'Oct 28 14:23:16 2037 GMT',
  54. 'notBefore': 'Aug 29 14:23:16 2018 GMT',
  55. 'serialNumber': 'CB2D80995A69525C',
  56. 'subject': ((('countryName', 'XY'),),
  57. (('localityName', 'Castle Anthrax'),),
  58. (('organizationName', 'Python Software Foundation'),),
  59. (('commonName', 'localhost'),)),
  60. 'subjectAltName': (('DNS', 'localhost'),),
  61. 'version': 3
  62. }
  63. def simple_server_sslcontext():
  64. server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  65. server_context.load_cert_chain(ONLYCERT, ONLYKEY)
  66. server_context.check_hostname = False
  67. server_context.verify_mode = ssl.CERT_NONE
  68. return server_context
  69. def simple_client_sslcontext(*, disable_verify=True):
  70. client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
  71. client_context.check_hostname = False
  72. if disable_verify:
  73. client_context.verify_mode = ssl.CERT_NONE
  74. return client_context
  75. def dummy_ssl_context():
  76. if ssl is None:
  77. return None
  78. else:
  79. return simple_client_sslcontext(disable_verify=True)
  80. def run_briefly(loop):
  81. async def once():
  82. pass
  83. gen = once()
  84. t = loop.create_task(gen)
  85. # Don't log a warning if the task is not done after run_until_complete().
  86. # It occurs if the loop is stopped or if a task raises a BaseException.
  87. t._log_destroy_pending = False
  88. try:
  89. loop.run_until_complete(t)
  90. finally:
  91. gen.close()
  92. def run_until(loop, pred, timeout=support.SHORT_TIMEOUT):
  93. deadline = time.monotonic() + timeout
  94. while not pred():
  95. if timeout is not None:
  96. timeout = deadline - time.monotonic()
  97. if timeout <= 0:
  98. raise futures.TimeoutError()
  99. loop.run_until_complete(tasks.sleep(0.001))
  100. def run_once(loop):
  101. """Legacy API to run once through the event loop.
  102. This is the recommended pattern for test code. It will poll the
  103. selector once and run all callbacks scheduled in response to I/O
  104. events.
  105. """
  106. loop.call_soon(loop.stop)
  107. loop.run_forever()
  108. class SilentWSGIRequestHandler(WSGIRequestHandler):
  109. def get_stderr(self):
  110. return io.StringIO()
  111. def log_message(self, format, *args):
  112. pass
  113. class SilentWSGIServer(WSGIServer):
  114. request_timeout = support.LOOPBACK_TIMEOUT
  115. def get_request(self):
  116. request, client_addr = super().get_request()
  117. request.settimeout(self.request_timeout)
  118. return request, client_addr
  119. def handle_error(self, request, client_address):
  120. pass
  121. class SSLWSGIServerMixin:
  122. def finish_request(self, request, client_address):
  123. # The relative location of our test directory (which
  124. # contains the ssl key and certificate files) differs
  125. # between the stdlib and stand-alone asyncio.
  126. # Prefer our own if we can find it.
  127. context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  128. context.load_cert_chain(ONLYCERT, ONLYKEY)
  129. ssock = context.wrap_socket(request, server_side=True)
  130. try:
  131. self.RequestHandlerClass(ssock, client_address, self)
  132. ssock.close()
  133. except OSError:
  134. # maybe socket has been closed by peer
  135. pass
  136. class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
  137. pass
  138. def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
  139. def loop(environ):
  140. size = int(environ['CONTENT_LENGTH'])
  141. while size:
  142. data = environ['wsgi.input'].read(min(size, 0x10000))
  143. yield data
  144. size -= len(data)
  145. def app(environ, start_response):
  146. status = '200 OK'
  147. headers = [('Content-type', 'text/plain')]
  148. start_response(status, headers)
  149. if environ['PATH_INFO'] == '/loop':
  150. return loop(environ)
  151. else:
  152. return [b'Test message']
  153. # Run the test WSGI server in a separate thread in order not to
  154. # interfere with event handling in the main thread
  155. server_class = server_ssl_cls if use_ssl else server_cls
  156. httpd = server_class(address, SilentWSGIRequestHandler)
  157. httpd.set_app(app)
  158. httpd.address = httpd.server_address
  159. server_thread = threading.Thread(
  160. target=lambda: httpd.serve_forever(poll_interval=0.05))
  161. server_thread.start()
  162. try:
  163. yield httpd
  164. finally:
  165. httpd.shutdown()
  166. httpd.server_close()
  167. server_thread.join()
  168. if hasattr(socket, 'AF_UNIX'):
  169. class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
  170. def server_bind(self):
  171. socketserver.UnixStreamServer.server_bind(self)
  172. self.server_name = '127.0.0.1'
  173. self.server_port = 80
  174. class UnixWSGIServer(UnixHTTPServer, WSGIServer):
  175. request_timeout = support.LOOPBACK_TIMEOUT
  176. def server_bind(self):
  177. UnixHTTPServer.server_bind(self)
  178. self.setup_environ()
  179. def get_request(self):
  180. request, client_addr = super().get_request()
  181. request.settimeout(self.request_timeout)
  182. # Code in the stdlib expects that get_request
  183. # will return a socket and a tuple (host, port).
  184. # However, this isn't true for UNIX sockets,
  185. # as the second return value will be a path;
  186. # hence we return some fake data sufficient
  187. # to get the tests going
  188. return request, ('127.0.0.1', '')
  189. class SilentUnixWSGIServer(UnixWSGIServer):
  190. def handle_error(self, request, client_address):
  191. pass
  192. class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
  193. pass
  194. def gen_unix_socket_path():
  195. with tempfile.NamedTemporaryFile() as file:
  196. return file.name
  197. @contextlib.contextmanager
  198. def unix_socket_path():
  199. path = gen_unix_socket_path()
  200. try:
  201. yield path
  202. finally:
  203. try:
  204. os.unlink(path)
  205. except OSError:
  206. pass
  207. @contextlib.contextmanager
  208. def run_test_unix_server(*, use_ssl=False):
  209. with unix_socket_path() as path:
  210. yield from _run_test_server(address=path, use_ssl=use_ssl,
  211. server_cls=SilentUnixWSGIServer,
  212. server_ssl_cls=UnixSSLWSGIServer)
  213. @contextlib.contextmanager
  214. def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
  215. yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
  216. server_cls=SilentWSGIServer,
  217. server_ssl_cls=SSLWSGIServer)
  218. def echo_datagrams(sock):
  219. while True:
  220. data, addr = sock.recvfrom(4096)
  221. if data == b'STOP':
  222. sock.close()
  223. break
  224. else:
  225. sock.sendto(data, addr)
  226. @contextlib.contextmanager
  227. def run_udp_echo_server(*, host='127.0.0.1', port=0):
  228. addr_info = socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
  229. family, type, proto, _, sockaddr = addr_info[0]
  230. sock = socket.socket(family, type, proto)
  231. sock.bind((host, port))
  232. thread = threading.Thread(target=lambda: echo_datagrams(sock))
  233. thread.start()
  234. try:
  235. yield sock.getsockname()
  236. finally:
  237. sock.sendto(b'STOP', sock.getsockname())
  238. thread.join()
  239. def make_test_protocol(base):
  240. dct = {}
  241. for name in dir(base):
  242. if name.startswith('__') and name.endswith('__'):
  243. # skip magic names
  244. continue
  245. dct[name] = MockCallback(return_value=None)
  246. return type('TestProtocol', (base,) + base.__bases__, dct)()
  247. class TestSelector(selectors.BaseSelector):
  248. def __init__(self):
  249. self.keys = {}
  250. def register(self, fileobj, events, data=None):
  251. key = selectors.SelectorKey(fileobj, 0, events, data)
  252. self.keys[fileobj] = key
  253. return key
  254. def unregister(self, fileobj):
  255. return self.keys.pop(fileobj)
  256. def select(self, timeout):
  257. return []
  258. def get_map(self):
  259. return self.keys
  260. class TestLoop(base_events.BaseEventLoop):
  261. """Loop for unittests.
  262. It manages self time directly.
  263. If something scheduled to be executed later then
  264. on next loop iteration after all ready handlers done
  265. generator passed to __init__ is calling.
  266. Generator should be like this:
  267. def gen():
  268. ...
  269. when = yield ...
  270. ... = yield time_advance
  271. Value returned by yield is absolute time of next scheduled handler.
  272. Value passed to yield is time advance to move loop's time forward.
  273. """
  274. def __init__(self, gen=None):
  275. super().__init__()
  276. if gen is None:
  277. def gen():
  278. yield
  279. self._check_on_close = False
  280. else:
  281. self._check_on_close = True
  282. self._gen = gen()
  283. next(self._gen)
  284. self._time = 0
  285. self._clock_resolution = 1e-9
  286. self._timers = []
  287. self._selector = TestSelector()
  288. self.readers = {}
  289. self.writers = {}
  290. self.reset_counters()
  291. self._transports = weakref.WeakValueDictionary()
  292. def time(self):
  293. return self._time
  294. def advance_time(self, advance):
  295. """Move test time forward."""
  296. if advance:
  297. self._time += advance
  298. def close(self):
  299. super().close()
  300. if self._check_on_close:
  301. try:
  302. self._gen.send(0)
  303. except StopIteration:
  304. pass
  305. else: # pragma: no cover
  306. raise AssertionError("Time generator is not finished")
  307. def _add_reader(self, fd, callback, *args):
  308. self.readers[fd] = events.Handle(callback, args, self, None)
  309. def _remove_reader(self, fd):
  310. self.remove_reader_count[fd] += 1
  311. if fd in self.readers:
  312. del self.readers[fd]
  313. return True
  314. else:
  315. return False
  316. def assert_reader(self, fd, callback, *args):
  317. if fd not in self.readers:
  318. raise AssertionError(f'fd {fd} is not registered')
  319. handle = self.readers[fd]
  320. if handle._callback != callback:
  321. raise AssertionError(
  322. f'unexpected callback: {handle._callback} != {callback}')
  323. if handle._args != args:
  324. raise AssertionError(
  325. f'unexpected callback args: {handle._args} != {args}')
  326. def assert_no_reader(self, fd):
  327. if fd in self.readers:
  328. raise AssertionError(f'fd {fd} is registered')
  329. def _add_writer(self, fd, callback, *args):
  330. self.writers[fd] = events.Handle(callback, args, self, None)
  331. def _remove_writer(self, fd):
  332. self.remove_writer_count[fd] += 1
  333. if fd in self.writers:
  334. del self.writers[fd]
  335. return True
  336. else:
  337. return False
  338. def assert_writer(self, fd, callback, *args):
  339. if fd not in self.writers:
  340. raise AssertionError(f'fd {fd} is not registered')
  341. handle = self.writers[fd]
  342. if handle._callback != callback:
  343. raise AssertionError(f'{handle._callback!r} != {callback!r}')
  344. if handle._args != args:
  345. raise AssertionError(f'{handle._args!r} != {args!r}')
  346. def _ensure_fd_no_transport(self, fd):
  347. if not isinstance(fd, int):
  348. try:
  349. fd = int(fd.fileno())
  350. except (AttributeError, TypeError, ValueError):
  351. # This code matches selectors._fileobj_to_fd function.
  352. raise ValueError("Invalid file object: "
  353. "{!r}".format(fd)) from None
  354. try:
  355. transport = self._transports[fd]
  356. except KeyError:
  357. pass
  358. else:
  359. raise RuntimeError(
  360. 'File descriptor {!r} is used by transport {!r}'.format(
  361. fd, transport))
  362. def add_reader(self, fd, callback, *args):
  363. """Add a reader callback."""
  364. self._ensure_fd_no_transport(fd)
  365. return self._add_reader(fd, callback, *args)
  366. def remove_reader(self, fd):
  367. """Remove a reader callback."""
  368. self._ensure_fd_no_transport(fd)
  369. return self._remove_reader(fd)
  370. def add_writer(self, fd, callback, *args):
  371. """Add a writer callback.."""
  372. self._ensure_fd_no_transport(fd)
  373. return self._add_writer(fd, callback, *args)
  374. def remove_writer(self, fd):
  375. """Remove a writer callback."""
  376. self._ensure_fd_no_transport(fd)
  377. return self._remove_writer(fd)
  378. def reset_counters(self):
  379. self.remove_reader_count = collections.defaultdict(int)
  380. self.remove_writer_count = collections.defaultdict(int)
  381. def _run_once(self):
  382. super()._run_once()
  383. for when in self._timers:
  384. advance = self._gen.send(when)
  385. self.advance_time(advance)
  386. self._timers = []
  387. def call_at(self, when, callback, *args, context=None):
  388. self._timers.append(when)
  389. return super().call_at(when, callback, *args, context=context)
  390. def _process_events(self, event_list):
  391. return
  392. def _write_to_self(self):
  393. pass
  394. def MockCallback(**kwargs):
  395. return mock.Mock(spec=['__call__'], **kwargs)
  396. class MockPattern(str):
  397. """A regex based str with a fuzzy __eq__.
  398. Use this helper with 'mock.assert_called_with', or anywhere
  399. where a regex comparison between strings is needed.
  400. For instance:
  401. mock_call.assert_called_with(MockPattern('spam.*ham'))
  402. """
  403. def __eq__(self, other):
  404. return bool(re.search(str(self), other, re.S))
  405. class MockInstanceOf:
  406. def __init__(self, type):
  407. self._type = type
  408. def __eq__(self, other):
  409. return isinstance(other, self._type)
  410. def get_function_source(func):
  411. source = format_helpers._get_function_source(func)
  412. if source is None:
  413. raise ValueError("unable to get the source of %r" % (func,))
  414. return source
  415. class TestCase(unittest.TestCase):
  416. @staticmethod
  417. def close_loop(loop):
  418. if loop._default_executor is not None:
  419. if not loop.is_closed():
  420. loop.run_until_complete(loop.shutdown_default_executor())
  421. else:
  422. loop._default_executor.shutdown(wait=True)
  423. loop.close()
  424. policy = support.maybe_get_event_loop_policy()
  425. if policy is not None:
  426. try:
  427. watcher = policy.get_child_watcher()
  428. except NotImplementedError:
  429. # watcher is not implemented by EventLoopPolicy, e.g. Windows
  430. pass
  431. else:
  432. if isinstance(watcher, asyncio.ThreadedChildWatcher):
  433. threads = list(watcher._threads.values())
  434. for thread in threads:
  435. thread.join()
  436. def set_event_loop(self, loop, *, cleanup=True):
  437. if loop is None:
  438. raise AssertionError('loop is None')
  439. # ensure that the event loop is passed explicitly in asyncio
  440. events.set_event_loop(None)
  441. if cleanup:
  442. self.addCleanup(self.close_loop, loop)
  443. def new_test_loop(self, gen=None):
  444. loop = TestLoop(gen)
  445. self.set_event_loop(loop)
  446. return loop
  447. def setUp(self):
  448. self._thread_cleanup = threading_helper.threading_setup()
  449. def tearDown(self):
  450. events.set_event_loop(None)
  451. # Detect CPython bug #23353: ensure that yield/yield-from is not used
  452. # in an except block of a generator
  453. self.assertEqual(sys.exc_info(), (None, None, None))
  454. self.doCleanups()
  455. threading_helper.threading_cleanup(*self._thread_cleanup)
  456. support.reap_children()
  457. @contextlib.contextmanager
  458. def disable_logger():
  459. """Context manager to disable asyncio logger.
  460. For example, it can be used to ignore warnings in debug mode.
  461. """
  462. old_level = logger.level
  463. try:
  464. logger.setLevel(logging.CRITICAL+1)
  465. yield
  466. finally:
  467. logger.setLevel(old_level)
  468. def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
  469. family=socket.AF_INET):
  470. """Create a mock of a non-blocking socket."""
  471. sock = mock.MagicMock(socket.socket)
  472. sock.proto = proto
  473. sock.type = type
  474. sock.family = family
  475. sock.gettimeout.return_value = 0.0
  476. return sock