audit-tests.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. """This script contains the actual auditing tests.
  2. It should not be imported directly, but should be run by the test_audit
  3. module with arguments identifying each test.
  4. """
  5. import contextlib
  6. import os
  7. import sys
  8. class TestHook:
  9. """Used in standard hook tests to collect any logged events.
  10. Should be used in a with block to ensure that it has no impact
  11. after the test completes.
  12. """
  13. def __init__(self, raise_on_events=None, exc_type=RuntimeError):
  14. self.raise_on_events = raise_on_events or ()
  15. self.exc_type = exc_type
  16. self.seen = []
  17. self.closed = False
  18. def __enter__(self, *a):
  19. sys.addaudithook(self)
  20. return self
  21. def __exit__(self, *a):
  22. self.close()
  23. def close(self):
  24. self.closed = True
  25. @property
  26. def seen_events(self):
  27. return [i[0] for i in self.seen]
  28. def __call__(self, event, args):
  29. if self.closed:
  30. return
  31. self.seen.append((event, args))
  32. if event in self.raise_on_events:
  33. raise self.exc_type("saw event " + event)
  34. # Simple helpers, since we are not in unittest here
  35. def assertEqual(x, y):
  36. if x != y:
  37. raise AssertionError(f"{x!r} should equal {y!r}")
  38. def assertIn(el, series):
  39. if el not in series:
  40. raise AssertionError(f"{el!r} should be in {series!r}")
  41. def assertNotIn(el, series):
  42. if el in series:
  43. raise AssertionError(f"{el!r} should not be in {series!r}")
  44. def assertSequenceEqual(x, y):
  45. if len(x) != len(y):
  46. raise AssertionError(f"{x!r} should equal {y!r}")
  47. if any(ix != iy for ix, iy in zip(x, y)):
  48. raise AssertionError(f"{x!r} should equal {y!r}")
  49. @contextlib.contextmanager
  50. def assertRaises(ex_type):
  51. try:
  52. yield
  53. assert False, f"expected {ex_type}"
  54. except BaseException as ex:
  55. if isinstance(ex, AssertionError):
  56. raise
  57. assert type(ex) is ex_type, f"{ex} should be {ex_type}"
  58. def test_basic():
  59. with TestHook() as hook:
  60. sys.audit("test_event", 1, 2, 3)
  61. assertEqual(hook.seen[0][0], "test_event")
  62. assertEqual(hook.seen[0][1], (1, 2, 3))
  63. def test_block_add_hook():
  64. # Raising an exception should prevent a new hook from being added,
  65. # but will not propagate out.
  66. with TestHook(raise_on_events="sys.addaudithook") as hook1:
  67. with TestHook() as hook2:
  68. sys.audit("test_event")
  69. assertIn("test_event", hook1.seen_events)
  70. assertNotIn("test_event", hook2.seen_events)
  71. def test_block_add_hook_baseexception():
  72. # Raising BaseException will propagate out when adding a hook
  73. with assertRaises(BaseException):
  74. with TestHook(
  75. raise_on_events="sys.addaudithook", exc_type=BaseException
  76. ) as hook1:
  77. # Adding this next hook should raise BaseException
  78. with TestHook() as hook2:
  79. pass
  80. def test_marshal():
  81. import marshal
  82. o = ("a", "b", "c", 1, 2, 3)
  83. payload = marshal.dumps(o)
  84. with TestHook() as hook:
  85. assertEqual(o, marshal.loads(marshal.dumps(o)))
  86. try:
  87. with open("test-marshal.bin", "wb") as f:
  88. marshal.dump(o, f)
  89. with open("test-marshal.bin", "rb") as f:
  90. assertEqual(o, marshal.load(f))
  91. finally:
  92. os.unlink("test-marshal.bin")
  93. actual = [(a[0], a[1]) for e, a in hook.seen if e == "marshal.dumps"]
  94. assertSequenceEqual(actual, [(o, marshal.version)] * 2)
  95. actual = [a[0] for e, a in hook.seen if e == "marshal.loads"]
  96. assertSequenceEqual(actual, [payload])
  97. actual = [e for e, a in hook.seen if e == "marshal.load"]
  98. assertSequenceEqual(actual, ["marshal.load"])
  99. def test_pickle():
  100. import pickle
  101. class PicklePrint:
  102. def __reduce_ex__(self, p):
  103. return str, ("Pwned!",)
  104. payload_1 = pickle.dumps(PicklePrint())
  105. payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3))
  106. # Before we add the hook, ensure our malicious pickle loads
  107. assertEqual("Pwned!", pickle.loads(payload_1))
  108. with TestHook(raise_on_events="pickle.find_class") as hook:
  109. with assertRaises(RuntimeError):
  110. # With the hook enabled, loading globals is not allowed
  111. pickle.loads(payload_1)
  112. # pickles with no globals are okay
  113. pickle.loads(payload_2)
  114. def test_monkeypatch():
  115. class A:
  116. pass
  117. class B:
  118. pass
  119. class C(A):
  120. pass
  121. a = A()
  122. with TestHook() as hook:
  123. # Catch name changes
  124. C.__name__ = "X"
  125. # Catch type changes
  126. C.__bases__ = (B,)
  127. # Ensure bypassing __setattr__ is still caught
  128. type.__dict__["__bases__"].__set__(C, (B,))
  129. # Catch attribute replacement
  130. C.__init__ = B.__init__
  131. # Catch attribute addition
  132. C.new_attr = 123
  133. # Catch class changes
  134. a.__class__ = B
  135. actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"]
  136. assertSequenceEqual(
  137. [(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual
  138. )
  139. def test_open():
  140. # SSLContext.load_dh_params uses _Py_fopen_obj rather than normal open()
  141. try:
  142. import ssl
  143. load_dh_params = ssl.create_default_context().load_dh_params
  144. except ImportError:
  145. load_dh_params = None
  146. # Try a range of "open" functions.
  147. # All of them should fail
  148. with TestHook(raise_on_events={"open"}) as hook:
  149. for fn, *args in [
  150. (open, sys.argv[2], "r"),
  151. (open, sys.executable, "rb"),
  152. (open, 3, "wb"),
  153. (open, sys.argv[2], "w", -1, None, None, None, False, lambda *a: 1),
  154. (load_dh_params, sys.argv[2]),
  155. ]:
  156. if not fn:
  157. continue
  158. with assertRaises(RuntimeError):
  159. fn(*args)
  160. actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]]
  161. actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]]
  162. assertSequenceEqual(
  163. [
  164. i
  165. for i in [
  166. (sys.argv[2], "r"),
  167. (sys.executable, "r"),
  168. (3, "w"),
  169. (sys.argv[2], "w"),
  170. (sys.argv[2], "rb") if load_dh_params else None,
  171. ]
  172. if i is not None
  173. ],
  174. actual_mode,
  175. )
  176. assertSequenceEqual([], actual_flag)
  177. def test_cantrace():
  178. traced = []
  179. def trace(frame, event, *args):
  180. if frame.f_code == TestHook.__call__.__code__:
  181. traced.append(event)
  182. old = sys.settrace(trace)
  183. try:
  184. with TestHook() as hook:
  185. # No traced call
  186. eval("1")
  187. # No traced call
  188. hook.__cantrace__ = False
  189. eval("2")
  190. # One traced call
  191. hook.__cantrace__ = True
  192. eval("3")
  193. # Two traced calls (writing to private member, eval)
  194. hook.__cantrace__ = 1
  195. eval("4")
  196. # One traced call (writing to private member)
  197. hook.__cantrace__ = 0
  198. finally:
  199. sys.settrace(old)
  200. assertSequenceEqual(["call"] * 4, traced)
  201. def test_mmap():
  202. import mmap
  203. with TestHook() as hook:
  204. mmap.mmap(-1, 8)
  205. assertEqual(hook.seen[0][1][:2], (-1, 8))
  206. def test_excepthook():
  207. def excepthook(exc_type, exc_value, exc_tb):
  208. if exc_type is not RuntimeError:
  209. sys.__excepthook__(exc_type, exc_value, exc_tb)
  210. def hook(event, args):
  211. if event == "sys.excepthook":
  212. if not isinstance(args[2], args[1]):
  213. raise TypeError(f"Expected isinstance({args[2]!r}, " f"{args[1]!r})")
  214. if args[0] != excepthook:
  215. raise ValueError(f"Expected {args[0]} == {excepthook}")
  216. print(event, repr(args[2]))
  217. sys.addaudithook(hook)
  218. sys.excepthook = excepthook
  219. raise RuntimeError("fatal-error")
  220. def test_unraisablehook():
  221. from _testcapi import write_unraisable_exc
  222. def unraisablehook(hookargs):
  223. pass
  224. def hook(event, args):
  225. if event == "sys.unraisablehook":
  226. if args[0] != unraisablehook:
  227. raise ValueError(f"Expected {args[0]} == {unraisablehook}")
  228. print(event, repr(args[1].exc_value), args[1].err_msg)
  229. sys.addaudithook(hook)
  230. sys.unraisablehook = unraisablehook
  231. write_unraisable_exc(RuntimeError("nonfatal-error"), "for audit hook test", None)
  232. def test_winreg():
  233. from winreg import OpenKey, EnumKey, CloseKey, HKEY_LOCAL_MACHINE
  234. def hook(event, args):
  235. if not event.startswith("winreg."):
  236. return
  237. print(event, *args)
  238. sys.addaudithook(hook)
  239. k = OpenKey(HKEY_LOCAL_MACHINE, "Software")
  240. EnumKey(k, 0)
  241. try:
  242. EnumKey(k, 10000)
  243. except OSError:
  244. pass
  245. else:
  246. raise RuntimeError("Expected EnumKey(HKLM, 10000) to fail")
  247. kv = k.Detach()
  248. CloseKey(kv)
  249. def test_socket():
  250. import socket
  251. def hook(event, args):
  252. if event.startswith("socket."):
  253. print(event, *args)
  254. sys.addaudithook(hook)
  255. socket.gethostname()
  256. # Don't care if this fails, we just want the audit message
  257. sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  258. try:
  259. # Don't care if this fails, we just want the audit message
  260. sock.bind(('127.0.0.1', 8080))
  261. except Exception:
  262. pass
  263. finally:
  264. sock.close()
  265. def test_gc():
  266. import gc
  267. def hook(event, args):
  268. if event.startswith("gc."):
  269. print(event, *args)
  270. sys.addaudithook(hook)
  271. gc.get_objects(generation=1)
  272. x = object()
  273. y = [x]
  274. gc.get_referrers(x)
  275. gc.get_referents(y)
  276. def test_http_client():
  277. import http.client
  278. def hook(event, args):
  279. if event.startswith("http.client."):
  280. print(event, *args[1:])
  281. sys.addaudithook(hook)
  282. conn = http.client.HTTPConnection('www.python.org')
  283. try:
  284. conn.request('GET', '/')
  285. except OSError:
  286. print('http.client.send', '[cannot send]')
  287. finally:
  288. conn.close()
  289. def test_sqlite3():
  290. import sqlite3
  291. def hook(event, *args):
  292. if event.startswith("sqlite3."):
  293. print(event, *args)
  294. sys.addaudithook(hook)
  295. cx1 = sqlite3.connect(":memory:")
  296. cx2 = sqlite3.Connection(":memory:")
  297. # Configured without --enable-loadable-sqlite-extensions
  298. if hasattr(sqlite3.Connection, "enable_load_extension"):
  299. cx1.enable_load_extension(False)
  300. try:
  301. cx1.load_extension("test")
  302. except sqlite3.OperationalError:
  303. pass
  304. else:
  305. raise RuntimeError("Expected sqlite3.load_extension to fail")
  306. def test_sys_getframe():
  307. import sys
  308. def hook(event, args):
  309. if event.startswith("sys."):
  310. print(event, args[0].f_code.co_name)
  311. sys.addaudithook(hook)
  312. sys._getframe()
  313. def test_syslog():
  314. import syslog
  315. def hook(event, args):
  316. if event.startswith("syslog."):
  317. print(event, *args)
  318. sys.addaudithook(hook)
  319. syslog.openlog('python')
  320. syslog.syslog('test')
  321. syslog.setlogmask(syslog.LOG_DEBUG)
  322. syslog.closelog()
  323. # implicit open
  324. syslog.syslog('test2')
  325. # open with default ident
  326. syslog.openlog(logoption=syslog.LOG_NDELAY, facility=syslog.LOG_LOCAL0)
  327. sys.argv = None
  328. syslog.openlog()
  329. syslog.closelog()
  330. def test_not_in_gc():
  331. import gc
  332. hook = lambda *a: None
  333. sys.addaudithook(hook)
  334. for o in gc.get_objects():
  335. if isinstance(o, list):
  336. assert hook not in o
  337. if __name__ == "__main__":
  338. from test.support import suppress_msvcrt_asserts
  339. suppress_msvcrt_asserts()
  340. test = sys.argv[1]
  341. globals()[test]()