test_defaultdict.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. """Unit tests for collections.defaultdict."""
  2. import os
  3. import copy
  4. import pickle
  5. import tempfile
  6. import unittest
  7. from collections import defaultdict
  8. def foobar():
  9. return list
  10. class TestDefaultDict(unittest.TestCase):
  11. def test_basic(self):
  12. d1 = defaultdict()
  13. self.assertEqual(d1.default_factory, None)
  14. d1.default_factory = list
  15. d1[12].append(42)
  16. self.assertEqual(d1, {12: [42]})
  17. d1[12].append(24)
  18. self.assertEqual(d1, {12: [42, 24]})
  19. d1[13]
  20. d1[14]
  21. self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
  22. self.assertTrue(d1[12] is not d1[13] is not d1[14])
  23. d2 = defaultdict(list, foo=1, bar=2)
  24. self.assertEqual(d2.default_factory, list)
  25. self.assertEqual(d2, {"foo": 1, "bar": 2})
  26. self.assertEqual(d2["foo"], 1)
  27. self.assertEqual(d2["bar"], 2)
  28. self.assertEqual(d2[42], [])
  29. self.assertIn("foo", d2)
  30. self.assertIn("foo", d2.keys())
  31. self.assertIn("bar", d2)
  32. self.assertIn("bar", d2.keys())
  33. self.assertIn(42, d2)
  34. self.assertIn(42, d2.keys())
  35. self.assertNotIn(12, d2)
  36. self.assertNotIn(12, d2.keys())
  37. d2.default_factory = None
  38. self.assertEqual(d2.default_factory, None)
  39. try:
  40. d2[15]
  41. except KeyError as err:
  42. self.assertEqual(err.args, (15,))
  43. else:
  44. self.fail("d2[15] didn't raise KeyError")
  45. self.assertRaises(TypeError, defaultdict, 1)
  46. def test_missing(self):
  47. d1 = defaultdict()
  48. self.assertRaises(KeyError, d1.__missing__, 42)
  49. d1.default_factory = list
  50. self.assertEqual(d1.__missing__(42), [])
  51. def test_repr(self):
  52. d1 = defaultdict()
  53. self.assertEqual(d1.default_factory, None)
  54. self.assertEqual(repr(d1), "defaultdict(None, {})")
  55. self.assertEqual(eval(repr(d1)), d1)
  56. d1[11] = 41
  57. self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
  58. d2 = defaultdict(int)
  59. self.assertEqual(d2.default_factory, int)
  60. d2[12] = 42
  61. self.assertEqual(repr(d2), "defaultdict(<class 'int'>, {12: 42})")
  62. def foo(): return 43
  63. d3 = defaultdict(foo)
  64. self.assertTrue(d3.default_factory is foo)
  65. d3[13]
  66. self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
  67. def test_copy(self):
  68. d1 = defaultdict()
  69. d2 = d1.copy()
  70. self.assertEqual(type(d2), defaultdict)
  71. self.assertEqual(d2.default_factory, None)
  72. self.assertEqual(d2, {})
  73. d1.default_factory = list
  74. d3 = d1.copy()
  75. self.assertEqual(type(d3), defaultdict)
  76. self.assertEqual(d3.default_factory, list)
  77. self.assertEqual(d3, {})
  78. d1[42]
  79. d4 = d1.copy()
  80. self.assertEqual(type(d4), defaultdict)
  81. self.assertEqual(d4.default_factory, list)
  82. self.assertEqual(d4, {42: []})
  83. d4[12]
  84. self.assertEqual(d4, {42: [], 12: []})
  85. # Issue 6637: Copy fails for empty default dict
  86. d = defaultdict()
  87. d['a'] = 42
  88. e = d.copy()
  89. self.assertEqual(e['a'], 42)
  90. def test_shallow_copy(self):
  91. d1 = defaultdict(foobar, {1: 1})
  92. d2 = copy.copy(d1)
  93. self.assertEqual(d2.default_factory, foobar)
  94. self.assertEqual(d2, d1)
  95. d1.default_factory = list
  96. d2 = copy.copy(d1)
  97. self.assertEqual(d2.default_factory, list)
  98. self.assertEqual(d2, d1)
  99. def test_deep_copy(self):
  100. d1 = defaultdict(foobar, {1: [1]})
  101. d2 = copy.deepcopy(d1)
  102. self.assertEqual(d2.default_factory, foobar)
  103. self.assertEqual(d2, d1)
  104. self.assertTrue(d1[1] is not d2[1])
  105. d1.default_factory = list
  106. d2 = copy.deepcopy(d1)
  107. self.assertEqual(d2.default_factory, list)
  108. self.assertEqual(d2, d1)
  109. def test_keyerror_without_factory(self):
  110. d1 = defaultdict()
  111. try:
  112. d1[(1,)]
  113. except KeyError as err:
  114. self.assertEqual(err.args[0], (1,))
  115. else:
  116. self.fail("expected KeyError")
  117. def test_recursive_repr(self):
  118. # Issue2045: stack overflow when default_factory is a bound method
  119. class sub(defaultdict):
  120. def __init__(self):
  121. self.default_factory = self._factory
  122. def _factory(self):
  123. return []
  124. d = sub()
  125. self.assertRegex(repr(d),
  126. r"sub\(<bound method .*sub\._factory "
  127. r"of sub\(\.\.\., \{\}\)>, \{\}\)")
  128. def test_callable_arg(self):
  129. self.assertRaises(TypeError, defaultdict, {})
  130. def test_pickling(self):
  131. d = defaultdict(int)
  132. d[1]
  133. for proto in range(pickle.HIGHEST_PROTOCOL + 1):
  134. s = pickle.dumps(d, proto)
  135. o = pickle.loads(s)
  136. self.assertEqual(d, o)
  137. def test_union(self):
  138. i = defaultdict(int, {1: 1, 2: 2})
  139. s = defaultdict(str, {0: "zero", 1: "one"})
  140. i_s = i | s
  141. self.assertIs(i_s.default_factory, int)
  142. self.assertDictEqual(i_s, {1: "one", 2: 2, 0: "zero"})
  143. self.assertEqual(list(i_s), [1, 2, 0])
  144. s_i = s | i
  145. self.assertIs(s_i.default_factory, str)
  146. self.assertDictEqual(s_i, {0: "zero", 1: 1, 2: 2})
  147. self.assertEqual(list(s_i), [0, 1, 2])
  148. i_ds = i | dict(s)
  149. self.assertIs(i_ds.default_factory, int)
  150. self.assertDictEqual(i_ds, {1: "one", 2: 2, 0: "zero"})
  151. self.assertEqual(list(i_ds), [1, 2, 0])
  152. ds_i = dict(s) | i
  153. self.assertIs(ds_i.default_factory, int)
  154. self.assertDictEqual(ds_i, {0: "zero", 1: 1, 2: 2})
  155. self.assertEqual(list(ds_i), [0, 1, 2])
  156. with self.assertRaises(TypeError):
  157. i | list(s.items())
  158. with self.assertRaises(TypeError):
  159. list(s.items()) | i
  160. # We inherit a fine |= from dict, so just a few sanity checks here:
  161. i |= list(s.items())
  162. self.assertIs(i.default_factory, int)
  163. self.assertDictEqual(i, {1: "one", 2: 2, 0: "zero"})
  164. self.assertEqual(list(i), [1, 2, 0])
  165. with self.assertRaises(TypeError):
  166. i |= None
  167. if __name__ == "__main__":
  168. unittest.main()