test_context.py 31 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103
  1. import concurrent.futures
  2. import contextvars
  3. import functools
  4. import gc
  5. import random
  6. import time
  7. import unittest
  8. import weakref
  9. from test.support import threading_helper
  10. try:
  11. from _testcapi import hamt
  12. except ImportError:
  13. hamt = None
  14. def isolated_context(func):
  15. """Needed to make reftracking test mode work."""
  16. @functools.wraps(func)
  17. def wrapper(*args, **kwargs):
  18. ctx = contextvars.Context()
  19. return ctx.run(func, *args, **kwargs)
  20. return wrapper
  21. class ContextTest(unittest.TestCase):
  22. def test_context_var_new_1(self):
  23. with self.assertRaisesRegex(TypeError, 'takes exactly 1'):
  24. contextvars.ContextVar()
  25. with self.assertRaisesRegex(TypeError, 'must be a str'):
  26. contextvars.ContextVar(1)
  27. c = contextvars.ContextVar('aaa')
  28. self.assertEqual(c.name, 'aaa')
  29. with self.assertRaises(AttributeError):
  30. c.name = 'bbb'
  31. self.assertNotEqual(hash(c), hash('aaa'))
  32. @isolated_context
  33. def test_context_var_repr_1(self):
  34. c = contextvars.ContextVar('a')
  35. self.assertIn('a', repr(c))
  36. c = contextvars.ContextVar('a', default=123)
  37. self.assertIn('123', repr(c))
  38. lst = []
  39. c = contextvars.ContextVar('a', default=lst)
  40. lst.append(c)
  41. self.assertIn('...', repr(c))
  42. self.assertIn('...', repr(lst))
  43. t = c.set(1)
  44. self.assertIn(repr(c), repr(t))
  45. self.assertNotIn(' used ', repr(t))
  46. c.reset(t)
  47. self.assertIn(' used ', repr(t))
  48. def test_context_subclassing_1(self):
  49. with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
  50. class MyContextVar(contextvars.ContextVar):
  51. # Potentially we might want ContextVars to be subclassable.
  52. pass
  53. with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
  54. class MyContext(contextvars.Context):
  55. pass
  56. with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
  57. class MyToken(contextvars.Token):
  58. pass
  59. def test_context_new_1(self):
  60. with self.assertRaisesRegex(TypeError, 'any arguments'):
  61. contextvars.Context(1)
  62. with self.assertRaisesRegex(TypeError, 'any arguments'):
  63. contextvars.Context(1, a=1)
  64. with self.assertRaisesRegex(TypeError, 'any arguments'):
  65. contextvars.Context(a=1)
  66. contextvars.Context(**{})
  67. def test_context_typerrors_1(self):
  68. ctx = contextvars.Context()
  69. with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
  70. ctx[1]
  71. with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
  72. 1 in ctx
  73. with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
  74. ctx.get(1)
  75. def test_context_get_context_1(self):
  76. ctx = contextvars.copy_context()
  77. self.assertIsInstance(ctx, contextvars.Context)
  78. def test_context_run_1(self):
  79. ctx = contextvars.Context()
  80. with self.assertRaisesRegex(TypeError, 'missing 1 required'):
  81. ctx.run()
  82. def test_context_run_2(self):
  83. ctx = contextvars.Context()
  84. def func(*args, **kwargs):
  85. kwargs['spam'] = 'foo'
  86. args += ('bar',)
  87. return args, kwargs
  88. for f in (func, functools.partial(func)):
  89. # partial doesn't support FASTCALL
  90. self.assertEqual(ctx.run(f), (('bar',), {'spam': 'foo'}))
  91. self.assertEqual(ctx.run(f, 1), ((1, 'bar'), {'spam': 'foo'}))
  92. self.assertEqual(
  93. ctx.run(f, a=2),
  94. (('bar',), {'a': 2, 'spam': 'foo'}))
  95. self.assertEqual(
  96. ctx.run(f, 11, a=2),
  97. ((11, 'bar'), {'a': 2, 'spam': 'foo'}))
  98. a = {}
  99. self.assertEqual(
  100. ctx.run(f, 11, **a),
  101. ((11, 'bar'), {'spam': 'foo'}))
  102. self.assertEqual(a, {})
  103. def test_context_run_3(self):
  104. ctx = contextvars.Context()
  105. def func(*args, **kwargs):
  106. 1 / 0
  107. with self.assertRaises(ZeroDivisionError):
  108. ctx.run(func)
  109. with self.assertRaises(ZeroDivisionError):
  110. ctx.run(func, 1, 2)
  111. with self.assertRaises(ZeroDivisionError):
  112. ctx.run(func, 1, 2, a=123)
  113. @isolated_context
  114. def test_context_run_4(self):
  115. ctx1 = contextvars.Context()
  116. ctx2 = contextvars.Context()
  117. var = contextvars.ContextVar('var')
  118. def func2():
  119. self.assertIsNone(var.get(None))
  120. def func1():
  121. self.assertIsNone(var.get(None))
  122. var.set('spam')
  123. ctx2.run(func2)
  124. self.assertEqual(var.get(None), 'spam')
  125. cur = contextvars.copy_context()
  126. self.assertEqual(len(cur), 1)
  127. self.assertEqual(cur[var], 'spam')
  128. return cur
  129. returned_ctx = ctx1.run(func1)
  130. self.assertEqual(ctx1, returned_ctx)
  131. self.assertEqual(returned_ctx[var], 'spam')
  132. self.assertIn(var, returned_ctx)
  133. def test_context_run_5(self):
  134. ctx = contextvars.Context()
  135. var = contextvars.ContextVar('var')
  136. def func():
  137. self.assertIsNone(var.get(None))
  138. var.set('spam')
  139. 1 / 0
  140. with self.assertRaises(ZeroDivisionError):
  141. ctx.run(func)
  142. self.assertIsNone(var.get(None))
  143. def test_context_run_6(self):
  144. ctx = contextvars.Context()
  145. c = contextvars.ContextVar('a', default=0)
  146. def fun():
  147. self.assertEqual(c.get(), 0)
  148. self.assertIsNone(ctx.get(c))
  149. c.set(42)
  150. self.assertEqual(c.get(), 42)
  151. self.assertEqual(ctx.get(c), 42)
  152. ctx.run(fun)
  153. def test_context_run_7(self):
  154. ctx = contextvars.Context()
  155. def fun():
  156. with self.assertRaisesRegex(RuntimeError, 'is already entered'):
  157. ctx.run(fun)
  158. ctx.run(fun)
  159. @isolated_context
  160. def test_context_getset_1(self):
  161. c = contextvars.ContextVar('c')
  162. with self.assertRaises(LookupError):
  163. c.get()
  164. self.assertIsNone(c.get(None))
  165. t0 = c.set(42)
  166. self.assertEqual(c.get(), 42)
  167. self.assertEqual(c.get(None), 42)
  168. self.assertIs(t0.old_value, t0.MISSING)
  169. self.assertIs(t0.old_value, contextvars.Token.MISSING)
  170. self.assertIs(t0.var, c)
  171. t = c.set('spam')
  172. self.assertEqual(c.get(), 'spam')
  173. self.assertEqual(c.get(None), 'spam')
  174. self.assertEqual(t.old_value, 42)
  175. c.reset(t)
  176. self.assertEqual(c.get(), 42)
  177. self.assertEqual(c.get(None), 42)
  178. c.set('spam2')
  179. with self.assertRaisesRegex(RuntimeError, 'has already been used'):
  180. c.reset(t)
  181. self.assertEqual(c.get(), 'spam2')
  182. ctx1 = contextvars.copy_context()
  183. self.assertIn(c, ctx1)
  184. c.reset(t0)
  185. with self.assertRaisesRegex(RuntimeError, 'has already been used'):
  186. c.reset(t0)
  187. self.assertIsNone(c.get(None))
  188. self.assertIn(c, ctx1)
  189. self.assertEqual(ctx1[c], 'spam2')
  190. self.assertEqual(ctx1.get(c, 'aa'), 'spam2')
  191. self.assertEqual(len(ctx1), 1)
  192. self.assertEqual(list(ctx1.items()), [(c, 'spam2')])
  193. self.assertEqual(list(ctx1.values()), ['spam2'])
  194. self.assertEqual(list(ctx1.keys()), [c])
  195. self.assertEqual(list(ctx1), [c])
  196. ctx2 = contextvars.copy_context()
  197. self.assertNotIn(c, ctx2)
  198. with self.assertRaises(KeyError):
  199. ctx2[c]
  200. self.assertEqual(ctx2.get(c, 'aa'), 'aa')
  201. self.assertEqual(len(ctx2), 0)
  202. self.assertEqual(list(ctx2), [])
  203. @isolated_context
  204. def test_context_getset_2(self):
  205. v1 = contextvars.ContextVar('v1')
  206. v2 = contextvars.ContextVar('v2')
  207. t1 = v1.set(42)
  208. with self.assertRaisesRegex(ValueError, 'by a different'):
  209. v2.reset(t1)
  210. @isolated_context
  211. def test_context_getset_3(self):
  212. c = contextvars.ContextVar('c', default=42)
  213. ctx = contextvars.Context()
  214. def fun():
  215. self.assertEqual(c.get(), 42)
  216. with self.assertRaises(KeyError):
  217. ctx[c]
  218. self.assertIsNone(ctx.get(c))
  219. self.assertEqual(ctx.get(c, 'spam'), 'spam')
  220. self.assertNotIn(c, ctx)
  221. self.assertEqual(list(ctx.keys()), [])
  222. t = c.set(1)
  223. self.assertEqual(list(ctx.keys()), [c])
  224. self.assertEqual(ctx[c], 1)
  225. c.reset(t)
  226. self.assertEqual(list(ctx.keys()), [])
  227. with self.assertRaises(KeyError):
  228. ctx[c]
  229. ctx.run(fun)
  230. @isolated_context
  231. def test_context_getset_4(self):
  232. c = contextvars.ContextVar('c', default=42)
  233. ctx = contextvars.Context()
  234. tok = ctx.run(c.set, 1)
  235. with self.assertRaisesRegex(ValueError, 'different Context'):
  236. c.reset(tok)
  237. @isolated_context
  238. def test_context_getset_5(self):
  239. c = contextvars.ContextVar('c', default=42)
  240. c.set([])
  241. def fun():
  242. c.set([])
  243. c.get().append(42)
  244. self.assertEqual(c.get(), [42])
  245. contextvars.copy_context().run(fun)
  246. self.assertEqual(c.get(), [])
  247. def test_context_copy_1(self):
  248. ctx1 = contextvars.Context()
  249. c = contextvars.ContextVar('c', default=42)
  250. def ctx1_fun():
  251. c.set(10)
  252. ctx2 = ctx1.copy()
  253. self.assertEqual(ctx2[c], 10)
  254. c.set(20)
  255. self.assertEqual(ctx1[c], 20)
  256. self.assertEqual(ctx2[c], 10)
  257. ctx2.run(ctx2_fun)
  258. self.assertEqual(ctx1[c], 20)
  259. self.assertEqual(ctx2[c], 30)
  260. def ctx2_fun():
  261. self.assertEqual(c.get(), 10)
  262. c.set(30)
  263. self.assertEqual(c.get(), 30)
  264. ctx1.run(ctx1_fun)
  265. @isolated_context
  266. @threading_helper.requires_working_threading()
  267. def test_context_threads_1(self):
  268. cvar = contextvars.ContextVar('cvar')
  269. def sub(num):
  270. for i in range(10):
  271. cvar.set(num + i)
  272. time.sleep(random.uniform(0.001, 0.05))
  273. self.assertEqual(cvar.get(), num + i)
  274. return num
  275. tp = concurrent.futures.ThreadPoolExecutor(max_workers=10)
  276. try:
  277. results = list(tp.map(sub, range(10)))
  278. finally:
  279. tp.shutdown()
  280. self.assertEqual(results, list(range(10)))
  281. # HAMT Tests
  282. class HashKey:
  283. _crasher = None
  284. def __init__(self, hash, name, *, error_on_eq_to=None):
  285. assert hash != -1
  286. self.name = name
  287. self.hash = hash
  288. self.error_on_eq_to = error_on_eq_to
  289. def __repr__(self):
  290. return f'<Key name:{self.name} hash:{self.hash}>'
  291. def __hash__(self):
  292. if self._crasher is not None and self._crasher.error_on_hash:
  293. raise HashingError
  294. return self.hash
  295. def __eq__(self, other):
  296. if not isinstance(other, HashKey):
  297. return NotImplemented
  298. if self._crasher is not None and self._crasher.error_on_eq:
  299. raise EqError
  300. if self.error_on_eq_to is not None and self.error_on_eq_to is other:
  301. raise ValueError(f'cannot compare {self!r} to {other!r}')
  302. if other.error_on_eq_to is not None and other.error_on_eq_to is self:
  303. raise ValueError(f'cannot compare {other!r} to {self!r}')
  304. return (self.name, self.hash) == (other.name, other.hash)
  305. class KeyStr(str):
  306. def __hash__(self):
  307. if HashKey._crasher is not None and HashKey._crasher.error_on_hash:
  308. raise HashingError
  309. return super().__hash__()
  310. def __eq__(self, other):
  311. if HashKey._crasher is not None and HashKey._crasher.error_on_eq:
  312. raise EqError
  313. return super().__eq__(other)
  314. class HaskKeyCrasher:
  315. def __init__(self, *, error_on_hash=False, error_on_eq=False):
  316. self.error_on_hash = error_on_hash
  317. self.error_on_eq = error_on_eq
  318. def __enter__(self):
  319. if HashKey._crasher is not None:
  320. raise RuntimeError('cannot nest crashers')
  321. HashKey._crasher = self
  322. def __exit__(self, *exc):
  323. HashKey._crasher = None
  324. class HashingError(Exception):
  325. pass
  326. class EqError(Exception):
  327. pass
  328. @unittest.skipIf(hamt is None, '_testcapi lacks "hamt()" function')
  329. class HamtTest(unittest.TestCase):
  330. def test_hashkey_helper_1(self):
  331. k1 = HashKey(10, 'aaa')
  332. k2 = HashKey(10, 'bbb')
  333. self.assertNotEqual(k1, k2)
  334. self.assertEqual(hash(k1), hash(k2))
  335. d = dict()
  336. d[k1] = 'a'
  337. d[k2] = 'b'
  338. self.assertEqual(d[k1], 'a')
  339. self.assertEqual(d[k2], 'b')
  340. def test_hamt_basics_1(self):
  341. h = hamt()
  342. h = None # NoQA
  343. def test_hamt_basics_2(self):
  344. h = hamt()
  345. self.assertEqual(len(h), 0)
  346. h2 = h.set('a', 'b')
  347. self.assertIsNot(h, h2)
  348. self.assertEqual(len(h), 0)
  349. self.assertEqual(len(h2), 1)
  350. self.assertIsNone(h.get('a'))
  351. self.assertEqual(h.get('a', 42), 42)
  352. self.assertEqual(h2.get('a'), 'b')
  353. h3 = h2.set('b', 10)
  354. self.assertIsNot(h2, h3)
  355. self.assertEqual(len(h), 0)
  356. self.assertEqual(len(h2), 1)
  357. self.assertEqual(len(h3), 2)
  358. self.assertEqual(h3.get('a'), 'b')
  359. self.assertEqual(h3.get('b'), 10)
  360. self.assertIsNone(h.get('b'))
  361. self.assertIsNone(h2.get('b'))
  362. self.assertIsNone(h.get('a'))
  363. self.assertEqual(h2.get('a'), 'b')
  364. h = h2 = h3 = None
  365. def test_hamt_basics_3(self):
  366. h = hamt()
  367. o = object()
  368. h1 = h.set('1', o)
  369. h2 = h1.set('1', o)
  370. self.assertIs(h1, h2)
  371. def test_hamt_basics_4(self):
  372. h = hamt()
  373. h1 = h.set('key', [])
  374. h2 = h1.set('key', [])
  375. self.assertIsNot(h1, h2)
  376. self.assertEqual(len(h1), 1)
  377. self.assertEqual(len(h2), 1)
  378. self.assertIsNot(h1.get('key'), h2.get('key'))
  379. def test_hamt_collision_1(self):
  380. k1 = HashKey(10, 'aaa')
  381. k2 = HashKey(10, 'bbb')
  382. k3 = HashKey(10, 'ccc')
  383. h = hamt()
  384. h2 = h.set(k1, 'a')
  385. h3 = h2.set(k2, 'b')
  386. self.assertEqual(h.get(k1), None)
  387. self.assertEqual(h.get(k2), None)
  388. self.assertEqual(h2.get(k1), 'a')
  389. self.assertEqual(h2.get(k2), None)
  390. self.assertEqual(h3.get(k1), 'a')
  391. self.assertEqual(h3.get(k2), 'b')
  392. h4 = h3.set(k2, 'cc')
  393. h5 = h4.set(k3, 'aa')
  394. self.assertEqual(h3.get(k1), 'a')
  395. self.assertEqual(h3.get(k2), 'b')
  396. self.assertEqual(h4.get(k1), 'a')
  397. self.assertEqual(h4.get(k2), 'cc')
  398. self.assertEqual(h4.get(k3), None)
  399. self.assertEqual(h5.get(k1), 'a')
  400. self.assertEqual(h5.get(k2), 'cc')
  401. self.assertEqual(h5.get(k2), 'cc')
  402. self.assertEqual(h5.get(k3), 'aa')
  403. self.assertEqual(len(h), 0)
  404. self.assertEqual(len(h2), 1)
  405. self.assertEqual(len(h3), 2)
  406. self.assertEqual(len(h4), 2)
  407. self.assertEqual(len(h5), 3)
  408. def test_hamt_collision_3(self):
  409. # Test that iteration works with the deepest tree possible.
  410. # https://github.com/python/cpython/issues/93065
  411. C = HashKey(0b10000000_00000000_00000000_00000000, 'C')
  412. D = HashKey(0b10000000_00000000_00000000_00000000, 'D')
  413. E = HashKey(0b00000000_00000000_00000000_00000000, 'E')
  414. h = hamt()
  415. h = h.set(C, 'C')
  416. h = h.set(D, 'D')
  417. h = h.set(E, 'E')
  418. # BitmapNode(size=2 count=1 bitmap=0b1):
  419. # NULL:
  420. # BitmapNode(size=2 count=1 bitmap=0b1):
  421. # NULL:
  422. # BitmapNode(size=2 count=1 bitmap=0b1):
  423. # NULL:
  424. # BitmapNode(size=2 count=1 bitmap=0b1):
  425. # NULL:
  426. # BitmapNode(size=2 count=1 bitmap=0b1):
  427. # NULL:
  428. # BitmapNode(size=2 count=1 bitmap=0b1):
  429. # NULL:
  430. # BitmapNode(size=4 count=2 bitmap=0b101):
  431. # <Key name:E hash:0>: 'E'
  432. # NULL:
  433. # CollisionNode(size=4 id=0x107a24520):
  434. # <Key name:C hash:2147483648>: 'C'
  435. # <Key name:D hash:2147483648>: 'D'
  436. self.assertEqual({k.name for k in h.keys()}, {'C', 'D', 'E'})
  437. def test_hamt_stress(self):
  438. COLLECTION_SIZE = 7000
  439. TEST_ITERS_EVERY = 647
  440. CRASH_HASH_EVERY = 97
  441. CRASH_EQ_EVERY = 11
  442. RUN_XTIMES = 3
  443. for _ in range(RUN_XTIMES):
  444. h = hamt()
  445. d = dict()
  446. for i in range(COLLECTION_SIZE):
  447. key = KeyStr(i)
  448. if not (i % CRASH_HASH_EVERY):
  449. with HaskKeyCrasher(error_on_hash=True):
  450. with self.assertRaises(HashingError):
  451. h.set(key, i)
  452. h = h.set(key, i)
  453. if not (i % CRASH_EQ_EVERY):
  454. with HaskKeyCrasher(error_on_eq=True):
  455. with self.assertRaises(EqError):
  456. h.get(KeyStr(i)) # really trigger __eq__
  457. d[key] = i
  458. self.assertEqual(len(d), len(h))
  459. if not (i % TEST_ITERS_EVERY):
  460. self.assertEqual(set(h.items()), set(d.items()))
  461. self.assertEqual(len(h.items()), len(d.items()))
  462. self.assertEqual(len(h), COLLECTION_SIZE)
  463. for key in range(COLLECTION_SIZE):
  464. self.assertEqual(h.get(KeyStr(key), 'not found'), key)
  465. keys_to_delete = list(range(COLLECTION_SIZE))
  466. random.shuffle(keys_to_delete)
  467. for iter_i, i in enumerate(keys_to_delete):
  468. key = KeyStr(i)
  469. if not (iter_i % CRASH_HASH_EVERY):
  470. with HaskKeyCrasher(error_on_hash=True):
  471. with self.assertRaises(HashingError):
  472. h.delete(key)
  473. if not (iter_i % CRASH_EQ_EVERY):
  474. with HaskKeyCrasher(error_on_eq=True):
  475. with self.assertRaises(EqError):
  476. h.delete(KeyStr(i))
  477. h = h.delete(key)
  478. self.assertEqual(h.get(key, 'not found'), 'not found')
  479. del d[key]
  480. self.assertEqual(len(d), len(h))
  481. if iter_i == COLLECTION_SIZE // 2:
  482. hm = h
  483. dm = d.copy()
  484. if not (iter_i % TEST_ITERS_EVERY):
  485. self.assertEqual(set(h.keys()), set(d.keys()))
  486. self.assertEqual(len(h.keys()), len(d.keys()))
  487. self.assertEqual(len(d), 0)
  488. self.assertEqual(len(h), 0)
  489. # ============
  490. for key in dm:
  491. self.assertEqual(hm.get(str(key)), dm[key])
  492. self.assertEqual(len(dm), len(hm))
  493. for i, key in enumerate(keys_to_delete):
  494. hm = hm.delete(str(key))
  495. self.assertEqual(hm.get(str(key), 'not found'), 'not found')
  496. dm.pop(str(key), None)
  497. self.assertEqual(len(d), len(h))
  498. if not (i % TEST_ITERS_EVERY):
  499. self.assertEqual(set(h.values()), set(d.values()))
  500. self.assertEqual(len(h.values()), len(d.values()))
  501. self.assertEqual(len(d), 0)
  502. self.assertEqual(len(h), 0)
  503. self.assertEqual(list(h.items()), [])
  504. def test_hamt_delete_1(self):
  505. A = HashKey(100, 'A')
  506. B = HashKey(101, 'B')
  507. C = HashKey(102, 'C')
  508. D = HashKey(103, 'D')
  509. E = HashKey(104, 'E')
  510. Z = HashKey(-100, 'Z')
  511. Er = HashKey(103, 'Er', error_on_eq_to=D)
  512. h = hamt()
  513. h = h.set(A, 'a')
  514. h = h.set(B, 'b')
  515. h = h.set(C, 'c')
  516. h = h.set(D, 'd')
  517. h = h.set(E, 'e')
  518. orig_len = len(h)
  519. # BitmapNode(size=10 bitmap=0b111110000 id=0x10eadc618):
  520. # <Key name:A hash:100>: 'a'
  521. # <Key name:B hash:101>: 'b'
  522. # <Key name:C hash:102>: 'c'
  523. # <Key name:D hash:103>: 'd'
  524. # <Key name:E hash:104>: 'e'
  525. h = h.delete(C)
  526. self.assertEqual(len(h), orig_len - 1)
  527. with self.assertRaisesRegex(ValueError, 'cannot compare'):
  528. h.delete(Er)
  529. h = h.delete(D)
  530. self.assertEqual(len(h), orig_len - 2)
  531. h2 = h.delete(Z)
  532. self.assertIs(h2, h)
  533. h = h.delete(A)
  534. self.assertEqual(len(h), orig_len - 3)
  535. self.assertEqual(h.get(A, 42), 42)
  536. self.assertEqual(h.get(B), 'b')
  537. self.assertEqual(h.get(E), 'e')
  538. def test_hamt_delete_2(self):
  539. A = HashKey(100, 'A')
  540. B = HashKey(201001, 'B')
  541. C = HashKey(101001, 'C')
  542. D = HashKey(103, 'D')
  543. E = HashKey(104, 'E')
  544. Z = HashKey(-100, 'Z')
  545. Er = HashKey(201001, 'Er', error_on_eq_to=B)
  546. h = hamt()
  547. h = h.set(A, 'a')
  548. h = h.set(B, 'b')
  549. h = h.set(C, 'c')
  550. h = h.set(D, 'd')
  551. h = h.set(E, 'e')
  552. orig_len = len(h)
  553. # BitmapNode(size=8 bitmap=0b1110010000):
  554. # <Key name:A hash:100>: 'a'
  555. # <Key name:D hash:103>: 'd'
  556. # <Key name:E hash:104>: 'e'
  557. # NULL:
  558. # BitmapNode(size=4 bitmap=0b100000000001000000000):
  559. # <Key name:B hash:201001>: 'b'
  560. # <Key name:C hash:101001>: 'c'
  561. with self.assertRaisesRegex(ValueError, 'cannot compare'):
  562. h.delete(Er)
  563. h = h.delete(Z)
  564. self.assertEqual(len(h), orig_len)
  565. h = h.delete(C)
  566. self.assertEqual(len(h), orig_len - 1)
  567. h = h.delete(B)
  568. self.assertEqual(len(h), orig_len - 2)
  569. h = h.delete(A)
  570. self.assertEqual(len(h), orig_len - 3)
  571. self.assertEqual(h.get(D), 'd')
  572. self.assertEqual(h.get(E), 'e')
  573. h = h.delete(A)
  574. h = h.delete(B)
  575. h = h.delete(D)
  576. h = h.delete(E)
  577. self.assertEqual(len(h), 0)
  578. def test_hamt_delete_3(self):
  579. A = HashKey(100, 'A')
  580. B = HashKey(101, 'B')
  581. C = HashKey(100100, 'C')
  582. D = HashKey(100100, 'D')
  583. E = HashKey(104, 'E')
  584. h = hamt()
  585. h = h.set(A, 'a')
  586. h = h.set(B, 'b')
  587. h = h.set(C, 'c')
  588. h = h.set(D, 'd')
  589. h = h.set(E, 'e')
  590. orig_len = len(h)
  591. # BitmapNode(size=6 bitmap=0b100110000):
  592. # NULL:
  593. # BitmapNode(size=4 bitmap=0b1000000000000000000001000):
  594. # <Key name:A hash:100>: 'a'
  595. # NULL:
  596. # CollisionNode(size=4 id=0x108572410):
  597. # <Key name:C hash:100100>: 'c'
  598. # <Key name:D hash:100100>: 'd'
  599. # <Key name:B hash:101>: 'b'
  600. # <Key name:E hash:104>: 'e'
  601. h = h.delete(A)
  602. self.assertEqual(len(h), orig_len - 1)
  603. h = h.delete(E)
  604. self.assertEqual(len(h), orig_len - 2)
  605. self.assertEqual(h.get(C), 'c')
  606. self.assertEqual(h.get(B), 'b')
  607. def test_hamt_delete_4(self):
  608. A = HashKey(100, 'A')
  609. B = HashKey(101, 'B')
  610. C = HashKey(100100, 'C')
  611. D = HashKey(100100, 'D')
  612. E = HashKey(100100, 'E')
  613. h = hamt()
  614. h = h.set(A, 'a')
  615. h = h.set(B, 'b')
  616. h = h.set(C, 'c')
  617. h = h.set(D, 'd')
  618. h = h.set(E, 'e')
  619. orig_len = len(h)
  620. # BitmapNode(size=4 bitmap=0b110000):
  621. # NULL:
  622. # BitmapNode(size=4 bitmap=0b1000000000000000000001000):
  623. # <Key name:A hash:100>: 'a'
  624. # NULL:
  625. # CollisionNode(size=6 id=0x10515ef30):
  626. # <Key name:C hash:100100>: 'c'
  627. # <Key name:D hash:100100>: 'd'
  628. # <Key name:E hash:100100>: 'e'
  629. # <Key name:B hash:101>: 'b'
  630. h = h.delete(D)
  631. self.assertEqual(len(h), orig_len - 1)
  632. h = h.delete(E)
  633. self.assertEqual(len(h), orig_len - 2)
  634. h = h.delete(C)
  635. self.assertEqual(len(h), orig_len - 3)
  636. h = h.delete(A)
  637. self.assertEqual(len(h), orig_len - 4)
  638. h = h.delete(B)
  639. self.assertEqual(len(h), 0)
  640. def test_hamt_delete_5(self):
  641. h = hamt()
  642. keys = []
  643. for i in range(17):
  644. key = HashKey(i, str(i))
  645. keys.append(key)
  646. h = h.set(key, f'val-{i}')
  647. collision_key16 = HashKey(16, '18')
  648. h = h.set(collision_key16, 'collision')
  649. # ArrayNode(id=0x10f8b9318):
  650. # 0::
  651. # BitmapNode(size=2 count=1 bitmap=0b1):
  652. # <Key name:0 hash:0>: 'val-0'
  653. #
  654. # ... 14 more BitmapNodes ...
  655. #
  656. # 15::
  657. # BitmapNode(size=2 count=1 bitmap=0b1):
  658. # <Key name:15 hash:15>: 'val-15'
  659. #
  660. # 16::
  661. # BitmapNode(size=2 count=1 bitmap=0b1):
  662. # NULL:
  663. # CollisionNode(size=4 id=0x10f2f5af8):
  664. # <Key name:16 hash:16>: 'val-16'
  665. # <Key name:18 hash:16>: 'collision'
  666. self.assertEqual(len(h), 18)
  667. h = h.delete(keys[2])
  668. self.assertEqual(len(h), 17)
  669. h = h.delete(collision_key16)
  670. self.assertEqual(len(h), 16)
  671. h = h.delete(keys[16])
  672. self.assertEqual(len(h), 15)
  673. h = h.delete(keys[1])
  674. self.assertEqual(len(h), 14)
  675. h = h.delete(keys[1])
  676. self.assertEqual(len(h), 14)
  677. for key in keys:
  678. h = h.delete(key)
  679. self.assertEqual(len(h), 0)
  680. def test_hamt_items_1(self):
  681. A = HashKey(100, 'A')
  682. B = HashKey(201001, 'B')
  683. C = HashKey(101001, 'C')
  684. D = HashKey(103, 'D')
  685. E = HashKey(104, 'E')
  686. F = HashKey(110, 'F')
  687. h = hamt()
  688. h = h.set(A, 'a')
  689. h = h.set(B, 'b')
  690. h = h.set(C, 'c')
  691. h = h.set(D, 'd')
  692. h = h.set(E, 'e')
  693. h = h.set(F, 'f')
  694. it = h.items()
  695. self.assertEqual(
  696. set(list(it)),
  697. {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
  698. def test_hamt_items_2(self):
  699. A = HashKey(100, 'A')
  700. B = HashKey(101, 'B')
  701. C = HashKey(100100, 'C')
  702. D = HashKey(100100, 'D')
  703. E = HashKey(100100, 'E')
  704. F = HashKey(110, 'F')
  705. h = hamt()
  706. h = h.set(A, 'a')
  707. h = h.set(B, 'b')
  708. h = h.set(C, 'c')
  709. h = h.set(D, 'd')
  710. h = h.set(E, 'e')
  711. h = h.set(F, 'f')
  712. it = h.items()
  713. self.assertEqual(
  714. set(list(it)),
  715. {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
  716. def test_hamt_keys_1(self):
  717. A = HashKey(100, 'A')
  718. B = HashKey(101, 'B')
  719. C = HashKey(100100, 'C')
  720. D = HashKey(100100, 'D')
  721. E = HashKey(100100, 'E')
  722. F = HashKey(110, 'F')
  723. h = hamt()
  724. h = h.set(A, 'a')
  725. h = h.set(B, 'b')
  726. h = h.set(C, 'c')
  727. h = h.set(D, 'd')
  728. h = h.set(E, 'e')
  729. h = h.set(F, 'f')
  730. self.assertEqual(set(list(h.keys())), {A, B, C, D, E, F})
  731. self.assertEqual(set(list(h)), {A, B, C, D, E, F})
  732. def test_hamt_items_3(self):
  733. h = hamt()
  734. self.assertEqual(len(h.items()), 0)
  735. self.assertEqual(list(h.items()), [])
  736. def test_hamt_eq_1(self):
  737. A = HashKey(100, 'A')
  738. B = HashKey(101, 'B')
  739. C = HashKey(100100, 'C')
  740. D = HashKey(100100, 'D')
  741. E = HashKey(120, 'E')
  742. h1 = hamt()
  743. h1 = h1.set(A, 'a')
  744. h1 = h1.set(B, 'b')
  745. h1 = h1.set(C, 'c')
  746. h1 = h1.set(D, 'd')
  747. h2 = hamt()
  748. h2 = h2.set(A, 'a')
  749. self.assertFalse(h1 == h2)
  750. self.assertTrue(h1 != h2)
  751. h2 = h2.set(B, 'b')
  752. self.assertFalse(h1 == h2)
  753. self.assertTrue(h1 != h2)
  754. h2 = h2.set(C, 'c')
  755. self.assertFalse(h1 == h2)
  756. self.assertTrue(h1 != h2)
  757. h2 = h2.set(D, 'd2')
  758. self.assertFalse(h1 == h2)
  759. self.assertTrue(h1 != h2)
  760. h2 = h2.set(D, 'd')
  761. self.assertTrue(h1 == h2)
  762. self.assertFalse(h1 != h2)
  763. h2 = h2.set(E, 'e')
  764. self.assertFalse(h1 == h2)
  765. self.assertTrue(h1 != h2)
  766. h2 = h2.delete(D)
  767. self.assertFalse(h1 == h2)
  768. self.assertTrue(h1 != h2)
  769. h2 = h2.set(E, 'd')
  770. self.assertFalse(h1 == h2)
  771. self.assertTrue(h1 != h2)
  772. def test_hamt_eq_2(self):
  773. A = HashKey(100, 'A')
  774. Er = HashKey(100, 'Er', error_on_eq_to=A)
  775. h1 = hamt()
  776. h1 = h1.set(A, 'a')
  777. h2 = hamt()
  778. h2 = h2.set(Er, 'a')
  779. with self.assertRaisesRegex(ValueError, 'cannot compare'):
  780. h1 == h2
  781. with self.assertRaisesRegex(ValueError, 'cannot compare'):
  782. h1 != h2
  783. def test_hamt_gc_1(self):
  784. A = HashKey(100, 'A')
  785. h = hamt()
  786. h = h.set(0, 0) # empty HAMT node is memoized in hamt.c
  787. ref = weakref.ref(h)
  788. a = []
  789. a.append(a)
  790. a.append(h)
  791. b = []
  792. a.append(b)
  793. b.append(a)
  794. h = h.set(A, b)
  795. del h, a, b
  796. gc.collect()
  797. gc.collect()
  798. gc.collect()
  799. self.assertIsNone(ref())
  800. def test_hamt_gc_2(self):
  801. A = HashKey(100, 'A')
  802. B = HashKey(101, 'B')
  803. h = hamt()
  804. h = h.set(A, 'a')
  805. h = h.set(A, h)
  806. ref = weakref.ref(h)
  807. hi = h.items()
  808. next(hi)
  809. del h, hi
  810. gc.collect()
  811. gc.collect()
  812. gc.collect()
  813. self.assertIsNone(ref())
  814. def test_hamt_in_1(self):
  815. A = HashKey(100, 'A')
  816. AA = HashKey(100, 'A')
  817. B = HashKey(101, 'B')
  818. h = hamt()
  819. h = h.set(A, 1)
  820. self.assertTrue(A in h)
  821. self.assertFalse(B in h)
  822. with self.assertRaises(EqError):
  823. with HaskKeyCrasher(error_on_eq=True):
  824. AA in h
  825. with self.assertRaises(HashingError):
  826. with HaskKeyCrasher(error_on_hash=True):
  827. AA in h
  828. def test_hamt_getitem_1(self):
  829. A = HashKey(100, 'A')
  830. AA = HashKey(100, 'A')
  831. B = HashKey(101, 'B')
  832. h = hamt()
  833. h = h.set(A, 1)
  834. self.assertEqual(h[A], 1)
  835. self.assertEqual(h[AA], 1)
  836. with self.assertRaises(KeyError):
  837. h[B]
  838. with self.assertRaises(EqError):
  839. with HaskKeyCrasher(error_on_eq=True):
  840. h[AA]
  841. with self.assertRaises(HashingError):
  842. with HaskKeyCrasher(error_on_hash=True):
  843. h[AA]
  844. if __name__ == "__main__":
  845. unittest.main()