test_threading_local.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. import sys
  2. import unittest
  3. from doctest import DocTestSuite
  4. from test import support
  5. from test.support import threading_helper
  6. from test.support.import_helper import import_module
  7. import weakref
  8. import gc
  9. # Modules under test
  10. import _thread
  11. import threading
  12. import _threading_local
  13. threading_helper.requires_working_threading(module=True)
  14. class Weak(object):
  15. pass
  16. def target(local, weaklist):
  17. weak = Weak()
  18. local.weak = weak
  19. weaklist.append(weakref.ref(weak))
  20. class BaseLocalTest:
  21. def test_local_refs(self):
  22. self._local_refs(20)
  23. self._local_refs(50)
  24. self._local_refs(100)
  25. def _local_refs(self, n):
  26. local = self._local()
  27. weaklist = []
  28. for i in range(n):
  29. t = threading.Thread(target=target, args=(local, weaklist))
  30. t.start()
  31. t.join()
  32. del t
  33. support.gc_collect() # For PyPy or other GCs.
  34. self.assertEqual(len(weaklist), n)
  35. # XXX _threading_local keeps the local of the last stopped thread alive.
  36. deadlist = [weak for weak in weaklist if weak() is None]
  37. self.assertIn(len(deadlist), (n-1, n))
  38. # Assignment to the same thread local frees it sometimes (!)
  39. local.someothervar = None
  40. support.gc_collect() # For PyPy or other GCs.
  41. deadlist = [weak for weak in weaklist if weak() is None]
  42. self.assertIn(len(deadlist), (n-1, n), (n, len(deadlist)))
  43. def test_derived(self):
  44. # Issue 3088: if there is a threads switch inside the __init__
  45. # of a threading.local derived class, the per-thread dictionary
  46. # is created but not correctly set on the object.
  47. # The first member set may be bogus.
  48. import time
  49. class Local(self._local):
  50. def __init__(self):
  51. time.sleep(0.01)
  52. local = Local()
  53. def f(i):
  54. local.x = i
  55. # Simply check that the variable is correctly set
  56. self.assertEqual(local.x, i)
  57. with threading_helper.start_threads(threading.Thread(target=f, args=(i,))
  58. for i in range(10)):
  59. pass
  60. def test_derived_cycle_dealloc(self):
  61. # http://bugs.python.org/issue6990
  62. class Local(self._local):
  63. pass
  64. locals = None
  65. passed = False
  66. e1 = threading.Event()
  67. e2 = threading.Event()
  68. def f():
  69. nonlocal passed
  70. # 1) Involve Local in a cycle
  71. cycle = [Local()]
  72. cycle.append(cycle)
  73. cycle[0].foo = 'bar'
  74. # 2) GC the cycle (triggers threadmodule.c::local_clear
  75. # before local_dealloc)
  76. del cycle
  77. support.gc_collect() # For PyPy or other GCs.
  78. e1.set()
  79. e2.wait()
  80. # 4) New Locals should be empty
  81. passed = all(not hasattr(local, 'foo') for local in locals)
  82. t = threading.Thread(target=f)
  83. t.start()
  84. e1.wait()
  85. # 3) New Locals should recycle the original's address. Creating
  86. # them in the thread overwrites the thread state and avoids the
  87. # bug
  88. locals = [Local() for i in range(10)]
  89. e2.set()
  90. t.join()
  91. self.assertTrue(passed)
  92. def test_arguments(self):
  93. # Issue 1522237
  94. class MyLocal(self._local):
  95. def __init__(self, *args, **kwargs):
  96. pass
  97. MyLocal(a=1)
  98. MyLocal(1)
  99. self.assertRaises(TypeError, self._local, a=1)
  100. self.assertRaises(TypeError, self._local, 1)
  101. def _test_one_class(self, c):
  102. self._failed = "No error message set or cleared."
  103. obj = c()
  104. e1 = threading.Event()
  105. e2 = threading.Event()
  106. def f1():
  107. obj.x = 'foo'
  108. obj.y = 'bar'
  109. del obj.y
  110. e1.set()
  111. e2.wait()
  112. def f2():
  113. try:
  114. foo = obj.x
  115. except AttributeError:
  116. # This is expected -- we haven't set obj.x in this thread yet!
  117. self._failed = "" # passed
  118. else:
  119. self._failed = ('Incorrectly got value %r from class %r\n' %
  120. (foo, c))
  121. sys.stderr.write(self._failed)
  122. t1 = threading.Thread(target=f1)
  123. t1.start()
  124. e1.wait()
  125. t2 = threading.Thread(target=f2)
  126. t2.start()
  127. t2.join()
  128. # The test is done; just let t1 know it can exit, and wait for it.
  129. e2.set()
  130. t1.join()
  131. self.assertFalse(self._failed, self._failed)
  132. def test_threading_local(self):
  133. self._test_one_class(self._local)
  134. def test_threading_local_subclass(self):
  135. class LocalSubclass(self._local):
  136. """To test that subclasses behave properly."""
  137. self._test_one_class(LocalSubclass)
  138. def _test_dict_attribute(self, cls):
  139. obj = cls()
  140. obj.x = 5
  141. self.assertEqual(obj.__dict__, {'x': 5})
  142. with self.assertRaises(AttributeError):
  143. obj.__dict__ = {}
  144. with self.assertRaises(AttributeError):
  145. del obj.__dict__
  146. def test_dict_attribute(self):
  147. self._test_dict_attribute(self._local)
  148. def test_dict_attribute_subclass(self):
  149. class LocalSubclass(self._local):
  150. """To test that subclasses behave properly."""
  151. self._test_dict_attribute(LocalSubclass)
  152. def test_cycle_collection(self):
  153. class X:
  154. pass
  155. x = X()
  156. x.local = self._local()
  157. x.local.x = x
  158. wr = weakref.ref(x)
  159. del x
  160. support.gc_collect() # For PyPy or other GCs.
  161. self.assertIsNone(wr())
  162. def test_threading_local_clear_race(self):
  163. # See https://github.com/python/cpython/issues/100892
  164. _testcapi = import_module('_testcapi')
  165. _testcapi.call_in_temporary_c_thread(lambda: None, False)
  166. for _ in range(1000):
  167. _ = threading.local()
  168. _testcapi.join_temporary_c_thread()
  169. class ThreadLocalTest(unittest.TestCase, BaseLocalTest):
  170. _local = _thread._local
  171. class PyThreadingLocalTest(unittest.TestCase, BaseLocalTest):
  172. _local = _threading_local.local
  173. def load_tests(loader, tests, pattern):
  174. tests.addTest(DocTestSuite('_threading_local'))
  175. local_orig = _threading_local.local
  176. def setUp(test):
  177. _threading_local.local = _thread._local
  178. def tearDown(test):
  179. _threading_local.local = local_orig
  180. tests.addTests(DocTestSuite('_threading_local',
  181. setUp=setUp, tearDown=tearDown)
  182. )
  183. return tests
  184. if __name__ == '__main__':
  185. unittest.main()