test_enumerate.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. import unittest
  2. import operator
  3. import sys
  4. import pickle
  5. import gc
  6. from test import support
  7. class G:
  8. 'Sequence using __getitem__'
  9. def __init__(self, seqn):
  10. self.seqn = seqn
  11. def __getitem__(self, i):
  12. return self.seqn[i]
  13. class I:
  14. 'Sequence using iterator protocol'
  15. def __init__(self, seqn):
  16. self.seqn = seqn
  17. self.i = 0
  18. def __iter__(self):
  19. return self
  20. def __next__(self):
  21. if self.i >= len(self.seqn): raise StopIteration
  22. v = self.seqn[self.i]
  23. self.i += 1
  24. return v
  25. class Ig:
  26. 'Sequence using iterator protocol defined with a generator'
  27. def __init__(self, seqn):
  28. self.seqn = seqn
  29. self.i = 0
  30. def __iter__(self):
  31. for val in self.seqn:
  32. yield val
  33. class X:
  34. 'Missing __getitem__ and __iter__'
  35. def __init__(self, seqn):
  36. self.seqn = seqn
  37. self.i = 0
  38. def __next__(self):
  39. if self.i >= len(self.seqn): raise StopIteration
  40. v = self.seqn[self.i]
  41. self.i += 1
  42. return v
  43. class E:
  44. 'Test propagation of exceptions'
  45. def __init__(self, seqn):
  46. self.seqn = seqn
  47. self.i = 0
  48. def __iter__(self):
  49. return self
  50. def __next__(self):
  51. 3 // 0
  52. class N:
  53. 'Iterator missing __next__()'
  54. def __init__(self, seqn):
  55. self.seqn = seqn
  56. self.i = 0
  57. def __iter__(self):
  58. return self
  59. class PickleTest:
  60. # Helper to check picklability
  61. def check_pickle(self, itorg, seq):
  62. for proto in range(pickle.HIGHEST_PROTOCOL + 1):
  63. d = pickle.dumps(itorg, proto)
  64. it = pickle.loads(d)
  65. self.assertEqual(type(itorg), type(it))
  66. self.assertEqual(list(it), seq)
  67. it = pickle.loads(d)
  68. try:
  69. next(it)
  70. except StopIteration:
  71. self.assertFalse(seq[1:])
  72. continue
  73. d = pickle.dumps(it, proto)
  74. it = pickle.loads(d)
  75. self.assertEqual(list(it), seq[1:])
  76. class EnumerateTestCase(unittest.TestCase, PickleTest):
  77. enum = enumerate
  78. seq, res = 'abc', [(0,'a'), (1,'b'), (2,'c')]
  79. def test_basicfunction(self):
  80. self.assertEqual(type(self.enum(self.seq)), self.enum)
  81. e = self.enum(self.seq)
  82. self.assertEqual(iter(e), e)
  83. self.assertEqual(list(self.enum(self.seq)), self.res)
  84. self.enum.__doc__
  85. def test_pickle(self):
  86. self.check_pickle(self.enum(self.seq), self.res)
  87. def test_getitemseqn(self):
  88. self.assertEqual(list(self.enum(G(self.seq))), self.res)
  89. e = self.enum(G(''))
  90. self.assertRaises(StopIteration, next, e)
  91. def test_iteratorseqn(self):
  92. self.assertEqual(list(self.enum(I(self.seq))), self.res)
  93. e = self.enum(I(''))
  94. self.assertRaises(StopIteration, next, e)
  95. def test_iteratorgenerator(self):
  96. self.assertEqual(list(self.enum(Ig(self.seq))), self.res)
  97. e = self.enum(Ig(''))
  98. self.assertRaises(StopIteration, next, e)
  99. def test_noniterable(self):
  100. self.assertRaises(TypeError, self.enum, X(self.seq))
  101. def test_illformediterable(self):
  102. self.assertRaises(TypeError, self.enum, N(self.seq))
  103. def test_exception_propagation(self):
  104. self.assertRaises(ZeroDivisionError, list, self.enum(E(self.seq)))
  105. def test_argumentcheck(self):
  106. self.assertRaises(TypeError, self.enum) # no arguments
  107. self.assertRaises(TypeError, self.enum, 1) # wrong type (not iterable)
  108. self.assertRaises(TypeError, self.enum, 'abc', 'a') # wrong type
  109. self.assertRaises(TypeError, self.enum, 'abc', 2, 3) # too many arguments
  110. def test_kwargs(self):
  111. self.assertEqual(list(self.enum(iterable=Ig(self.seq))), self.res)
  112. expected = list(self.enum(Ig(self.seq), 0))
  113. self.assertEqual(list(self.enum(iterable=Ig(self.seq), start=0)),
  114. expected)
  115. self.assertEqual(list(self.enum(start=0, iterable=Ig(self.seq))),
  116. expected)
  117. self.assertRaises(TypeError, self.enum, iterable=[], x=3)
  118. self.assertRaises(TypeError, self.enum, start=0, x=3)
  119. self.assertRaises(TypeError, self.enum, x=0, y=3)
  120. self.assertRaises(TypeError, self.enum, x=0)
  121. @support.cpython_only
  122. def test_tuple_reuse(self):
  123. # Tests an implementation detail where tuple is reused
  124. # whenever nothing else holds a reference to it
  125. self.assertEqual(len(set(map(id, list(enumerate(self.seq))))), len(self.seq))
  126. self.assertEqual(len(set(map(id, enumerate(self.seq)))), min(1,len(self.seq)))
  127. @support.cpython_only
  128. def test_enumerate_result_gc(self):
  129. # bpo-42536: enumerate's tuple-reuse speed trick breaks the GC's
  130. # assumptions about what can be untracked. Make sure we re-track result
  131. # tuples whenever we reuse them.
  132. it = self.enum([[]])
  133. gc.collect()
  134. # That GC collection probably untracked the recycled internal result
  135. # tuple, which is initialized to (None, None). Make sure it's re-tracked
  136. # when it's mutated and returned from __next__:
  137. self.assertTrue(gc.is_tracked(next(it)))
  138. class MyEnum(enumerate):
  139. pass
  140. class SubclassTestCase(EnumerateTestCase):
  141. enum = MyEnum
  142. class TestEmpty(EnumerateTestCase):
  143. seq, res = '', []
  144. class TestBig(EnumerateTestCase):
  145. seq = range(10,20000,2)
  146. res = list(zip(range(20000), seq))
  147. class TestReversed(unittest.TestCase, PickleTest):
  148. def test_simple(self):
  149. class A:
  150. def __getitem__(self, i):
  151. if i < 5:
  152. return str(i)
  153. raise StopIteration
  154. def __len__(self):
  155. return 5
  156. for data in ('abc', range(5), tuple(enumerate('abc')), A(),
  157. range(1,17,5), dict.fromkeys('abcde')):
  158. self.assertEqual(list(data)[::-1], list(reversed(data)))
  159. # don't allow keyword arguments
  160. self.assertRaises(TypeError, reversed, [], a=1)
  161. def test_range_optimization(self):
  162. x = range(1)
  163. self.assertEqual(type(reversed(x)), type(iter(x)))
  164. def test_len(self):
  165. for s in ('hello', tuple('hello'), list('hello'), range(5)):
  166. self.assertEqual(operator.length_hint(reversed(s)), len(s))
  167. r = reversed(s)
  168. list(r)
  169. self.assertEqual(operator.length_hint(r), 0)
  170. class SeqWithWeirdLen:
  171. called = False
  172. def __len__(self):
  173. if not self.called:
  174. self.called = True
  175. return 10
  176. raise ZeroDivisionError
  177. def __getitem__(self, index):
  178. return index
  179. r = reversed(SeqWithWeirdLen())
  180. self.assertRaises(ZeroDivisionError, operator.length_hint, r)
  181. def test_gc(self):
  182. class Seq:
  183. def __len__(self):
  184. return 10
  185. def __getitem__(self, index):
  186. return index
  187. s = Seq()
  188. r = reversed(s)
  189. s.r = r
  190. def test_args(self):
  191. self.assertRaises(TypeError, reversed)
  192. self.assertRaises(TypeError, reversed, [], 'extra')
  193. @unittest.skipUnless(hasattr(sys, 'getrefcount'), 'test needs sys.getrefcount()')
  194. def test_bug1229429(self):
  195. # this bug was never in reversed, it was in
  196. # PyObject_CallMethod, and reversed_new calls that sometimes.
  197. def f():
  198. pass
  199. r = f.__reversed__ = object()
  200. rc = sys.getrefcount(r)
  201. for i in range(10):
  202. try:
  203. reversed(f)
  204. except TypeError:
  205. pass
  206. else:
  207. self.fail("non-callable __reversed__ didn't raise!")
  208. self.assertEqual(rc, sys.getrefcount(r))
  209. def test_objmethods(self):
  210. # Objects must have __len__() and __getitem__() implemented.
  211. class NoLen(object):
  212. def __getitem__(self, i): return 1
  213. nl = NoLen()
  214. self.assertRaises(TypeError, reversed, nl)
  215. class NoGetItem(object):
  216. def __len__(self): return 2
  217. ngi = NoGetItem()
  218. self.assertRaises(TypeError, reversed, ngi)
  219. class Blocked(object):
  220. def __getitem__(self, i): return 1
  221. def __len__(self): return 2
  222. __reversed__ = None
  223. b = Blocked()
  224. self.assertRaises(TypeError, reversed, b)
  225. def test_pickle(self):
  226. for data in 'abc', range(5), tuple(enumerate('abc')), range(1,17,5):
  227. self.check_pickle(reversed(data), list(data)[::-1])
  228. class EnumerateStartTestCase(EnumerateTestCase):
  229. def test_basicfunction(self):
  230. e = self.enum(self.seq)
  231. self.assertEqual(iter(e), e)
  232. self.assertEqual(list(self.enum(self.seq)), self.res)
  233. class TestStart(EnumerateStartTestCase):
  234. def enum(self, iterable, start=11):
  235. return enumerate(iterable, start=start)
  236. seq, res = 'abc', [(11, 'a'), (12, 'b'), (13, 'c')]
  237. class TestLongStart(EnumerateStartTestCase):
  238. def enum(self, iterable, start=sys.maxsize + 1):
  239. return enumerate(iterable, start=start)
  240. seq, res = 'abc', [(sys.maxsize+1,'a'), (sys.maxsize+2,'b'),
  241. (sys.maxsize+3,'c')]
  242. if __name__ == "__main__":
  243. unittest.main()