test_weakset.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. import unittest
  2. from weakref import WeakSet
  3. import copy
  4. import string
  5. from collections import UserString as ustr
  6. from collections.abc import Set, MutableSet
  7. import gc
  8. import contextlib
  9. from test import support
  10. class Foo:
  11. pass
  12. class RefCycle:
  13. def __init__(self):
  14. self.cycle = self
  15. class WeakSetSubclass(WeakSet):
  16. pass
  17. class WeakSetWithSlots(WeakSet):
  18. __slots__ = ('x', 'y')
  19. class TestWeakSet(unittest.TestCase):
  20. def setUp(self):
  21. # need to keep references to them
  22. self.items = [ustr(c) for c in ('a', 'b', 'c')]
  23. self.items2 = [ustr(c) for c in ('x', 'y', 'z')]
  24. self.ab_items = [ustr(c) for c in 'ab']
  25. self.abcde_items = [ustr(c) for c in 'abcde']
  26. self.def_items = [ustr(c) for c in 'def']
  27. self.ab_weakset = WeakSet(self.ab_items)
  28. self.abcde_weakset = WeakSet(self.abcde_items)
  29. self.def_weakset = WeakSet(self.def_items)
  30. self.letters = [ustr(c) for c in string.ascii_letters]
  31. self.s = WeakSet(self.items)
  32. self.d = dict.fromkeys(self.items)
  33. self.obj = ustr('F')
  34. self.fs = WeakSet([self.obj])
  35. def test_methods(self):
  36. weaksetmethods = dir(WeakSet)
  37. for method in dir(set):
  38. if method == 'test_c_api' or method.startswith('_'):
  39. continue
  40. self.assertIn(method, weaksetmethods,
  41. "WeakSet missing method " + method)
  42. def test_new_or_init(self):
  43. self.assertRaises(TypeError, WeakSet, [], 2)
  44. def test_len(self):
  45. self.assertEqual(len(self.s), len(self.d))
  46. self.assertEqual(len(self.fs), 1)
  47. del self.obj
  48. support.gc_collect() # For PyPy or other GCs.
  49. self.assertEqual(len(self.fs), 0)
  50. def test_contains(self):
  51. for c in self.letters:
  52. self.assertEqual(c in self.s, c in self.d)
  53. # 1 is not weakref'able, but that TypeError is caught by __contains__
  54. self.assertNotIn(1, self.s)
  55. self.assertIn(self.obj, self.fs)
  56. del self.obj
  57. support.gc_collect() # For PyPy or other GCs.
  58. self.assertNotIn(ustr('F'), self.fs)
  59. def test_union(self):
  60. u = self.s.union(self.items2)
  61. for c in self.letters:
  62. self.assertEqual(c in u, c in self.d or c in self.items2)
  63. self.assertEqual(self.s, WeakSet(self.items))
  64. self.assertEqual(type(u), WeakSet)
  65. self.assertRaises(TypeError, self.s.union, [[]])
  66. for C in set, frozenset, dict.fromkeys, list, tuple:
  67. x = WeakSet(self.items + self.items2)
  68. c = C(self.items2)
  69. self.assertEqual(self.s.union(c), x)
  70. del c
  71. self.assertEqual(len(u), len(self.items) + len(self.items2))
  72. self.items2.pop()
  73. gc.collect()
  74. self.assertEqual(len(u), len(self.items) + len(self.items2))
  75. def test_or(self):
  76. i = self.s.union(self.items2)
  77. self.assertEqual(self.s | set(self.items2), i)
  78. self.assertEqual(self.s | frozenset(self.items2), i)
  79. def test_intersection(self):
  80. s = WeakSet(self.letters)
  81. i = s.intersection(self.items2)
  82. for c in self.letters:
  83. self.assertEqual(c in i, c in self.items2 and c in self.letters)
  84. self.assertEqual(s, WeakSet(self.letters))
  85. self.assertEqual(type(i), WeakSet)
  86. for C in set, frozenset, dict.fromkeys, list, tuple:
  87. x = WeakSet([])
  88. self.assertEqual(i.intersection(C(self.items)), x)
  89. self.assertEqual(len(i), len(self.items2))
  90. self.items2.pop()
  91. gc.collect()
  92. self.assertEqual(len(i), len(self.items2))
  93. def test_isdisjoint(self):
  94. self.assertTrue(self.s.isdisjoint(WeakSet(self.items2)))
  95. self.assertTrue(not self.s.isdisjoint(WeakSet(self.letters)))
  96. def test_and(self):
  97. i = self.s.intersection(self.items2)
  98. self.assertEqual(self.s & set(self.items2), i)
  99. self.assertEqual(self.s & frozenset(self.items2), i)
  100. def test_difference(self):
  101. i = self.s.difference(self.items2)
  102. for c in self.letters:
  103. self.assertEqual(c in i, c in self.d and c not in self.items2)
  104. self.assertEqual(self.s, WeakSet(self.items))
  105. self.assertEqual(type(i), WeakSet)
  106. self.assertRaises(TypeError, self.s.difference, [[]])
  107. def test_sub(self):
  108. i = self.s.difference(self.items2)
  109. self.assertEqual(self.s - set(self.items2), i)
  110. self.assertEqual(self.s - frozenset(self.items2), i)
  111. def test_symmetric_difference(self):
  112. i = self.s.symmetric_difference(self.items2)
  113. for c in self.letters:
  114. self.assertEqual(c in i, (c in self.d) ^ (c in self.items2))
  115. self.assertEqual(self.s, WeakSet(self.items))
  116. self.assertEqual(type(i), WeakSet)
  117. self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
  118. self.assertEqual(len(i), len(self.items) + len(self.items2))
  119. self.items2.pop()
  120. gc.collect()
  121. self.assertEqual(len(i), len(self.items) + len(self.items2))
  122. def test_xor(self):
  123. i = self.s.symmetric_difference(self.items2)
  124. self.assertEqual(self.s ^ set(self.items2), i)
  125. self.assertEqual(self.s ^ frozenset(self.items2), i)
  126. def test_sub_and_super(self):
  127. self.assertTrue(self.ab_weakset <= self.abcde_weakset)
  128. self.assertTrue(self.abcde_weakset <= self.abcde_weakset)
  129. self.assertTrue(self.abcde_weakset >= self.ab_weakset)
  130. self.assertFalse(self.abcde_weakset <= self.def_weakset)
  131. self.assertFalse(self.abcde_weakset >= self.def_weakset)
  132. self.assertTrue(set('a').issubset('abc'))
  133. self.assertTrue(set('abc').issuperset('a'))
  134. self.assertFalse(set('a').issubset('cbs'))
  135. self.assertFalse(set('cbs').issuperset('a'))
  136. def test_lt(self):
  137. self.assertTrue(self.ab_weakset < self.abcde_weakset)
  138. self.assertFalse(self.abcde_weakset < self.def_weakset)
  139. self.assertFalse(self.ab_weakset < self.ab_weakset)
  140. self.assertFalse(WeakSet() < WeakSet())
  141. def test_gt(self):
  142. self.assertTrue(self.abcde_weakset > self.ab_weakset)
  143. self.assertFalse(self.abcde_weakset > self.def_weakset)
  144. self.assertFalse(self.ab_weakset > self.ab_weakset)
  145. self.assertFalse(WeakSet() > WeakSet())
  146. def test_gc(self):
  147. # Create a nest of cycles to exercise overall ref count check
  148. s = WeakSet(Foo() for i in range(1000))
  149. for elem in s:
  150. elem.cycle = s
  151. elem.sub = elem
  152. elem.set = WeakSet([elem])
  153. def test_subclass_with_custom_hash(self):
  154. # Bug #1257731
  155. class H(WeakSet):
  156. def __hash__(self):
  157. return int(id(self) & 0x7fffffff)
  158. s=H()
  159. f=set()
  160. f.add(s)
  161. self.assertIn(s, f)
  162. f.remove(s)
  163. f.add(s)
  164. f.discard(s)
  165. def test_init(self):
  166. s = WeakSet()
  167. s.__init__(self.items)
  168. self.assertEqual(s, self.s)
  169. s.__init__(self.items2)
  170. self.assertEqual(s, WeakSet(self.items2))
  171. self.assertRaises(TypeError, s.__init__, s, 2);
  172. self.assertRaises(TypeError, s.__init__, 1);
  173. def test_constructor_identity(self):
  174. s = WeakSet(self.items)
  175. t = WeakSet(s)
  176. self.assertNotEqual(id(s), id(t))
  177. def test_hash(self):
  178. self.assertRaises(TypeError, hash, self.s)
  179. def test_clear(self):
  180. self.s.clear()
  181. self.assertEqual(self.s, WeakSet([]))
  182. self.assertEqual(len(self.s), 0)
  183. def test_copy(self):
  184. dup = self.s.copy()
  185. self.assertEqual(self.s, dup)
  186. self.assertNotEqual(id(self.s), id(dup))
  187. def test_add(self):
  188. x = ustr('Q')
  189. self.s.add(x)
  190. self.assertIn(x, self.s)
  191. dup = self.s.copy()
  192. self.s.add(x)
  193. self.assertEqual(self.s, dup)
  194. self.assertRaises(TypeError, self.s.add, [])
  195. self.fs.add(Foo())
  196. support.gc_collect() # For PyPy or other GCs.
  197. self.assertTrue(len(self.fs) == 1)
  198. self.fs.add(self.obj)
  199. self.assertTrue(len(self.fs) == 1)
  200. def test_remove(self):
  201. x = ustr('a')
  202. self.s.remove(x)
  203. self.assertNotIn(x, self.s)
  204. self.assertRaises(KeyError, self.s.remove, x)
  205. self.assertRaises(TypeError, self.s.remove, [])
  206. def test_discard(self):
  207. a, q = ustr('a'), ustr('Q')
  208. self.s.discard(a)
  209. self.assertNotIn(a, self.s)
  210. self.s.discard(q)
  211. self.assertRaises(TypeError, self.s.discard, [])
  212. def test_pop(self):
  213. for i in range(len(self.s)):
  214. elem = self.s.pop()
  215. self.assertNotIn(elem, self.s)
  216. self.assertRaises(KeyError, self.s.pop)
  217. def test_update(self):
  218. retval = self.s.update(self.items2)
  219. self.assertEqual(retval, None)
  220. for c in (self.items + self.items2):
  221. self.assertIn(c, self.s)
  222. self.assertRaises(TypeError, self.s.update, [[]])
  223. def test_update_set(self):
  224. self.s.update(set(self.items2))
  225. for c in (self.items + self.items2):
  226. self.assertIn(c, self.s)
  227. def test_ior(self):
  228. self.s |= set(self.items2)
  229. for c in (self.items + self.items2):
  230. self.assertIn(c, self.s)
  231. def test_intersection_update(self):
  232. retval = self.s.intersection_update(self.items2)
  233. self.assertEqual(retval, None)
  234. for c in (self.items + self.items2):
  235. if c in self.items2 and c in self.items:
  236. self.assertIn(c, self.s)
  237. else:
  238. self.assertNotIn(c, self.s)
  239. self.assertRaises(TypeError, self.s.intersection_update, [[]])
  240. def test_iand(self):
  241. self.s &= set(self.items2)
  242. for c in (self.items + self.items2):
  243. if c in self.items2 and c in self.items:
  244. self.assertIn(c, self.s)
  245. else:
  246. self.assertNotIn(c, self.s)
  247. def test_difference_update(self):
  248. retval = self.s.difference_update(self.items2)
  249. self.assertEqual(retval, None)
  250. for c in (self.items + self.items2):
  251. if c in self.items and c not in self.items2:
  252. self.assertIn(c, self.s)
  253. else:
  254. self.assertNotIn(c, self.s)
  255. self.assertRaises(TypeError, self.s.difference_update, [[]])
  256. self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
  257. def test_isub(self):
  258. self.s -= set(self.items2)
  259. for c in (self.items + self.items2):
  260. if c in self.items and c not in self.items2:
  261. self.assertIn(c, self.s)
  262. else:
  263. self.assertNotIn(c, self.s)
  264. def test_symmetric_difference_update(self):
  265. retval = self.s.symmetric_difference_update(self.items2)
  266. self.assertEqual(retval, None)
  267. for c in (self.items + self.items2):
  268. if (c in self.items) ^ (c in self.items2):
  269. self.assertIn(c, self.s)
  270. else:
  271. self.assertNotIn(c, self.s)
  272. self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
  273. def test_ixor(self):
  274. self.s ^= set(self.items2)
  275. for c in (self.items + self.items2):
  276. if (c in self.items) ^ (c in self.items2):
  277. self.assertIn(c, self.s)
  278. else:
  279. self.assertNotIn(c, self.s)
  280. def test_inplace_on_self(self):
  281. t = self.s.copy()
  282. t |= t
  283. self.assertEqual(t, self.s)
  284. t &= t
  285. self.assertEqual(t, self.s)
  286. t -= t
  287. self.assertEqual(t, WeakSet())
  288. t = self.s.copy()
  289. t ^= t
  290. self.assertEqual(t, WeakSet())
  291. def test_eq(self):
  292. # issue 5964
  293. self.assertTrue(self.s == self.s)
  294. self.assertTrue(self.s == WeakSet(self.items))
  295. self.assertFalse(self.s == set(self.items))
  296. self.assertFalse(self.s == list(self.items))
  297. self.assertFalse(self.s == tuple(self.items))
  298. self.assertFalse(self.s == WeakSet([Foo]))
  299. self.assertFalse(self.s == 1)
  300. def test_ne(self):
  301. self.assertTrue(self.s != set(self.items))
  302. s1 = WeakSet()
  303. s2 = WeakSet()
  304. self.assertFalse(s1 != s2)
  305. def test_weak_destroy_while_iterating(self):
  306. # Issue #7105: iterators shouldn't crash when a key is implicitly removed
  307. # Create new items to be sure no-one else holds a reference
  308. items = [ustr(c) for c in ('a', 'b', 'c')]
  309. s = WeakSet(items)
  310. it = iter(s)
  311. next(it) # Trigger internal iteration
  312. # Destroy an item
  313. del items[-1]
  314. gc.collect() # just in case
  315. # We have removed either the first consumed items, or another one
  316. self.assertIn(len(list(it)), [len(items), len(items) - 1])
  317. del it
  318. # The removal has been committed
  319. self.assertEqual(len(s), len(items))
  320. def test_weak_destroy_and_mutate_while_iterating(self):
  321. # Issue #7105: iterators shouldn't crash when a key is implicitly removed
  322. items = [ustr(c) for c in string.ascii_letters]
  323. s = WeakSet(items)
  324. @contextlib.contextmanager
  325. def testcontext():
  326. try:
  327. it = iter(s)
  328. # Start iterator
  329. yielded = ustr(str(next(it)))
  330. # Schedule an item for removal and recreate it
  331. u = ustr(str(items.pop()))
  332. if yielded == u:
  333. # The iterator still has a reference to the removed item,
  334. # advance it (issue #20006).
  335. next(it)
  336. gc.collect() # just in case
  337. yield u
  338. finally:
  339. it = None # should commit all removals
  340. with testcontext() as u:
  341. self.assertNotIn(u, s)
  342. with testcontext() as u:
  343. self.assertRaises(KeyError, s.remove, u)
  344. self.assertNotIn(u, s)
  345. with testcontext() as u:
  346. s.add(u)
  347. self.assertIn(u, s)
  348. t = s.copy()
  349. with testcontext() as u:
  350. s.update(t)
  351. self.assertEqual(len(s), len(t))
  352. with testcontext() as u:
  353. s.clear()
  354. self.assertEqual(len(s), 0)
  355. def test_len_cycles(self):
  356. N = 20
  357. items = [RefCycle() for i in range(N)]
  358. s = WeakSet(items)
  359. del items
  360. it = iter(s)
  361. try:
  362. next(it)
  363. except StopIteration:
  364. pass
  365. gc.collect()
  366. n1 = len(s)
  367. del it
  368. gc.collect()
  369. gc.collect() # For PyPy or other GCs.
  370. n2 = len(s)
  371. # one item may be kept alive inside the iterator
  372. self.assertIn(n1, (0, 1))
  373. self.assertEqual(n2, 0)
  374. def test_len_race(self):
  375. # Extended sanity checks for len() in the face of cyclic collection
  376. self.addCleanup(gc.set_threshold, *gc.get_threshold())
  377. for th in range(1, 100):
  378. N = 20
  379. gc.collect(0)
  380. gc.set_threshold(th, th, th)
  381. items = [RefCycle() for i in range(N)]
  382. s = WeakSet(items)
  383. del items
  384. # All items will be collected at next garbage collection pass
  385. it = iter(s)
  386. try:
  387. next(it)
  388. except StopIteration:
  389. pass
  390. n1 = len(s)
  391. del it
  392. n2 = len(s)
  393. self.assertGreaterEqual(n1, 0)
  394. self.assertLessEqual(n1, N)
  395. self.assertGreaterEqual(n2, 0)
  396. self.assertLessEqual(n2, n1)
  397. def test_repr(self):
  398. assert repr(self.s) == repr(self.s.data)
  399. def test_abc(self):
  400. self.assertIsInstance(self.s, Set)
  401. self.assertIsInstance(self.s, MutableSet)
  402. def test_copying(self):
  403. for cls in WeakSet, WeakSetWithSlots:
  404. s = cls(self.items)
  405. s.x = ['x']
  406. s.z = ['z']
  407. dup = copy.copy(s)
  408. self.assertIsInstance(dup, cls)
  409. self.assertEqual(dup, s)
  410. self.assertIsNot(dup, s)
  411. self.assertIs(dup.x, s.x)
  412. self.assertIs(dup.z, s.z)
  413. self.assertFalse(hasattr(dup, 'y'))
  414. dup = copy.deepcopy(s)
  415. self.assertIsInstance(dup, cls)
  416. self.assertEqual(dup, s)
  417. self.assertIsNot(dup, s)
  418. self.assertEqual(dup.x, s.x)
  419. self.assertIsNot(dup.x, s.x)
  420. self.assertEqual(dup.z, s.z)
  421. self.assertIsNot(dup.z, s.z)
  422. self.assertFalse(hasattr(dup, 'y'))
  423. if __name__ == "__main__":
  424. unittest.main()