| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549 |
- from _compat_pickle import (IMPORT_MAPPING, REVERSE_IMPORT_MAPPING,
- NAME_MAPPING, REVERSE_NAME_MAPPING)
- import builtins
- import pickle
- import io
- import collections
- import struct
- import sys
- import warnings
- import weakref
- import doctest
- import unittest
- from test import support
- from test.support import import_helper
- from test.pickletester import AbstractHookTests
- from test.pickletester import AbstractUnpickleTests
- from test.pickletester import AbstractPickleTests
- from test.pickletester import AbstractPickleModuleTests
- from test.pickletester import AbstractPersistentPicklerTests
- from test.pickletester import AbstractIdentityPersistentPicklerTests
- from test.pickletester import AbstractPicklerUnpicklerObjectTests
- from test.pickletester import AbstractDispatchTableTests
- from test.pickletester import AbstractCustomPicklerClass
- from test.pickletester import BigmemPickleTests
- try:
- import _pickle
- has_c_implementation = True
- except ImportError:
- has_c_implementation = False
- class PyPickleTests(AbstractPickleModuleTests, unittest.TestCase):
- dump = staticmethod(pickle._dump)
- dumps = staticmethod(pickle._dumps)
- load = staticmethod(pickle._load)
- loads = staticmethod(pickle._loads)
- Pickler = pickle._Pickler
- Unpickler = pickle._Unpickler
- class PyUnpicklerTests(AbstractUnpickleTests, unittest.TestCase):
- unpickler = pickle._Unpickler
- bad_stack_errors = (IndexError,)
- truncated_errors = (pickle.UnpicklingError, EOFError,
- AttributeError, ValueError,
- struct.error, IndexError, ImportError)
- def loads(self, buf, **kwds):
- f = io.BytesIO(buf)
- u = self.unpickler(f, **kwds)
- return u.load()
- class PyPicklerTests(AbstractPickleTests, unittest.TestCase):
- pickler = pickle._Pickler
- unpickler = pickle._Unpickler
- def dumps(self, arg, proto=None, **kwargs):
- f = io.BytesIO()
- p = self.pickler(f, proto, **kwargs)
- p.dump(arg)
- f.seek(0)
- return bytes(f.read())
- def loads(self, buf, **kwds):
- f = io.BytesIO(buf)
- u = self.unpickler(f, **kwds)
- return u.load()
- class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests,
- BigmemPickleTests, unittest.TestCase):
- bad_stack_errors = (pickle.UnpicklingError, IndexError)
- truncated_errors = (pickle.UnpicklingError, EOFError,
- AttributeError, ValueError,
- struct.error, IndexError, ImportError)
- def dumps(self, arg, protocol=None, **kwargs):
- return pickle.dumps(arg, protocol, **kwargs)
- def loads(self, buf, **kwds):
- return pickle.loads(buf, **kwds)
- test_framed_write_sizes_with_delayed_writer = None
- class PersistentPicklerUnpicklerMixin(object):
- def dumps(self, arg, proto=None):
- class PersPickler(self.pickler):
- def persistent_id(subself, obj):
- return self.persistent_id(obj)
- f = io.BytesIO()
- p = PersPickler(f, proto)
- p.dump(arg)
- return f.getvalue()
- def loads(self, buf, **kwds):
- class PersUnpickler(self.unpickler):
- def persistent_load(subself, obj):
- return self.persistent_load(obj)
- f = io.BytesIO(buf)
- u = PersUnpickler(f, **kwds)
- return u.load()
- class PyPersPicklerTests(AbstractPersistentPicklerTests,
- PersistentPicklerUnpicklerMixin, unittest.TestCase):
- pickler = pickle._Pickler
- unpickler = pickle._Unpickler
- class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests,
- PersistentPicklerUnpicklerMixin, unittest.TestCase):
- pickler = pickle._Pickler
- unpickler = pickle._Unpickler
- @support.cpython_only
- def test_pickler_reference_cycle(self):
- def check(Pickler):
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- f = io.BytesIO()
- pickler = Pickler(f, proto)
- pickler.dump('abc')
- self.assertEqual(self.loads(f.getvalue()), 'abc')
- pickler = Pickler(io.BytesIO())
- self.assertEqual(pickler.persistent_id('def'), 'def')
- r = weakref.ref(pickler)
- del pickler
- self.assertIsNone(r())
- class PersPickler(self.pickler):
- def persistent_id(subself, obj):
- return obj
- check(PersPickler)
- class PersPickler(self.pickler):
- @classmethod
- def persistent_id(cls, obj):
- return obj
- check(PersPickler)
- class PersPickler(self.pickler):
- @staticmethod
- def persistent_id(obj):
- return obj
- check(PersPickler)
- @support.cpython_only
- def test_custom_pickler_dispatch_table_memleak(self):
- # See https://github.com/python/cpython/issues/89988
- class Pickler(self.pickler):
- def __init__(self, *args, **kwargs):
- self.dispatch_table = table
- super().__init__(*args, **kwargs)
- class DispatchTable:
- pass
- table = DispatchTable()
- pickler = Pickler(io.BytesIO())
- self.assertIs(pickler.dispatch_table, table)
- table_ref = weakref.ref(table)
- self.assertIsNotNone(table_ref())
- del pickler
- del table
- support.gc_collect()
- self.assertIsNone(table_ref())
- @support.cpython_only
- def test_unpickler_reference_cycle(self):
- def check(Unpickler):
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto)))
- self.assertEqual(unpickler.load(), 'abc')
- unpickler = Unpickler(io.BytesIO())
- self.assertEqual(unpickler.persistent_load('def'), 'def')
- r = weakref.ref(unpickler)
- del unpickler
- self.assertIsNone(r())
- class PersUnpickler(self.unpickler):
- def persistent_load(subself, pid):
- return pid
- check(PersUnpickler)
- class PersUnpickler(self.unpickler):
- @classmethod
- def persistent_load(cls, pid):
- return pid
- check(PersUnpickler)
- class PersUnpickler(self.unpickler):
- @staticmethod
- def persistent_load(pid):
- return pid
- check(PersUnpickler)
- class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase):
- pickler_class = pickle._Pickler
- unpickler_class = pickle._Unpickler
- class PyDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
- pickler_class = pickle._Pickler
- def get_dispatch_table(self):
- return pickle.dispatch_table.copy()
- class PyChainDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
- pickler_class = pickle._Pickler
- def get_dispatch_table(self):
- return collections.ChainMap({}, pickle.dispatch_table)
- class PyPicklerHookTests(AbstractHookTests, unittest.TestCase):
- class CustomPyPicklerClass(pickle._Pickler,
- AbstractCustomPicklerClass):
- pass
- pickler_class = CustomPyPicklerClass
- if has_c_implementation:
- class CPickleTests(AbstractPickleModuleTests, unittest.TestCase):
- from _pickle import dump, dumps, load, loads, Pickler, Unpickler
- class CUnpicklerTests(PyUnpicklerTests):
- unpickler = _pickle.Unpickler
- bad_stack_errors = (pickle.UnpicklingError,)
- truncated_errors = (pickle.UnpicklingError,)
- class CPicklerTests(PyPicklerTests):
- pickler = _pickle.Pickler
- unpickler = _pickle.Unpickler
- class CPersPicklerTests(PyPersPicklerTests):
- pickler = _pickle.Pickler
- unpickler = _pickle.Unpickler
- class CIdPersPicklerTests(PyIdPersPicklerTests):
- pickler = _pickle.Pickler
- unpickler = _pickle.Unpickler
- class CDumpPickle_LoadPickle(PyPicklerTests):
- pickler = _pickle.Pickler
- unpickler = pickle._Unpickler
- class DumpPickle_CLoadPickle(PyPicklerTests):
- pickler = pickle._Pickler
- unpickler = _pickle.Unpickler
- class CPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase):
- pickler_class = _pickle.Pickler
- unpickler_class = _pickle.Unpickler
- def test_issue18339(self):
- unpickler = self.unpickler_class(io.BytesIO())
- with self.assertRaises(TypeError):
- unpickler.memo = object
- # used to cause a segfault
- with self.assertRaises(ValueError):
- unpickler.memo = {-1: None}
- unpickler.memo = {1: None}
- class CDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
- pickler_class = pickle.Pickler
- def get_dispatch_table(self):
- return pickle.dispatch_table.copy()
- class CChainDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
- pickler_class = pickle.Pickler
- def get_dispatch_table(self):
- return collections.ChainMap({}, pickle.dispatch_table)
- class CPicklerHookTests(AbstractHookTests, unittest.TestCase):
- class CustomCPicklerClass(_pickle.Pickler, AbstractCustomPicklerClass):
- pass
- pickler_class = CustomCPicklerClass
- @support.cpython_only
- class SizeofTests(unittest.TestCase):
- check_sizeof = support.check_sizeof
- def test_pickler(self):
- basesize = support.calcobjsize('7P2n3i2n3i2P')
- p = _pickle.Pickler(io.BytesIO())
- self.assertEqual(object.__sizeof__(p), basesize)
- MT_size = struct.calcsize('3nP0n')
- ME_size = struct.calcsize('Pn0P')
- check = self.check_sizeof
- check(p, basesize +
- MT_size + 8 * ME_size + # Minimal memo table size.
- sys.getsizeof(b'x'*4096)) # Minimal write buffer size.
- for i in range(6):
- p.dump(chr(i))
- check(p, basesize +
- MT_size + 32 * ME_size + # Size of memo table required to
- # save references to 6 objects.
- 0) # Write buffer is cleared after every dump().
- def test_unpickler(self):
- basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n8P2n2i')
- unpickler = _pickle.Unpickler
- P = struct.calcsize('P') # Size of memo table entry.
- n = struct.calcsize('n') # Size of mark table entry.
- check = self.check_sizeof
- for encoding in 'ASCII', 'UTF-16', 'latin-1':
- for errors in 'strict', 'replace':
- u = unpickler(io.BytesIO(),
- encoding=encoding, errors=errors)
- self.assertEqual(object.__sizeof__(u), basesize)
- check(u, basesize +
- 32 * P + # Minimal memo table size.
- len(encoding) + 1 + len(errors) + 1)
- stdsize = basesize + len('ASCII') + 1 + len('strict') + 1
- def check_unpickler(data, memo_size, marks_size):
- dump = pickle.dumps(data)
- u = unpickler(io.BytesIO(dump),
- encoding='ASCII', errors='strict')
- u.load()
- check(u, stdsize + memo_size * P + marks_size * n)
- check_unpickler(0, 32, 0)
- # 20 is minimal non-empty mark stack size.
- check_unpickler([0] * 100, 32, 20)
- # 128 is memo table size required to save references to 100 objects.
- check_unpickler([chr(i) for i in range(100)], 128, 20)
- def recurse(deep):
- data = 0
- for i in range(deep):
- data = [data, data]
- return data
- check_unpickler(recurse(0), 32, 0)
- check_unpickler(recurse(1), 32, 20)
- check_unpickler(recurse(20), 32, 20)
- check_unpickler(recurse(50), 64, 60)
- check_unpickler(recurse(100), 128, 140)
- u = unpickler(io.BytesIO(pickle.dumps('a', 0)),
- encoding='ASCII', errors='strict')
- u.load()
- check(u, stdsize + 32 * P + 2 + 1)
- ALT_IMPORT_MAPPING = {
- ('_elementtree', 'xml.etree.ElementTree'),
- ('cPickle', 'pickle'),
- ('StringIO', 'io'),
- ('cStringIO', 'io'),
- }
- ALT_NAME_MAPPING = {
- ('__builtin__', 'basestring', 'builtins', 'str'),
- ('exceptions', 'StandardError', 'builtins', 'Exception'),
- ('UserDict', 'UserDict', 'collections', 'UserDict'),
- ('socket', '_socketobject', 'socket', 'SocketType'),
- }
- def mapping(module, name):
- if (module, name) in NAME_MAPPING:
- module, name = NAME_MAPPING[(module, name)]
- elif module in IMPORT_MAPPING:
- module = IMPORT_MAPPING[module]
- return module, name
- def reverse_mapping(module, name):
- if (module, name) in REVERSE_NAME_MAPPING:
- module, name = REVERSE_NAME_MAPPING[(module, name)]
- elif module in REVERSE_IMPORT_MAPPING:
- module = REVERSE_IMPORT_MAPPING[module]
- return module, name
- def getmodule(module):
- try:
- return sys.modules[module]
- except KeyError:
- try:
- with warnings.catch_warnings():
- action = 'always' if support.verbose else 'ignore'
- warnings.simplefilter(action, DeprecationWarning)
- __import__(module)
- except AttributeError as exc:
- if support.verbose:
- print("Can't import module %r: %s" % (module, exc))
- raise ImportError
- except ImportError as exc:
- if support.verbose:
- print(exc)
- raise
- return sys.modules[module]
- def getattribute(module, name):
- obj = getmodule(module)
- for n in name.split('.'):
- obj = getattr(obj, n)
- return obj
- def get_exceptions(mod):
- for name in dir(mod):
- attr = getattr(mod, name)
- if isinstance(attr, type) and issubclass(attr, BaseException):
- yield name, attr
- class CompatPickleTests(unittest.TestCase):
- def test_import(self):
- modules = set(IMPORT_MAPPING.values())
- modules |= set(REVERSE_IMPORT_MAPPING)
- modules |= {module for module, name in REVERSE_NAME_MAPPING}
- modules |= {module for module, name in NAME_MAPPING.values()}
- for module in modules:
- try:
- getmodule(module)
- except ImportError:
- pass
- def test_import_mapping(self):
- for module3, module2 in REVERSE_IMPORT_MAPPING.items():
- with self.subTest((module3, module2)):
- try:
- getmodule(module3)
- except ImportError:
- pass
- if module3[:1] != '_':
- self.assertIn(module2, IMPORT_MAPPING)
- self.assertEqual(IMPORT_MAPPING[module2], module3)
- def test_name_mapping(self):
- for (module3, name3), (module2, name2) in REVERSE_NAME_MAPPING.items():
- with self.subTest(((module3, name3), (module2, name2))):
- if (module2, name2) == ('exceptions', 'OSError'):
- attr = getattribute(module3, name3)
- self.assertTrue(issubclass(attr, OSError))
- elif (module2, name2) == ('exceptions', 'ImportError'):
- attr = getattribute(module3, name3)
- self.assertTrue(issubclass(attr, ImportError))
- else:
- module, name = mapping(module2, name2)
- if module3[:1] != '_':
- self.assertEqual((module, name), (module3, name3))
- try:
- attr = getattribute(module3, name3)
- except ImportError:
- pass
- else:
- self.assertEqual(getattribute(module, name), attr)
- def test_reverse_import_mapping(self):
- for module2, module3 in IMPORT_MAPPING.items():
- with self.subTest((module2, module3)):
- try:
- getmodule(module3)
- except ImportError as exc:
- if support.verbose:
- print(exc)
- if ((module2, module3) not in ALT_IMPORT_MAPPING and
- REVERSE_IMPORT_MAPPING.get(module3, None) != module2):
- for (m3, n3), (m2, n2) in REVERSE_NAME_MAPPING.items():
- if (module3, module2) == (m3, m2):
- break
- else:
- self.fail('No reverse mapping from %r to %r' %
- (module3, module2))
- module = REVERSE_IMPORT_MAPPING.get(module3, module3)
- module = IMPORT_MAPPING.get(module, module)
- self.assertEqual(module, module3)
- def test_reverse_name_mapping(self):
- for (module2, name2), (module3, name3) in NAME_MAPPING.items():
- with self.subTest(((module2, name2), (module3, name3))):
- try:
- attr = getattribute(module3, name3)
- except ImportError:
- pass
- module, name = reverse_mapping(module3, name3)
- if (module2, name2, module3, name3) not in ALT_NAME_MAPPING:
- self.assertEqual((module, name), (module2, name2))
- module, name = mapping(module, name)
- self.assertEqual((module, name), (module3, name3))
- def test_exceptions(self):
- self.assertEqual(mapping('exceptions', 'StandardError'),
- ('builtins', 'Exception'))
- self.assertEqual(mapping('exceptions', 'Exception'),
- ('builtins', 'Exception'))
- self.assertEqual(reverse_mapping('builtins', 'Exception'),
- ('exceptions', 'Exception'))
- self.assertEqual(mapping('exceptions', 'OSError'),
- ('builtins', 'OSError'))
- self.assertEqual(reverse_mapping('builtins', 'OSError'),
- ('exceptions', 'OSError'))
- for name, exc in get_exceptions(builtins):
- with self.subTest(name):
- if exc in (BlockingIOError,
- ResourceWarning,
- StopAsyncIteration,
- RecursionError,
- EncodingWarning,
- BaseExceptionGroup,
- ExceptionGroup):
- continue
- if exc is not OSError and issubclass(exc, OSError):
- self.assertEqual(reverse_mapping('builtins', name),
- ('exceptions', 'OSError'))
- elif exc is not ImportError and issubclass(exc, ImportError):
- self.assertEqual(reverse_mapping('builtins', name),
- ('exceptions', 'ImportError'))
- self.assertEqual(mapping('exceptions', name),
- ('exceptions', name))
- else:
- self.assertEqual(reverse_mapping('builtins', name),
- ('exceptions', name))
- self.assertEqual(mapping('exceptions', name),
- ('builtins', name))
- def test_multiprocessing_exceptions(self):
- module = import_helper.import_module('multiprocessing.context')
- for name, exc in get_exceptions(module):
- with self.subTest(name):
- self.assertEqual(reverse_mapping('multiprocessing.context', name),
- ('multiprocessing', name))
- self.assertEqual(mapping('multiprocessing', name),
- ('multiprocessing.context', name))
- def load_tests(loader, tests, pattern):
- tests.addTest(doctest.DocTestSuite())
- return tests
- if __name__ == "__main__":
- unittest.main()
|