test_heapq.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. """Unittests for heapq."""
  2. import random
  3. import unittest
  4. import doctest
  5. from test import support
  6. from test.support import import_helper
  7. from unittest import TestCase, skipUnless
  8. from operator import itemgetter
  9. py_heapq = import_helper.import_fresh_module('heapq', blocked=['_heapq'])
  10. c_heapq = import_helper.import_fresh_module('heapq', fresh=['_heapq'])
  11. # _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when
  12. # _heapq is imported, so check them there
  13. func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace',
  14. '_heappop_max', '_heapreplace_max', '_heapify_max']
  15. class TestModules(TestCase):
  16. def test_py_functions(self):
  17. for fname in func_names:
  18. self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq')
  19. @skipUnless(c_heapq, 'requires _heapq')
  20. def test_c_functions(self):
  21. for fname in func_names:
  22. self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
  23. def load_tests(loader, tests, ignore):
  24. # The 'merge' function has examples in its docstring which we should test
  25. # with 'doctest'.
  26. #
  27. # However, doctest can't easily find all docstrings in the module (loading
  28. # it through import_fresh_module seems to confuse it), so we specifically
  29. # create a finder which returns the doctests from the merge method.
  30. class HeapqMergeDocTestFinder:
  31. def find(self, *args, **kwargs):
  32. dtf = doctest.DocTestFinder()
  33. return dtf.find(py_heapq.merge)
  34. tests.addTests(doctest.DocTestSuite(py_heapq,
  35. test_finder=HeapqMergeDocTestFinder()))
  36. return tests
  37. class TestHeap:
  38. def test_push_pop(self):
  39. # 1) Push 256 random numbers and pop them off, verifying all's OK.
  40. heap = []
  41. data = []
  42. self.check_invariant(heap)
  43. for i in range(256):
  44. item = random.random()
  45. data.append(item)
  46. self.module.heappush(heap, item)
  47. self.check_invariant(heap)
  48. results = []
  49. while heap:
  50. item = self.module.heappop(heap)
  51. self.check_invariant(heap)
  52. results.append(item)
  53. data_sorted = data[:]
  54. data_sorted.sort()
  55. self.assertEqual(data_sorted, results)
  56. # 2) Check that the invariant holds for a sorted array
  57. self.check_invariant(results)
  58. self.assertRaises(TypeError, self.module.heappush, [])
  59. try:
  60. self.assertRaises(TypeError, self.module.heappush, None, None)
  61. self.assertRaises(TypeError, self.module.heappop, None)
  62. except AttributeError:
  63. pass
  64. def check_invariant(self, heap):
  65. # Check the heap invariant.
  66. for pos, item in enumerate(heap):
  67. if pos: # pos 0 has no parent
  68. parentpos = (pos-1) >> 1
  69. self.assertTrue(heap[parentpos] <= item)
  70. def test_heapify(self):
  71. for size in list(range(30)) + [20000]:
  72. heap = [random.random() for dummy in range(size)]
  73. self.module.heapify(heap)
  74. self.check_invariant(heap)
  75. self.assertRaises(TypeError, self.module.heapify, None)
  76. def test_naive_nbest(self):
  77. data = [random.randrange(2000) for i in range(1000)]
  78. heap = []
  79. for item in data:
  80. self.module.heappush(heap, item)
  81. if len(heap) > 10:
  82. self.module.heappop(heap)
  83. heap.sort()
  84. self.assertEqual(heap, sorted(data)[-10:])
  85. def heapiter(self, heap):
  86. # An iterator returning a heap's elements, smallest-first.
  87. try:
  88. while 1:
  89. yield self.module.heappop(heap)
  90. except IndexError:
  91. pass
  92. def test_nbest(self):
  93. # Less-naive "N-best" algorithm, much faster (if len(data) is big
  94. # enough <wink>) than sorting all of data. However, if we had a max
  95. # heap instead of a min heap, it could go faster still via
  96. # heapify'ing all of data (linear time), then doing 10 heappops
  97. # (10 log-time steps).
  98. data = [random.randrange(2000) for i in range(1000)]
  99. heap = data[:10]
  100. self.module.heapify(heap)
  101. for item in data[10:]:
  102. if item > heap[0]: # this gets rarer the longer we run
  103. self.module.heapreplace(heap, item)
  104. self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
  105. self.assertRaises(TypeError, self.module.heapreplace, None)
  106. self.assertRaises(TypeError, self.module.heapreplace, None, None)
  107. self.assertRaises(IndexError, self.module.heapreplace, [], None)
  108. def test_nbest_with_pushpop(self):
  109. data = [random.randrange(2000) for i in range(1000)]
  110. heap = data[:10]
  111. self.module.heapify(heap)
  112. for item in data[10:]:
  113. self.module.heappushpop(heap, item)
  114. self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
  115. self.assertEqual(self.module.heappushpop([], 'x'), 'x')
  116. def test_heappushpop(self):
  117. h = []
  118. x = self.module.heappushpop(h, 10)
  119. self.assertEqual((h, x), ([], 10))
  120. h = [10]
  121. x = self.module.heappushpop(h, 10.0)
  122. self.assertEqual((h, x), ([10], 10.0))
  123. self.assertEqual(type(h[0]), int)
  124. self.assertEqual(type(x), float)
  125. h = [10]
  126. x = self.module.heappushpop(h, 9)
  127. self.assertEqual((h, x), ([10], 9))
  128. h = [10]
  129. x = self.module.heappushpop(h, 11)
  130. self.assertEqual((h, x), ([11], 10))
  131. def test_heappop_max(self):
  132. # _heapop_max has an optimization for one-item lists which isn't
  133. # covered in other tests, so test that case explicitly here
  134. h = [3, 2]
  135. self.assertEqual(self.module._heappop_max(h), 3)
  136. self.assertEqual(self.module._heappop_max(h), 2)
  137. def test_heapsort(self):
  138. # Exercise everything with repeated heapsort checks
  139. for trial in range(100):
  140. size = random.randrange(50)
  141. data = [random.randrange(25) for i in range(size)]
  142. if trial & 1: # Half of the time, use heapify
  143. heap = data[:]
  144. self.module.heapify(heap)
  145. else: # The rest of the time, use heappush
  146. heap = []
  147. for item in data:
  148. self.module.heappush(heap, item)
  149. heap_sorted = [self.module.heappop(heap) for i in range(size)]
  150. self.assertEqual(heap_sorted, sorted(data))
  151. def test_merge(self):
  152. inputs = []
  153. for i in range(random.randrange(25)):
  154. row = []
  155. for j in range(random.randrange(100)):
  156. tup = random.choice('ABC'), random.randrange(-500, 500)
  157. row.append(tup)
  158. inputs.append(row)
  159. for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]:
  160. for reverse in [False, True]:
  161. seqs = []
  162. for seq in inputs:
  163. seqs.append(sorted(seq, key=key, reverse=reverse))
  164. self.assertEqual(sorted(chain(*inputs), key=key, reverse=reverse),
  165. list(self.module.merge(*seqs, key=key, reverse=reverse)))
  166. self.assertEqual(list(self.module.merge()), [])
  167. def test_empty_merges(self):
  168. # Merging two empty lists (with or without a key) should produce
  169. # another empty list.
  170. self.assertEqual(list(self.module.merge([], [])), [])
  171. self.assertEqual(list(self.module.merge([], [], key=lambda: 6)), [])
  172. def test_merge_does_not_suppress_index_error(self):
  173. # Issue 19018: Heapq.merge suppresses IndexError from user generator
  174. def iterable():
  175. s = list(range(10))
  176. for i in range(20):
  177. yield s[i] # IndexError when i > 10
  178. with self.assertRaises(IndexError):
  179. list(self.module.merge(iterable(), iterable()))
  180. def test_merge_stability(self):
  181. class Int(int):
  182. pass
  183. inputs = [[], [], [], []]
  184. for i in range(20000):
  185. stream = random.randrange(4)
  186. x = random.randrange(500)
  187. obj = Int(x)
  188. obj.pair = (x, stream)
  189. inputs[stream].append(obj)
  190. for stream in inputs:
  191. stream.sort()
  192. result = [i.pair for i in self.module.merge(*inputs)]
  193. self.assertEqual(result, sorted(result))
  194. def test_nsmallest(self):
  195. data = [(random.randrange(2000), i) for i in range(1000)]
  196. for f in (None, lambda x: x[0] * 547 % 2000):
  197. for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
  198. self.assertEqual(list(self.module.nsmallest(n, data)),
  199. sorted(data)[:n])
  200. self.assertEqual(list(self.module.nsmallest(n, data, key=f)),
  201. sorted(data, key=f)[:n])
  202. def test_nlargest(self):
  203. data = [(random.randrange(2000), i) for i in range(1000)]
  204. for f in (None, lambda x: x[0] * 547 % 2000):
  205. for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
  206. self.assertEqual(list(self.module.nlargest(n, data)),
  207. sorted(data, reverse=True)[:n])
  208. self.assertEqual(list(self.module.nlargest(n, data, key=f)),
  209. sorted(data, key=f, reverse=True)[:n])
  210. def test_comparison_operator(self):
  211. # Issue 3051: Make sure heapq works with both __lt__
  212. # For python 3.0, __le__ alone is not enough
  213. def hsort(data, comp):
  214. data = [comp(x) for x in data]
  215. self.module.heapify(data)
  216. return [self.module.heappop(data).x for i in range(len(data))]
  217. class LT:
  218. def __init__(self, x):
  219. self.x = x
  220. def __lt__(self, other):
  221. return self.x > other.x
  222. class LE:
  223. def __init__(self, x):
  224. self.x = x
  225. def __le__(self, other):
  226. return self.x >= other.x
  227. data = [random.random() for i in range(100)]
  228. target = sorted(data, reverse=True)
  229. self.assertEqual(hsort(data, LT), target)
  230. self.assertRaises(TypeError, data, LE)
  231. class TestHeapPython(TestHeap, TestCase):
  232. module = py_heapq
  233. @skipUnless(c_heapq, 'requires _heapq')
  234. class TestHeapC(TestHeap, TestCase):
  235. module = c_heapq
  236. #==============================================================================
  237. class LenOnly:
  238. "Dummy sequence class defining __len__ but not __getitem__."
  239. def __len__(self):
  240. return 10
  241. class CmpErr:
  242. "Dummy element that always raises an error during comparison"
  243. def __eq__(self, other):
  244. raise ZeroDivisionError
  245. __ne__ = __lt__ = __le__ = __gt__ = __ge__ = __eq__
  246. def R(seqn):
  247. 'Regular generator'
  248. for i in seqn:
  249. yield i
  250. class G:
  251. 'Sequence using __getitem__'
  252. def __init__(self, seqn):
  253. self.seqn = seqn
  254. def __getitem__(self, i):
  255. return self.seqn[i]
  256. class I:
  257. 'Sequence using iterator protocol'
  258. def __init__(self, seqn):
  259. self.seqn = seqn
  260. self.i = 0
  261. def __iter__(self):
  262. return self
  263. def __next__(self):
  264. if self.i >= len(self.seqn): raise StopIteration
  265. v = self.seqn[self.i]
  266. self.i += 1
  267. return v
  268. class Ig:
  269. 'Sequence using iterator protocol defined with a generator'
  270. def __init__(self, seqn):
  271. self.seqn = seqn
  272. self.i = 0
  273. def __iter__(self):
  274. for val in self.seqn:
  275. yield val
  276. class X:
  277. 'Missing __getitem__ and __iter__'
  278. def __init__(self, seqn):
  279. self.seqn = seqn
  280. self.i = 0
  281. def __next__(self):
  282. if self.i >= len(self.seqn): raise StopIteration
  283. v = self.seqn[self.i]
  284. self.i += 1
  285. return v
  286. class N:
  287. 'Iterator missing __next__()'
  288. def __init__(self, seqn):
  289. self.seqn = seqn
  290. self.i = 0
  291. def __iter__(self):
  292. return self
  293. class E:
  294. 'Test propagation of exceptions'
  295. def __init__(self, seqn):
  296. self.seqn = seqn
  297. self.i = 0
  298. def __iter__(self):
  299. return self
  300. def __next__(self):
  301. 3 // 0
  302. class S:
  303. 'Test immediate stop'
  304. def __init__(self, seqn):
  305. pass
  306. def __iter__(self):
  307. return self
  308. def __next__(self):
  309. raise StopIteration
  310. from itertools import chain
  311. def L(seqn):
  312. 'Test multiple tiers of iterators'
  313. return chain(map(lambda x:x, R(Ig(G(seqn)))))
  314. class SideEffectLT:
  315. def __init__(self, value, heap):
  316. self.value = value
  317. self.heap = heap
  318. def __lt__(self, other):
  319. self.heap[:] = []
  320. return self.value < other.value
  321. class TestErrorHandling:
  322. def test_non_sequence(self):
  323. for f in (self.module.heapify, self.module.heappop):
  324. self.assertRaises((TypeError, AttributeError), f, 10)
  325. for f in (self.module.heappush, self.module.heapreplace,
  326. self.module.nlargest, self.module.nsmallest):
  327. self.assertRaises((TypeError, AttributeError), f, 10, 10)
  328. def test_len_only(self):
  329. for f in (self.module.heapify, self.module.heappop):
  330. self.assertRaises((TypeError, AttributeError), f, LenOnly())
  331. for f in (self.module.heappush, self.module.heapreplace):
  332. self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10)
  333. for f in (self.module.nlargest, self.module.nsmallest):
  334. self.assertRaises(TypeError, f, 2, LenOnly())
  335. def test_cmp_err(self):
  336. seq = [CmpErr(), CmpErr(), CmpErr()]
  337. for f in (self.module.heapify, self.module.heappop):
  338. self.assertRaises(ZeroDivisionError, f, seq)
  339. for f in (self.module.heappush, self.module.heapreplace):
  340. self.assertRaises(ZeroDivisionError, f, seq, 10)
  341. for f in (self.module.nlargest, self.module.nsmallest):
  342. self.assertRaises(ZeroDivisionError, f, 2, seq)
  343. def test_arg_parsing(self):
  344. for f in (self.module.heapify, self.module.heappop,
  345. self.module.heappush, self.module.heapreplace,
  346. self.module.nlargest, self.module.nsmallest):
  347. self.assertRaises((TypeError, AttributeError), f, 10)
  348. def test_iterable_args(self):
  349. for f in (self.module.nlargest, self.module.nsmallest):
  350. for s in ("123", "", range(1000), (1, 1.2), range(2000,2200,5)):
  351. for g in (G, I, Ig, L, R):
  352. self.assertEqual(list(f(2, g(s))), list(f(2,s)))
  353. self.assertEqual(list(f(2, S(s))), [])
  354. self.assertRaises(TypeError, f, 2, X(s))
  355. self.assertRaises(TypeError, f, 2, N(s))
  356. self.assertRaises(ZeroDivisionError, f, 2, E(s))
  357. # Issue #17278: the heap may change size while it's being walked.
  358. def test_heappush_mutating_heap(self):
  359. heap = []
  360. heap.extend(SideEffectLT(i, heap) for i in range(200))
  361. # Python version raises IndexError, C version RuntimeError
  362. with self.assertRaises((IndexError, RuntimeError)):
  363. self.module.heappush(heap, SideEffectLT(5, heap))
  364. def test_heappop_mutating_heap(self):
  365. heap = []
  366. heap.extend(SideEffectLT(i, heap) for i in range(200))
  367. # Python version raises IndexError, C version RuntimeError
  368. with self.assertRaises((IndexError, RuntimeError)):
  369. self.module.heappop(heap)
  370. def test_comparison_operator_modifiying_heap(self):
  371. # See bpo-39421: Strong references need to be taken
  372. # when comparing objects as they can alter the heap
  373. class EvilClass(int):
  374. def __lt__(self, o):
  375. heap.clear()
  376. return NotImplemented
  377. heap = []
  378. self.module.heappush(heap, EvilClass(0))
  379. self.assertRaises(IndexError, self.module.heappushpop, heap, 1)
  380. def test_comparison_operator_modifiying_heap_two_heaps(self):
  381. class h(int):
  382. def __lt__(self, o):
  383. list2.clear()
  384. return NotImplemented
  385. class g(int):
  386. def __lt__(self, o):
  387. list1.clear()
  388. return NotImplemented
  389. list1, list2 = [], []
  390. self.module.heappush(list1, h(0))
  391. self.module.heappush(list2, g(0))
  392. self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
  393. self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))
  394. class TestErrorHandlingPython(TestErrorHandling, TestCase):
  395. module = py_heapq
  396. @skipUnless(c_heapq, 'requires _heapq')
  397. class TestErrorHandlingC(TestErrorHandling, TestCase):
  398. module = c_heapq
  399. if __name__ == "__main__":
  400. unittest.main()