sslproto.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921
  1. import collections
  2. import enum
  3. import warnings
  4. try:
  5. import ssl
  6. except ImportError: # pragma: no cover
  7. ssl = None
  8. from . import constants
  9. from . import exceptions
  10. from . import protocols
  11. from . import transports
  12. from .log import logger
  13. if ssl is not None:
  14. SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError)
  15. class SSLProtocolState(enum.Enum):
  16. UNWRAPPED = "UNWRAPPED"
  17. DO_HANDSHAKE = "DO_HANDSHAKE"
  18. WRAPPED = "WRAPPED"
  19. FLUSHING = "FLUSHING"
  20. SHUTDOWN = "SHUTDOWN"
  21. class AppProtocolState(enum.Enum):
  22. # This tracks the state of app protocol (https://git.io/fj59P):
  23. #
  24. # INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST
  25. #
  26. # * cm: connection_made()
  27. # * dr: data_received()
  28. # * er: eof_received()
  29. # * cl: connection_lost()
  30. STATE_INIT = "STATE_INIT"
  31. STATE_CON_MADE = "STATE_CON_MADE"
  32. STATE_EOF = "STATE_EOF"
  33. STATE_CON_LOST = "STATE_CON_LOST"
  34. def _create_transport_context(server_side, server_hostname):
  35. if server_side:
  36. raise ValueError('Server side SSL needs a valid SSLContext')
  37. # Client side may pass ssl=True to use a default
  38. # context; in that case the sslcontext passed is None.
  39. # The default is secure for client connections.
  40. # Python 3.4+: use up-to-date strong settings.
  41. sslcontext = ssl.create_default_context()
  42. if not server_hostname:
  43. sslcontext.check_hostname = False
  44. return sslcontext
  45. def add_flowcontrol_defaults(high, low, kb):
  46. if high is None:
  47. if low is None:
  48. hi = kb * 1024
  49. else:
  50. lo = low
  51. hi = 4 * lo
  52. else:
  53. hi = high
  54. if low is None:
  55. lo = hi // 4
  56. else:
  57. lo = low
  58. if not hi >= lo >= 0:
  59. raise ValueError('high (%r) must be >= low (%r) must be >= 0' %
  60. (hi, lo))
  61. return hi, lo
  62. class _SSLProtocolTransport(transports._FlowControlMixin,
  63. transports.Transport):
  64. _start_tls_compatible = True
  65. _sendfile_compatible = constants._SendfileMode.FALLBACK
  66. def __init__(self, loop, ssl_protocol):
  67. self._loop = loop
  68. self._ssl_protocol = ssl_protocol
  69. self._closed = False
  70. def get_extra_info(self, name, default=None):
  71. """Get optional transport information."""
  72. return self._ssl_protocol._get_extra_info(name, default)
  73. def set_protocol(self, protocol):
  74. self._ssl_protocol._set_app_protocol(protocol)
  75. def get_protocol(self):
  76. return self._ssl_protocol._app_protocol
  77. def is_closing(self):
  78. return self._closed
  79. def close(self):
  80. """Close the transport.
  81. Buffered data will be flushed asynchronously. No more data
  82. will be received. After all buffered data is flushed, the
  83. protocol's connection_lost() method will (eventually) called
  84. with None as its argument.
  85. """
  86. if not self._closed:
  87. self._closed = True
  88. self._ssl_protocol._start_shutdown()
  89. else:
  90. self._ssl_protocol = None
  91. def __del__(self, _warnings=warnings):
  92. if not self._closed:
  93. self._closed = True
  94. _warnings.warn(
  95. "unclosed transport <asyncio._SSLProtocolTransport "
  96. "object>", ResourceWarning)
  97. def is_reading(self):
  98. return not self._ssl_protocol._app_reading_paused
  99. def pause_reading(self):
  100. """Pause the receiving end.
  101. No data will be passed to the protocol's data_received()
  102. method until resume_reading() is called.
  103. """
  104. self._ssl_protocol._pause_reading()
  105. def resume_reading(self):
  106. """Resume the receiving end.
  107. Data received will once again be passed to the protocol's
  108. data_received() method.
  109. """
  110. self._ssl_protocol._resume_reading()
  111. def set_write_buffer_limits(self, high=None, low=None):
  112. """Set the high- and low-water limits for write flow control.
  113. These two values control when to call the protocol's
  114. pause_writing() and resume_writing() methods. If specified,
  115. the low-water limit must be less than or equal to the
  116. high-water limit. Neither value can be negative.
  117. The defaults are implementation-specific. If only the
  118. high-water limit is given, the low-water limit defaults to an
  119. implementation-specific value less than or equal to the
  120. high-water limit. Setting high to zero forces low to zero as
  121. well, and causes pause_writing() to be called whenever the
  122. buffer becomes non-empty. Setting low to zero causes
  123. resume_writing() to be called only once the buffer is empty.
  124. Use of zero for either limit is generally sub-optimal as it
  125. reduces opportunities for doing I/O and computation
  126. concurrently.
  127. """
  128. self._ssl_protocol._set_write_buffer_limits(high, low)
  129. self._ssl_protocol._control_app_writing()
  130. def get_write_buffer_limits(self):
  131. return (self._ssl_protocol._outgoing_low_water,
  132. self._ssl_protocol._outgoing_high_water)
  133. def get_write_buffer_size(self):
  134. """Return the current size of the write buffers."""
  135. return self._ssl_protocol._get_write_buffer_size()
  136. def set_read_buffer_limits(self, high=None, low=None):
  137. """Set the high- and low-water limits for read flow control.
  138. These two values control when to call the upstream transport's
  139. pause_reading() and resume_reading() methods. If specified,
  140. the low-water limit must be less than or equal to the
  141. high-water limit. Neither value can be negative.
  142. The defaults are implementation-specific. If only the
  143. high-water limit is given, the low-water limit defaults to an
  144. implementation-specific value less than or equal to the
  145. high-water limit. Setting high to zero forces low to zero as
  146. well, and causes pause_reading() to be called whenever the
  147. buffer becomes non-empty. Setting low to zero causes
  148. resume_reading() to be called only once the buffer is empty.
  149. Use of zero for either limit is generally sub-optimal as it
  150. reduces opportunities for doing I/O and computation
  151. concurrently.
  152. """
  153. self._ssl_protocol._set_read_buffer_limits(high, low)
  154. self._ssl_protocol._control_ssl_reading()
  155. def get_read_buffer_limits(self):
  156. return (self._ssl_protocol._incoming_low_water,
  157. self._ssl_protocol._incoming_high_water)
  158. def get_read_buffer_size(self):
  159. """Return the current size of the read buffer."""
  160. return self._ssl_protocol._get_read_buffer_size()
  161. @property
  162. def _protocol_paused(self):
  163. # Required for sendfile fallback pause_writing/resume_writing logic
  164. return self._ssl_protocol._app_writing_paused
  165. def write(self, data):
  166. """Write some data bytes to the transport.
  167. This does not block; it buffers the data and arranges for it
  168. to be sent out asynchronously.
  169. """
  170. if not isinstance(data, (bytes, bytearray, memoryview)):
  171. raise TypeError(f"data: expecting a bytes-like instance, "
  172. f"got {type(data).__name__}")
  173. if not data:
  174. return
  175. self._ssl_protocol._write_appdata((data,))
  176. def writelines(self, list_of_data):
  177. """Write a list (or any iterable) of data bytes to the transport.
  178. The default implementation concatenates the arguments and
  179. calls write() on the result.
  180. """
  181. self._ssl_protocol._write_appdata(list_of_data)
  182. def write_eof(self):
  183. """Close the write end after flushing buffered data.
  184. This raises :exc:`NotImplementedError` right now.
  185. """
  186. raise NotImplementedError
  187. def can_write_eof(self):
  188. """Return True if this transport supports write_eof(), False if not."""
  189. return False
  190. def abort(self):
  191. """Close the transport immediately.
  192. Buffered data will be lost. No more data will be received.
  193. The protocol's connection_lost() method will (eventually) be
  194. called with None as its argument.
  195. """
  196. self._closed = True
  197. self._ssl_protocol._abort()
  198. def _force_close(self, exc):
  199. self._closed = True
  200. self._ssl_protocol._abort(exc)
  201. def _test__append_write_backlog(self, data):
  202. # for test only
  203. self._ssl_protocol._write_backlog.append(data)
  204. self._ssl_protocol._write_buffer_size += len(data)
  205. class SSLProtocol(protocols.BufferedProtocol):
  206. max_size = 256 * 1024 # Buffer size passed to read()
  207. _handshake_start_time = None
  208. _handshake_timeout_handle = None
  209. _shutdown_timeout_handle = None
  210. def __init__(self, loop, app_protocol, sslcontext, waiter,
  211. server_side=False, server_hostname=None,
  212. call_connection_made=True,
  213. ssl_handshake_timeout=None,
  214. ssl_shutdown_timeout=None):
  215. if ssl is None:
  216. raise RuntimeError("stdlib ssl module not available")
  217. self._ssl_buffer = bytearray(self.max_size)
  218. self._ssl_buffer_view = memoryview(self._ssl_buffer)
  219. if ssl_handshake_timeout is None:
  220. ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT
  221. elif ssl_handshake_timeout <= 0:
  222. raise ValueError(
  223. f"ssl_handshake_timeout should be a positive number, "
  224. f"got {ssl_handshake_timeout}")
  225. if ssl_shutdown_timeout is None:
  226. ssl_shutdown_timeout = constants.SSL_SHUTDOWN_TIMEOUT
  227. elif ssl_shutdown_timeout <= 0:
  228. raise ValueError(
  229. f"ssl_shutdown_timeout should be a positive number, "
  230. f"got {ssl_shutdown_timeout}")
  231. if not sslcontext:
  232. sslcontext = _create_transport_context(
  233. server_side, server_hostname)
  234. self._server_side = server_side
  235. if server_hostname and not server_side:
  236. self._server_hostname = server_hostname
  237. else:
  238. self._server_hostname = None
  239. self._sslcontext = sslcontext
  240. # SSL-specific extra info. More info are set when the handshake
  241. # completes.
  242. self._extra = dict(sslcontext=sslcontext)
  243. # App data write buffering
  244. self._write_backlog = collections.deque()
  245. self._write_buffer_size = 0
  246. self._waiter = waiter
  247. self._loop = loop
  248. self._set_app_protocol(app_protocol)
  249. self._app_transport = None
  250. self._app_transport_created = False
  251. # transport, ex: SelectorSocketTransport
  252. self._transport = None
  253. self._ssl_handshake_timeout = ssl_handshake_timeout
  254. self._ssl_shutdown_timeout = ssl_shutdown_timeout
  255. # SSL and state machine
  256. self._incoming = ssl.MemoryBIO()
  257. self._outgoing = ssl.MemoryBIO()
  258. self._state = SSLProtocolState.UNWRAPPED
  259. self._conn_lost = 0 # Set when connection_lost called
  260. if call_connection_made:
  261. self._app_state = AppProtocolState.STATE_INIT
  262. else:
  263. self._app_state = AppProtocolState.STATE_CON_MADE
  264. self._sslobj = self._sslcontext.wrap_bio(
  265. self._incoming, self._outgoing,
  266. server_side=self._server_side,
  267. server_hostname=self._server_hostname)
  268. # Flow Control
  269. self._ssl_writing_paused = False
  270. self._app_reading_paused = False
  271. self._ssl_reading_paused = False
  272. self._incoming_high_water = 0
  273. self._incoming_low_water = 0
  274. self._set_read_buffer_limits()
  275. self._eof_received = False
  276. self._app_writing_paused = False
  277. self._outgoing_high_water = 0
  278. self._outgoing_low_water = 0
  279. self._set_write_buffer_limits()
  280. self._get_app_transport()
  281. def _set_app_protocol(self, app_protocol):
  282. self._app_protocol = app_protocol
  283. # Make fast hasattr check first
  284. if (hasattr(app_protocol, 'get_buffer') and
  285. isinstance(app_protocol, protocols.BufferedProtocol)):
  286. self._app_protocol_get_buffer = app_protocol.get_buffer
  287. self._app_protocol_buffer_updated = app_protocol.buffer_updated
  288. self._app_protocol_is_buffer = True
  289. else:
  290. self._app_protocol_is_buffer = False
  291. def _wakeup_waiter(self, exc=None):
  292. if self._waiter is None:
  293. return
  294. if not self._waiter.cancelled():
  295. if exc is not None:
  296. self._waiter.set_exception(exc)
  297. else:
  298. self._waiter.set_result(None)
  299. self._waiter = None
  300. def _get_app_transport(self):
  301. if self._app_transport is None:
  302. if self._app_transport_created:
  303. raise RuntimeError('Creating _SSLProtocolTransport twice')
  304. self._app_transport = _SSLProtocolTransport(self._loop, self)
  305. self._app_transport_created = True
  306. return self._app_transport
  307. def connection_made(self, transport):
  308. """Called when the low-level connection is made.
  309. Start the SSL handshake.
  310. """
  311. self._transport = transport
  312. self._start_handshake()
  313. def connection_lost(self, exc):
  314. """Called when the low-level connection is lost or closed.
  315. The argument is an exception object or None (the latter
  316. meaning a regular EOF is received or the connection was
  317. aborted or closed).
  318. """
  319. self._write_backlog.clear()
  320. self._outgoing.read()
  321. self._conn_lost += 1
  322. # Just mark the app transport as closed so that its __dealloc__
  323. # doesn't complain.
  324. if self._app_transport is not None:
  325. self._app_transport._closed = True
  326. if self._state != SSLProtocolState.DO_HANDSHAKE:
  327. if (
  328. self._app_state == AppProtocolState.STATE_CON_MADE or
  329. self._app_state == AppProtocolState.STATE_EOF
  330. ):
  331. self._app_state = AppProtocolState.STATE_CON_LOST
  332. self._loop.call_soon(self._app_protocol.connection_lost, exc)
  333. self._set_state(SSLProtocolState.UNWRAPPED)
  334. self._transport = None
  335. self._app_transport = None
  336. self._app_protocol = None
  337. self._wakeup_waiter(exc)
  338. if self._shutdown_timeout_handle:
  339. self._shutdown_timeout_handle.cancel()
  340. self._shutdown_timeout_handle = None
  341. if self._handshake_timeout_handle:
  342. self._handshake_timeout_handle.cancel()
  343. self._handshake_timeout_handle = None
  344. def get_buffer(self, n):
  345. want = n
  346. if want <= 0 or want > self.max_size:
  347. want = self.max_size
  348. if len(self._ssl_buffer) < want:
  349. self._ssl_buffer = bytearray(want)
  350. self._ssl_buffer_view = memoryview(self._ssl_buffer)
  351. return self._ssl_buffer_view
  352. def buffer_updated(self, nbytes):
  353. self._incoming.write(self._ssl_buffer_view[:nbytes])
  354. if self._state == SSLProtocolState.DO_HANDSHAKE:
  355. self._do_handshake()
  356. elif self._state == SSLProtocolState.WRAPPED:
  357. self._do_read()
  358. elif self._state == SSLProtocolState.FLUSHING:
  359. self._do_flush()
  360. elif self._state == SSLProtocolState.SHUTDOWN:
  361. self._do_shutdown()
  362. def eof_received(self):
  363. """Called when the other end of the low-level stream
  364. is half-closed.
  365. If this returns a false value (including None), the transport
  366. will close itself. If it returns a true value, closing the
  367. transport is up to the protocol.
  368. """
  369. self._eof_received = True
  370. try:
  371. if self._loop.get_debug():
  372. logger.debug("%r received EOF", self)
  373. if self._state == SSLProtocolState.DO_HANDSHAKE:
  374. self._on_handshake_complete(ConnectionResetError)
  375. elif self._state == SSLProtocolState.WRAPPED:
  376. self._set_state(SSLProtocolState.FLUSHING)
  377. if self._app_reading_paused:
  378. return True
  379. else:
  380. self._do_flush()
  381. elif self._state == SSLProtocolState.FLUSHING:
  382. self._do_write()
  383. self._set_state(SSLProtocolState.SHUTDOWN)
  384. self._do_shutdown()
  385. elif self._state == SSLProtocolState.SHUTDOWN:
  386. self._do_shutdown()
  387. except Exception:
  388. self._transport.close()
  389. raise
  390. def _get_extra_info(self, name, default=None):
  391. if name in self._extra:
  392. return self._extra[name]
  393. elif self._transport is not None:
  394. return self._transport.get_extra_info(name, default)
  395. else:
  396. return default
  397. def _set_state(self, new_state):
  398. allowed = False
  399. if new_state == SSLProtocolState.UNWRAPPED:
  400. allowed = True
  401. elif (
  402. self._state == SSLProtocolState.UNWRAPPED and
  403. new_state == SSLProtocolState.DO_HANDSHAKE
  404. ):
  405. allowed = True
  406. elif (
  407. self._state == SSLProtocolState.DO_HANDSHAKE and
  408. new_state == SSLProtocolState.WRAPPED
  409. ):
  410. allowed = True
  411. elif (
  412. self._state == SSLProtocolState.WRAPPED and
  413. new_state == SSLProtocolState.FLUSHING
  414. ):
  415. allowed = True
  416. elif (
  417. self._state == SSLProtocolState.FLUSHING and
  418. new_state == SSLProtocolState.SHUTDOWN
  419. ):
  420. allowed = True
  421. if allowed:
  422. self._state = new_state
  423. else:
  424. raise RuntimeError(
  425. 'cannot switch state from {} to {}'.format(
  426. self._state, new_state))
  427. # Handshake flow
  428. def _start_handshake(self):
  429. if self._loop.get_debug():
  430. logger.debug("%r starts SSL handshake", self)
  431. self._handshake_start_time = self._loop.time()
  432. else:
  433. self._handshake_start_time = None
  434. self._set_state(SSLProtocolState.DO_HANDSHAKE)
  435. # start handshake timeout count down
  436. self._handshake_timeout_handle = \
  437. self._loop.call_later(self._ssl_handshake_timeout,
  438. lambda: self._check_handshake_timeout())
  439. self._do_handshake()
  440. def _check_handshake_timeout(self):
  441. if self._state == SSLProtocolState.DO_HANDSHAKE:
  442. msg = (
  443. f"SSL handshake is taking longer than "
  444. f"{self._ssl_handshake_timeout} seconds: "
  445. f"aborting the connection"
  446. )
  447. self._fatal_error(ConnectionAbortedError(msg))
  448. def _do_handshake(self):
  449. try:
  450. self._sslobj.do_handshake()
  451. except SSLAgainErrors:
  452. self._process_outgoing()
  453. except ssl.SSLError as exc:
  454. self._on_handshake_complete(exc)
  455. else:
  456. self._on_handshake_complete(None)
  457. def _on_handshake_complete(self, handshake_exc):
  458. if self._handshake_timeout_handle is not None:
  459. self._handshake_timeout_handle.cancel()
  460. self._handshake_timeout_handle = None
  461. sslobj = self._sslobj
  462. try:
  463. if handshake_exc is None:
  464. self._set_state(SSLProtocolState.WRAPPED)
  465. else:
  466. raise handshake_exc
  467. peercert = sslobj.getpeercert()
  468. except Exception as exc:
  469. self._set_state(SSLProtocolState.UNWRAPPED)
  470. if isinstance(exc, ssl.CertificateError):
  471. msg = 'SSL handshake failed on verifying the certificate'
  472. else:
  473. msg = 'SSL handshake failed'
  474. self._fatal_error(exc, msg)
  475. self._wakeup_waiter(exc)
  476. return
  477. if self._loop.get_debug():
  478. dt = self._loop.time() - self._handshake_start_time
  479. logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3)
  480. # Add extra info that becomes available after handshake.
  481. self._extra.update(peercert=peercert,
  482. cipher=sslobj.cipher(),
  483. compression=sslobj.compression(),
  484. ssl_object=sslobj)
  485. if self._app_state == AppProtocolState.STATE_INIT:
  486. self._app_state = AppProtocolState.STATE_CON_MADE
  487. self._app_protocol.connection_made(self._get_app_transport())
  488. self._wakeup_waiter()
  489. self._do_read()
  490. # Shutdown flow
  491. def _start_shutdown(self):
  492. if (
  493. self._state in (
  494. SSLProtocolState.FLUSHING,
  495. SSLProtocolState.SHUTDOWN,
  496. SSLProtocolState.UNWRAPPED
  497. )
  498. ):
  499. return
  500. if self._app_transport is not None:
  501. self._app_transport._closed = True
  502. if self._state == SSLProtocolState.DO_HANDSHAKE:
  503. self._abort()
  504. else:
  505. self._set_state(SSLProtocolState.FLUSHING)
  506. self._shutdown_timeout_handle = self._loop.call_later(
  507. self._ssl_shutdown_timeout,
  508. lambda: self._check_shutdown_timeout()
  509. )
  510. self._do_flush()
  511. def _check_shutdown_timeout(self):
  512. if (
  513. self._state in (
  514. SSLProtocolState.FLUSHING,
  515. SSLProtocolState.SHUTDOWN
  516. )
  517. ):
  518. self._transport._force_close(
  519. exceptions.TimeoutError('SSL shutdown timed out'))
  520. def _do_flush(self):
  521. self._do_read()
  522. self._set_state(SSLProtocolState.SHUTDOWN)
  523. self._do_shutdown()
  524. def _do_shutdown(self):
  525. try:
  526. if not self._eof_received:
  527. self._sslobj.unwrap()
  528. except SSLAgainErrors:
  529. self._process_outgoing()
  530. except ssl.SSLError as exc:
  531. self._on_shutdown_complete(exc)
  532. else:
  533. self._process_outgoing()
  534. self._call_eof_received()
  535. self._on_shutdown_complete(None)
  536. def _on_shutdown_complete(self, shutdown_exc):
  537. if self._shutdown_timeout_handle is not None:
  538. self._shutdown_timeout_handle.cancel()
  539. self._shutdown_timeout_handle = None
  540. if shutdown_exc:
  541. self._fatal_error(shutdown_exc)
  542. else:
  543. self._loop.call_soon(self._transport.close)
  544. def _abort(self):
  545. self._set_state(SSLProtocolState.UNWRAPPED)
  546. if self._transport is not None:
  547. self._transport.abort()
  548. # Outgoing flow
  549. def _write_appdata(self, list_of_data):
  550. if (
  551. self._state in (
  552. SSLProtocolState.FLUSHING,
  553. SSLProtocolState.SHUTDOWN,
  554. SSLProtocolState.UNWRAPPED
  555. )
  556. ):
  557. if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
  558. logger.warning('SSL connection is closed')
  559. self._conn_lost += 1
  560. return
  561. for data in list_of_data:
  562. self._write_backlog.append(data)
  563. self._write_buffer_size += len(data)
  564. try:
  565. if self._state == SSLProtocolState.WRAPPED:
  566. self._do_write()
  567. except Exception as ex:
  568. self._fatal_error(ex, 'Fatal error on SSL protocol')
  569. def _do_write(self):
  570. try:
  571. while self._write_backlog:
  572. data = self._write_backlog[0]
  573. count = self._sslobj.write(data)
  574. data_len = len(data)
  575. if count < data_len:
  576. self._write_backlog[0] = data[count:]
  577. self._write_buffer_size -= count
  578. else:
  579. del self._write_backlog[0]
  580. self._write_buffer_size -= data_len
  581. except SSLAgainErrors:
  582. pass
  583. self._process_outgoing()
  584. def _process_outgoing(self):
  585. if not self._ssl_writing_paused:
  586. data = self._outgoing.read()
  587. if len(data):
  588. self._transport.write(data)
  589. self._control_app_writing()
  590. # Incoming flow
  591. def _do_read(self):
  592. if (
  593. self._state not in (
  594. SSLProtocolState.WRAPPED,
  595. SSLProtocolState.FLUSHING,
  596. )
  597. ):
  598. return
  599. try:
  600. if not self._app_reading_paused:
  601. if self._app_protocol_is_buffer:
  602. self._do_read__buffered()
  603. else:
  604. self._do_read__copied()
  605. if self._write_backlog:
  606. self._do_write()
  607. else:
  608. self._process_outgoing()
  609. self._control_ssl_reading()
  610. except Exception as ex:
  611. self._fatal_error(ex, 'Fatal error on SSL protocol')
  612. def _do_read__buffered(self):
  613. offset = 0
  614. count = 1
  615. buf = self._app_protocol_get_buffer(self._get_read_buffer_size())
  616. wants = len(buf)
  617. try:
  618. count = self._sslobj.read(wants, buf)
  619. if count > 0:
  620. offset = count
  621. while offset < wants:
  622. count = self._sslobj.read(wants - offset, buf[offset:])
  623. if count > 0:
  624. offset += count
  625. else:
  626. break
  627. else:
  628. self._loop.call_soon(lambda: self._do_read())
  629. except SSLAgainErrors:
  630. pass
  631. if offset > 0:
  632. self._app_protocol_buffer_updated(offset)
  633. if not count:
  634. # close_notify
  635. self._call_eof_received()
  636. self._start_shutdown()
  637. def _do_read__copied(self):
  638. chunk = b'1'
  639. zero = True
  640. one = False
  641. try:
  642. while True:
  643. chunk = self._sslobj.read(self.max_size)
  644. if not chunk:
  645. break
  646. if zero:
  647. zero = False
  648. one = True
  649. first = chunk
  650. elif one:
  651. one = False
  652. data = [first, chunk]
  653. else:
  654. data.append(chunk)
  655. except SSLAgainErrors:
  656. pass
  657. if one:
  658. self._app_protocol.data_received(first)
  659. elif not zero:
  660. self._app_protocol.data_received(b''.join(data))
  661. if not chunk:
  662. # close_notify
  663. self._call_eof_received()
  664. self._start_shutdown()
  665. def _call_eof_received(self):
  666. try:
  667. if self._app_state == AppProtocolState.STATE_CON_MADE:
  668. self._app_state = AppProtocolState.STATE_EOF
  669. keep_open = self._app_protocol.eof_received()
  670. if keep_open:
  671. logger.warning('returning true from eof_received() '
  672. 'has no effect when using ssl')
  673. except (KeyboardInterrupt, SystemExit):
  674. raise
  675. except BaseException as ex:
  676. self._fatal_error(ex, 'Error calling eof_received()')
  677. # Flow control for writes from APP socket
  678. def _control_app_writing(self):
  679. size = self._get_write_buffer_size()
  680. if size >= self._outgoing_high_water and not self._app_writing_paused:
  681. self._app_writing_paused = True
  682. try:
  683. self._app_protocol.pause_writing()
  684. except (KeyboardInterrupt, SystemExit):
  685. raise
  686. except BaseException as exc:
  687. self._loop.call_exception_handler({
  688. 'message': 'protocol.pause_writing() failed',
  689. 'exception': exc,
  690. 'transport': self._app_transport,
  691. 'protocol': self,
  692. })
  693. elif size <= self._outgoing_low_water and self._app_writing_paused:
  694. self._app_writing_paused = False
  695. try:
  696. self._app_protocol.resume_writing()
  697. except (KeyboardInterrupt, SystemExit):
  698. raise
  699. except BaseException as exc:
  700. self._loop.call_exception_handler({
  701. 'message': 'protocol.resume_writing() failed',
  702. 'exception': exc,
  703. 'transport': self._app_transport,
  704. 'protocol': self,
  705. })
  706. def _get_write_buffer_size(self):
  707. return self._outgoing.pending + self._write_buffer_size
  708. def _set_write_buffer_limits(self, high=None, low=None):
  709. high, low = add_flowcontrol_defaults(
  710. high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_WRITE)
  711. self._outgoing_high_water = high
  712. self._outgoing_low_water = low
  713. # Flow control for reads to APP socket
  714. def _pause_reading(self):
  715. self._app_reading_paused = True
  716. def _resume_reading(self):
  717. if self._app_reading_paused:
  718. self._app_reading_paused = False
  719. def resume():
  720. if self._state == SSLProtocolState.WRAPPED:
  721. self._do_read()
  722. elif self._state == SSLProtocolState.FLUSHING:
  723. self._do_flush()
  724. elif self._state == SSLProtocolState.SHUTDOWN:
  725. self._do_shutdown()
  726. self._loop.call_soon(resume)
  727. # Flow control for reads from SSL socket
  728. def _control_ssl_reading(self):
  729. size = self._get_read_buffer_size()
  730. if size >= self._incoming_high_water and not self._ssl_reading_paused:
  731. self._ssl_reading_paused = True
  732. self._transport.pause_reading()
  733. elif size <= self._incoming_low_water and self._ssl_reading_paused:
  734. self._ssl_reading_paused = False
  735. self._transport.resume_reading()
  736. def _set_read_buffer_limits(self, high=None, low=None):
  737. high, low = add_flowcontrol_defaults(
  738. high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_READ)
  739. self._incoming_high_water = high
  740. self._incoming_low_water = low
  741. def _get_read_buffer_size(self):
  742. return self._incoming.pending
  743. # Flow control for writes to SSL socket
  744. def pause_writing(self):
  745. """Called when the low-level transport's buffer goes over
  746. the high-water mark.
  747. """
  748. assert not self._ssl_writing_paused
  749. self._ssl_writing_paused = True
  750. def resume_writing(self):
  751. """Called when the low-level transport's buffer drains below
  752. the low-water mark.
  753. """
  754. assert self._ssl_writing_paused
  755. self._ssl_writing_paused = False
  756. self._process_outgoing()
  757. def _fatal_error(self, exc, message='Fatal error on transport'):
  758. if self._transport:
  759. self._transport._force_close(exc)
  760. if isinstance(exc, OSError):
  761. if self._loop.get_debug():
  762. logger.debug("%r: %s", self, message, exc_info=True)
  763. elif not isinstance(exc, exceptions.CancelledError):
  764. self._loop.call_exception_handler({
  765. 'message': message,
  766. 'exception': exc,
  767. 'transport': self._transport,
  768. 'protocol': self,
  769. })