| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296 |
- import unittest
- import operator
- import sys
- import pickle
- import gc
- from test import support
- class G:
- 'Sequence using __getitem__'
- def __init__(self, seqn):
- self.seqn = seqn
- def __getitem__(self, i):
- return self.seqn[i]
- class I:
- 'Sequence using iterator protocol'
- def __init__(self, seqn):
- self.seqn = seqn
- self.i = 0
- def __iter__(self):
- return self
- def __next__(self):
- if self.i >= len(self.seqn): raise StopIteration
- v = self.seqn[self.i]
- self.i += 1
- return v
- class Ig:
- 'Sequence using iterator protocol defined with a generator'
- def __init__(self, seqn):
- self.seqn = seqn
- self.i = 0
- def __iter__(self):
- for val in self.seqn:
- yield val
- class X:
- 'Missing __getitem__ and __iter__'
- def __init__(self, seqn):
- self.seqn = seqn
- self.i = 0
- def __next__(self):
- if self.i >= len(self.seqn): raise StopIteration
- v = self.seqn[self.i]
- self.i += 1
- return v
- class E:
- 'Test propagation of exceptions'
- def __init__(self, seqn):
- self.seqn = seqn
- self.i = 0
- def __iter__(self):
- return self
- def __next__(self):
- 3 // 0
- class N:
- 'Iterator missing __next__()'
- def __init__(self, seqn):
- self.seqn = seqn
- self.i = 0
- def __iter__(self):
- return self
- class PickleTest:
- # Helper to check picklability
- def check_pickle(self, itorg, seq):
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- d = pickle.dumps(itorg, proto)
- it = pickle.loads(d)
- self.assertEqual(type(itorg), type(it))
- self.assertEqual(list(it), seq)
- it = pickle.loads(d)
- try:
- next(it)
- except StopIteration:
- self.assertFalse(seq[1:])
- continue
- d = pickle.dumps(it, proto)
- it = pickle.loads(d)
- self.assertEqual(list(it), seq[1:])
- class EnumerateTestCase(unittest.TestCase, PickleTest):
- enum = enumerate
- seq, res = 'abc', [(0,'a'), (1,'b'), (2,'c')]
- def test_basicfunction(self):
- self.assertEqual(type(self.enum(self.seq)), self.enum)
- e = self.enum(self.seq)
- self.assertEqual(iter(e), e)
- self.assertEqual(list(self.enum(self.seq)), self.res)
- self.enum.__doc__
- def test_pickle(self):
- self.check_pickle(self.enum(self.seq), self.res)
- def test_getitemseqn(self):
- self.assertEqual(list(self.enum(G(self.seq))), self.res)
- e = self.enum(G(''))
- self.assertRaises(StopIteration, next, e)
- def test_iteratorseqn(self):
- self.assertEqual(list(self.enum(I(self.seq))), self.res)
- e = self.enum(I(''))
- self.assertRaises(StopIteration, next, e)
- def test_iteratorgenerator(self):
- self.assertEqual(list(self.enum(Ig(self.seq))), self.res)
- e = self.enum(Ig(''))
- self.assertRaises(StopIteration, next, e)
- def test_noniterable(self):
- self.assertRaises(TypeError, self.enum, X(self.seq))
- def test_illformediterable(self):
- self.assertRaises(TypeError, self.enum, N(self.seq))
- def test_exception_propagation(self):
- self.assertRaises(ZeroDivisionError, list, self.enum(E(self.seq)))
- def test_argumentcheck(self):
- self.assertRaises(TypeError, self.enum) # no arguments
- self.assertRaises(TypeError, self.enum, 1) # wrong type (not iterable)
- self.assertRaises(TypeError, self.enum, 'abc', 'a') # wrong type
- self.assertRaises(TypeError, self.enum, 'abc', 2, 3) # too many arguments
- def test_kwargs(self):
- self.assertEqual(list(self.enum(iterable=Ig(self.seq))), self.res)
- expected = list(self.enum(Ig(self.seq), 0))
- self.assertEqual(list(self.enum(iterable=Ig(self.seq), start=0)),
- expected)
- self.assertEqual(list(self.enum(start=0, iterable=Ig(self.seq))),
- expected)
- self.assertRaises(TypeError, self.enum, iterable=[], x=3)
- self.assertRaises(TypeError, self.enum, start=0, x=3)
- self.assertRaises(TypeError, self.enum, x=0, y=3)
- self.assertRaises(TypeError, self.enum, x=0)
- @support.cpython_only
- def test_tuple_reuse(self):
- # Tests an implementation detail where tuple is reused
- # whenever nothing else holds a reference to it
- self.assertEqual(len(set(map(id, list(enumerate(self.seq))))), len(self.seq))
- self.assertEqual(len(set(map(id, enumerate(self.seq)))), min(1,len(self.seq)))
- @support.cpython_only
- def test_enumerate_result_gc(self):
- # bpo-42536: enumerate's tuple-reuse speed trick breaks the GC's
- # assumptions about what can be untracked. Make sure we re-track result
- # tuples whenever we reuse them.
- it = self.enum([[]])
- gc.collect()
- # That GC collection probably untracked the recycled internal result
- # tuple, which is initialized to (None, None). Make sure it's re-tracked
- # when it's mutated and returned from __next__:
- self.assertTrue(gc.is_tracked(next(it)))
- class MyEnum(enumerate):
- pass
- class SubclassTestCase(EnumerateTestCase):
- enum = MyEnum
- class TestEmpty(EnumerateTestCase):
- seq, res = '', []
- class TestBig(EnumerateTestCase):
- seq = range(10,20000,2)
- res = list(zip(range(20000), seq))
- class TestReversed(unittest.TestCase, PickleTest):
- def test_simple(self):
- class A:
- def __getitem__(self, i):
- if i < 5:
- return str(i)
- raise StopIteration
- def __len__(self):
- return 5
- for data in ('abc', range(5), tuple(enumerate('abc')), A(),
- range(1,17,5), dict.fromkeys('abcde')):
- self.assertEqual(list(data)[::-1], list(reversed(data)))
- # don't allow keyword arguments
- self.assertRaises(TypeError, reversed, [], a=1)
- def test_range_optimization(self):
- x = range(1)
- self.assertEqual(type(reversed(x)), type(iter(x)))
- def test_len(self):
- for s in ('hello', tuple('hello'), list('hello'), range(5)):
- self.assertEqual(operator.length_hint(reversed(s)), len(s))
- r = reversed(s)
- list(r)
- self.assertEqual(operator.length_hint(r), 0)
- class SeqWithWeirdLen:
- called = False
- def __len__(self):
- if not self.called:
- self.called = True
- return 10
- raise ZeroDivisionError
- def __getitem__(self, index):
- return index
- r = reversed(SeqWithWeirdLen())
- self.assertRaises(ZeroDivisionError, operator.length_hint, r)
- def test_gc(self):
- class Seq:
- def __len__(self):
- return 10
- def __getitem__(self, index):
- return index
- s = Seq()
- r = reversed(s)
- s.r = r
- def test_args(self):
- self.assertRaises(TypeError, reversed)
- self.assertRaises(TypeError, reversed, [], 'extra')
- @unittest.skipUnless(hasattr(sys, 'getrefcount'), 'test needs sys.getrefcount()')
- def test_bug1229429(self):
- # this bug was never in reversed, it was in
- # PyObject_CallMethod, and reversed_new calls that sometimes.
- def f():
- pass
- r = f.__reversed__ = object()
- rc = sys.getrefcount(r)
- for i in range(10):
- try:
- reversed(f)
- except TypeError:
- pass
- else:
- self.fail("non-callable __reversed__ didn't raise!")
- self.assertEqual(rc, sys.getrefcount(r))
- def test_objmethods(self):
- # Objects must have __len__() and __getitem__() implemented.
- class NoLen(object):
- def __getitem__(self, i): return 1
- nl = NoLen()
- self.assertRaises(TypeError, reversed, nl)
- class NoGetItem(object):
- def __len__(self): return 2
- ngi = NoGetItem()
- self.assertRaises(TypeError, reversed, ngi)
- class Blocked(object):
- def __getitem__(self, i): return 1
- def __len__(self): return 2
- __reversed__ = None
- b = Blocked()
- self.assertRaises(TypeError, reversed, b)
- def test_pickle(self):
- for data in 'abc', range(5), tuple(enumerate('abc')), range(1,17,5):
- self.check_pickle(reversed(data), list(data)[::-1])
- class EnumerateStartTestCase(EnumerateTestCase):
- def test_basicfunction(self):
- e = self.enum(self.seq)
- self.assertEqual(iter(e), e)
- self.assertEqual(list(self.enum(self.seq)), self.res)
- class TestStart(EnumerateStartTestCase):
- def enum(self, iterable, start=11):
- return enumerate(iterable, start=start)
- seq, res = 'abc', [(11, 'a'), (12, 'b'), (13, 'c')]
- class TestLongStart(EnumerateStartTestCase):
- def enum(self, iterable, start=sys.maxsize + 1):
- return enumerate(iterable, start=start)
- seq, res = 'abc', [(sys.maxsize+1,'a'), (sys.maxsize+2,'b'),
- (sys.maxsize+3,'c')]
- if __name__ == "__main__":
- unittest.main()
|