test_functools.py 109 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066
  1. import abc
  2. import builtins
  3. import collections
  4. import collections.abc
  5. import copy
  6. from itertools import permutations
  7. import pickle
  8. from random import choice
  9. import sys
  10. from test import support
  11. import threading
  12. import time
  13. import typing
  14. import unittest
  15. import unittest.mock
  16. import os
  17. import weakref
  18. import gc
  19. from weakref import proxy
  20. import contextlib
  21. from test.support import import_helper
  22. from test.support import threading_helper
  23. from test.support.script_helper import assert_python_ok
  24. import functools
  25. py_functools = import_helper.import_fresh_module('functools',
  26. blocked=['_functools'])
  27. c_functools = import_helper.import_fresh_module('functools')
  28. decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal'])
  29. @contextlib.contextmanager
  30. def replaced_module(name, replacement):
  31. original_module = sys.modules[name]
  32. sys.modules[name] = replacement
  33. try:
  34. yield
  35. finally:
  36. sys.modules[name] = original_module
  37. def capture(*args, **kw):
  38. """capture all positional and keyword arguments"""
  39. return args, kw
  40. def signature(part):
  41. """ return the signature of a partial object """
  42. return (part.func, part.args, part.keywords, part.__dict__)
  43. class MyTuple(tuple):
  44. pass
  45. class BadTuple(tuple):
  46. def __add__(self, other):
  47. return list(self) + list(other)
  48. class MyDict(dict):
  49. pass
  50. class TestPartial:
  51. def test_basic_examples(self):
  52. p = self.partial(capture, 1, 2, a=10, b=20)
  53. self.assertTrue(callable(p))
  54. self.assertEqual(p(3, 4, b=30, c=40),
  55. ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
  56. p = self.partial(map, lambda x: x*10)
  57. self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
  58. def test_attributes(self):
  59. p = self.partial(capture, 1, 2, a=10, b=20)
  60. # attributes should be readable
  61. self.assertEqual(p.func, capture)
  62. self.assertEqual(p.args, (1, 2))
  63. self.assertEqual(p.keywords, dict(a=10, b=20))
  64. def test_argument_checking(self):
  65. self.assertRaises(TypeError, self.partial) # need at least a func arg
  66. try:
  67. self.partial(2)()
  68. except TypeError:
  69. pass
  70. else:
  71. self.fail('First arg not checked for callability')
  72. def test_protection_of_callers_dict_argument(self):
  73. # a caller's dictionary should not be altered by partial
  74. def func(a=10, b=20):
  75. return a
  76. d = {'a':3}
  77. p = self.partial(func, a=5)
  78. self.assertEqual(p(**d), 3)
  79. self.assertEqual(d, {'a':3})
  80. p(b=7)
  81. self.assertEqual(d, {'a':3})
  82. def test_kwargs_copy(self):
  83. # Issue #29532: Altering a kwarg dictionary passed to a constructor
  84. # should not affect a partial object after creation
  85. d = {'a': 3}
  86. p = self.partial(capture, **d)
  87. self.assertEqual(p(), ((), {'a': 3}))
  88. d['a'] = 5
  89. self.assertEqual(p(), ((), {'a': 3}))
  90. def test_arg_combinations(self):
  91. # exercise special code paths for zero args in either partial
  92. # object or the caller
  93. p = self.partial(capture)
  94. self.assertEqual(p(), ((), {}))
  95. self.assertEqual(p(1,2), ((1,2), {}))
  96. p = self.partial(capture, 1, 2)
  97. self.assertEqual(p(), ((1,2), {}))
  98. self.assertEqual(p(3,4), ((1,2,3,4), {}))
  99. def test_kw_combinations(self):
  100. # exercise special code paths for no keyword args in
  101. # either the partial object or the caller
  102. p = self.partial(capture)
  103. self.assertEqual(p.keywords, {})
  104. self.assertEqual(p(), ((), {}))
  105. self.assertEqual(p(a=1), ((), {'a':1}))
  106. p = self.partial(capture, a=1)
  107. self.assertEqual(p.keywords, {'a':1})
  108. self.assertEqual(p(), ((), {'a':1}))
  109. self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
  110. # keyword args in the call override those in the partial object
  111. self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
  112. def test_positional(self):
  113. # make sure positional arguments are captured correctly
  114. for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
  115. p = self.partial(capture, *args)
  116. expected = args + ('x',)
  117. got, empty = p('x')
  118. self.assertTrue(expected == got and empty == {})
  119. def test_keyword(self):
  120. # make sure keyword arguments are captured correctly
  121. for a in ['a', 0, None, 3.5]:
  122. p = self.partial(capture, a=a)
  123. expected = {'a':a,'x':None}
  124. empty, got = p(x=None)
  125. self.assertTrue(expected == got and empty == ())
  126. def test_no_side_effects(self):
  127. # make sure there are no side effects that affect subsequent calls
  128. p = self.partial(capture, 0, a=1)
  129. args1, kw1 = p(1, b=2)
  130. self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
  131. args2, kw2 = p()
  132. self.assertTrue(args2 == (0,) and kw2 == {'a':1})
  133. def test_error_propagation(self):
  134. def f(x, y):
  135. x / y
  136. self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
  137. self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
  138. self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
  139. self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
  140. def test_weakref(self):
  141. f = self.partial(int, base=16)
  142. p = proxy(f)
  143. self.assertEqual(f.func, p.func)
  144. f = None
  145. support.gc_collect() # For PyPy or other GCs.
  146. self.assertRaises(ReferenceError, getattr, p, 'func')
  147. def test_with_bound_and_unbound_methods(self):
  148. data = list(map(str, range(10)))
  149. join = self.partial(str.join, '')
  150. self.assertEqual(join(data), '0123456789')
  151. join = self.partial(''.join)
  152. self.assertEqual(join(data), '0123456789')
  153. def test_nested_optimization(self):
  154. partial = self.partial
  155. inner = partial(signature, 'asdf')
  156. nested = partial(inner, bar=True)
  157. flat = partial(signature, 'asdf', bar=True)
  158. self.assertEqual(signature(nested), signature(flat))
  159. def test_nested_partial_with_attribute(self):
  160. # see issue 25137
  161. partial = self.partial
  162. def foo(bar):
  163. return bar
  164. p = partial(foo, 'first')
  165. p2 = partial(p, 'second')
  166. p2.new_attr = 'spam'
  167. self.assertEqual(p2.new_attr, 'spam')
  168. def test_repr(self):
  169. args = (object(), object())
  170. args_repr = ', '.join(repr(a) for a in args)
  171. kwargs = {'a': object(), 'b': object()}
  172. kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
  173. 'b={b!r}, a={a!r}'.format_map(kwargs)]
  174. if self.partial in (c_functools.partial, py_functools.partial):
  175. name = 'functools.partial'
  176. else:
  177. name = self.partial.__name__
  178. f = self.partial(capture)
  179. self.assertEqual(f'{name}({capture!r})', repr(f))
  180. f = self.partial(capture, *args)
  181. self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
  182. f = self.partial(capture, **kwargs)
  183. self.assertIn(repr(f),
  184. [f'{name}({capture!r}, {kwargs_repr})'
  185. for kwargs_repr in kwargs_reprs])
  186. f = self.partial(capture, *args, **kwargs)
  187. self.assertIn(repr(f),
  188. [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
  189. for kwargs_repr in kwargs_reprs])
  190. def test_recursive_repr(self):
  191. if self.partial in (c_functools.partial, py_functools.partial):
  192. name = 'functools.partial'
  193. else:
  194. name = self.partial.__name__
  195. f = self.partial(capture)
  196. f.__setstate__((f, (), {}, {}))
  197. try:
  198. self.assertEqual(repr(f), '%s(...)' % (name,))
  199. finally:
  200. f.__setstate__((capture, (), {}, {}))
  201. f = self.partial(capture)
  202. f.__setstate__((capture, (f,), {}, {}))
  203. try:
  204. self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
  205. finally:
  206. f.__setstate__((capture, (), {}, {}))
  207. f = self.partial(capture)
  208. f.__setstate__((capture, (), {'a': f}, {}))
  209. try:
  210. self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
  211. finally:
  212. f.__setstate__((capture, (), {}, {}))
  213. def test_pickle(self):
  214. with self.AllowPickle():
  215. f = self.partial(signature, ['asdf'], bar=[True])
  216. f.attr = []
  217. for proto in range(pickle.HIGHEST_PROTOCOL + 1):
  218. f_copy = pickle.loads(pickle.dumps(f, proto))
  219. self.assertEqual(signature(f_copy), signature(f))
  220. def test_copy(self):
  221. f = self.partial(signature, ['asdf'], bar=[True])
  222. f.attr = []
  223. f_copy = copy.copy(f)
  224. self.assertEqual(signature(f_copy), signature(f))
  225. self.assertIs(f_copy.attr, f.attr)
  226. self.assertIs(f_copy.args, f.args)
  227. self.assertIs(f_copy.keywords, f.keywords)
  228. def test_deepcopy(self):
  229. f = self.partial(signature, ['asdf'], bar=[True])
  230. f.attr = []
  231. f_copy = copy.deepcopy(f)
  232. self.assertEqual(signature(f_copy), signature(f))
  233. self.assertIsNot(f_copy.attr, f.attr)
  234. self.assertIsNot(f_copy.args, f.args)
  235. self.assertIsNot(f_copy.args[0], f.args[0])
  236. self.assertIsNot(f_copy.keywords, f.keywords)
  237. self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
  238. def test_setstate(self):
  239. f = self.partial(signature)
  240. f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
  241. self.assertEqual(signature(f),
  242. (capture, (1,), dict(a=10), dict(attr=[])))
  243. self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
  244. f.__setstate__((capture, (1,), dict(a=10), None))
  245. self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
  246. self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
  247. f.__setstate__((capture, (1,), None, None))
  248. #self.assertEqual(signature(f), (capture, (1,), {}, {}))
  249. self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
  250. self.assertEqual(f(2), ((1, 2), {}))
  251. self.assertEqual(f(), ((1,), {}))
  252. f.__setstate__((capture, (), {}, None))
  253. self.assertEqual(signature(f), (capture, (), {}, {}))
  254. self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
  255. self.assertEqual(f(2), ((2,), {}))
  256. self.assertEqual(f(), ((), {}))
  257. def test_setstate_errors(self):
  258. f = self.partial(signature)
  259. self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
  260. self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
  261. self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
  262. self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
  263. self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
  264. self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
  265. self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
  266. def test_setstate_subclasses(self):
  267. f = self.partial(signature)
  268. f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
  269. s = signature(f)
  270. self.assertEqual(s, (capture, (1,), dict(a=10), {}))
  271. self.assertIs(type(s[1]), tuple)
  272. self.assertIs(type(s[2]), dict)
  273. r = f()
  274. self.assertEqual(r, ((1,), {'a': 10}))
  275. self.assertIs(type(r[0]), tuple)
  276. self.assertIs(type(r[1]), dict)
  277. f.__setstate__((capture, BadTuple((1,)), {}, None))
  278. s = signature(f)
  279. self.assertEqual(s, (capture, (1,), {}, {}))
  280. self.assertIs(type(s[1]), tuple)
  281. r = f(2)
  282. self.assertEqual(r, ((1, 2), {}))
  283. self.assertIs(type(r[0]), tuple)
  284. def test_recursive_pickle(self):
  285. with self.AllowPickle():
  286. f = self.partial(capture)
  287. f.__setstate__((f, (), {}, {}))
  288. try:
  289. for proto in range(pickle.HIGHEST_PROTOCOL + 1):
  290. with self.assertRaises(RecursionError):
  291. pickle.dumps(f, proto)
  292. finally:
  293. f.__setstate__((capture, (), {}, {}))
  294. f = self.partial(capture)
  295. f.__setstate__((capture, (f,), {}, {}))
  296. try:
  297. for proto in range(pickle.HIGHEST_PROTOCOL + 1):
  298. f_copy = pickle.loads(pickle.dumps(f, proto))
  299. try:
  300. self.assertIs(f_copy.args[0], f_copy)
  301. finally:
  302. f_copy.__setstate__((capture, (), {}, {}))
  303. finally:
  304. f.__setstate__((capture, (), {}, {}))
  305. f = self.partial(capture)
  306. f.__setstate__((capture, (), {'a': f}, {}))
  307. try:
  308. for proto in range(pickle.HIGHEST_PROTOCOL + 1):
  309. f_copy = pickle.loads(pickle.dumps(f, proto))
  310. try:
  311. self.assertIs(f_copy.keywords['a'], f_copy)
  312. finally:
  313. f_copy.__setstate__((capture, (), {}, {}))
  314. finally:
  315. f.__setstate__((capture, (), {}, {}))
  316. # Issue 6083: Reference counting bug
  317. def test_setstate_refcount(self):
  318. class BadSequence:
  319. def __len__(self):
  320. return 4
  321. def __getitem__(self, key):
  322. if key == 0:
  323. return max
  324. elif key == 1:
  325. return tuple(range(1000000))
  326. elif key in (2, 3):
  327. return {}
  328. raise IndexError
  329. f = self.partial(object)
  330. self.assertRaises(TypeError, f.__setstate__, BadSequence())
  331. @unittest.skipUnless(c_functools, 'requires the C _functools module')
  332. class TestPartialC(TestPartial, unittest.TestCase):
  333. if c_functools:
  334. partial = c_functools.partial
  335. class AllowPickle:
  336. def __enter__(self):
  337. return self
  338. def __exit__(self, type, value, tb):
  339. return False
  340. def test_attributes_unwritable(self):
  341. # attributes should not be writable
  342. p = self.partial(capture, 1, 2, a=10, b=20)
  343. self.assertRaises(AttributeError, setattr, p, 'func', map)
  344. self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
  345. self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
  346. p = self.partial(hex)
  347. try:
  348. del p.__dict__
  349. except TypeError:
  350. pass
  351. else:
  352. self.fail('partial object allowed __dict__ to be deleted')
  353. def test_manually_adding_non_string_keyword(self):
  354. p = self.partial(capture)
  355. # Adding a non-string/unicode keyword to partial kwargs
  356. p.keywords[1234] = 'value'
  357. r = repr(p)
  358. self.assertIn('1234', r)
  359. self.assertIn("'value'", r)
  360. with self.assertRaises(TypeError):
  361. p()
  362. def test_keystr_replaces_value(self):
  363. p = self.partial(capture)
  364. class MutatesYourDict(object):
  365. def __str__(self):
  366. p.keywords[self] = ['sth2']
  367. return 'astr'
  368. # Replacing the value during key formatting should keep the original
  369. # value alive (at least long enough).
  370. p.keywords[MutatesYourDict()] = ['sth']
  371. r = repr(p)
  372. self.assertIn('astr', r)
  373. self.assertIn("['sth']", r)
  374. class TestPartialPy(TestPartial, unittest.TestCase):
  375. partial = py_functools.partial
  376. class AllowPickle:
  377. def __init__(self):
  378. self._cm = replaced_module("functools", py_functools)
  379. def __enter__(self):
  380. return self._cm.__enter__()
  381. def __exit__(self, type, value, tb):
  382. return self._cm.__exit__(type, value, tb)
  383. if c_functools:
  384. class CPartialSubclass(c_functools.partial):
  385. pass
  386. class PyPartialSubclass(py_functools.partial):
  387. pass
  388. @unittest.skipUnless(c_functools, 'requires the C _functools module')
  389. class TestPartialCSubclass(TestPartialC):
  390. if c_functools:
  391. partial = CPartialSubclass
  392. # partial subclasses are not optimized for nested calls
  393. test_nested_optimization = None
  394. class TestPartialPySubclass(TestPartialPy):
  395. partial = PyPartialSubclass
  396. class TestPartialMethod(unittest.TestCase):
  397. class A(object):
  398. nothing = functools.partialmethod(capture)
  399. positional = functools.partialmethod(capture, 1)
  400. keywords = functools.partialmethod(capture, a=2)
  401. both = functools.partialmethod(capture, 3, b=4)
  402. spec_keywords = functools.partialmethod(capture, self=1, func=2)
  403. nested = functools.partialmethod(positional, 5)
  404. over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
  405. static = functools.partialmethod(staticmethod(capture), 8)
  406. cls = functools.partialmethod(classmethod(capture), d=9)
  407. a = A()
  408. def test_arg_combinations(self):
  409. self.assertEqual(self.a.nothing(), ((self.a,), {}))
  410. self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
  411. self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
  412. self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
  413. self.assertEqual(self.a.positional(), ((self.a, 1), {}))
  414. self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
  415. self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
  416. self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
  417. self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
  418. self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
  419. self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
  420. self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
  421. self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
  422. self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
  423. self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
  424. self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
  425. self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
  426. self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
  427. def test_nested(self):
  428. self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
  429. self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
  430. self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
  431. self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
  432. self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
  433. def test_over_partial(self):
  434. self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
  435. self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
  436. self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
  437. self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
  438. self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
  439. def test_bound_method_introspection(self):
  440. obj = self.a
  441. self.assertIs(obj.both.__self__, obj)
  442. self.assertIs(obj.nested.__self__, obj)
  443. self.assertIs(obj.over_partial.__self__, obj)
  444. self.assertIs(obj.cls.__self__, self.A)
  445. self.assertIs(self.A.cls.__self__, self.A)
  446. def test_unbound_method_retrieval(self):
  447. obj = self.A
  448. self.assertFalse(hasattr(obj.both, "__self__"))
  449. self.assertFalse(hasattr(obj.nested, "__self__"))
  450. self.assertFalse(hasattr(obj.over_partial, "__self__"))
  451. self.assertFalse(hasattr(obj.static, "__self__"))
  452. self.assertFalse(hasattr(self.a.static, "__self__"))
  453. def test_descriptors(self):
  454. for obj in [self.A, self.a]:
  455. with self.subTest(obj=obj):
  456. self.assertEqual(obj.static(), ((8,), {}))
  457. self.assertEqual(obj.static(5), ((8, 5), {}))
  458. self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
  459. self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
  460. self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
  461. self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
  462. self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
  463. self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
  464. def test_overriding_keywords(self):
  465. self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
  466. self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
  467. def test_invalid_args(self):
  468. with self.assertRaises(TypeError):
  469. class B(object):
  470. method = functools.partialmethod(None, 1)
  471. with self.assertRaises(TypeError):
  472. class B:
  473. method = functools.partialmethod()
  474. with self.assertRaises(TypeError):
  475. class B:
  476. method = functools.partialmethod(func=capture, a=1)
  477. def test_repr(self):
  478. self.assertEqual(repr(vars(self.A)['both']),
  479. 'functools.partialmethod({}, 3, b=4)'.format(capture))
  480. def test_abstract(self):
  481. class Abstract(abc.ABCMeta):
  482. @abc.abstractmethod
  483. def add(self, x, y):
  484. pass
  485. add5 = functools.partialmethod(add, 5)
  486. self.assertTrue(Abstract.add.__isabstractmethod__)
  487. self.assertTrue(Abstract.add5.__isabstractmethod__)
  488. for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
  489. self.assertFalse(getattr(func, '__isabstractmethod__', False))
  490. def test_positional_only(self):
  491. def f(a, b, /):
  492. return a + b
  493. p = functools.partial(f, 1)
  494. self.assertEqual(p(2), f(1, 2))
  495. class TestUpdateWrapper(unittest.TestCase):
  496. def check_wrapper(self, wrapper, wrapped,
  497. assigned=functools.WRAPPER_ASSIGNMENTS,
  498. updated=functools.WRAPPER_UPDATES):
  499. # Check attributes were assigned
  500. for name in assigned:
  501. self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
  502. # Check attributes were updated
  503. for name in updated:
  504. wrapper_attr = getattr(wrapper, name)
  505. wrapped_attr = getattr(wrapped, name)
  506. for key in wrapped_attr:
  507. if name == "__dict__" and key == "__wrapped__":
  508. # __wrapped__ is overwritten by the update code
  509. continue
  510. self.assertIs(wrapped_attr[key], wrapper_attr[key])
  511. # Check __wrapped__
  512. self.assertIs(wrapper.__wrapped__, wrapped)
  513. def _default_update(self):
  514. def f(a:'This is a new annotation'):
  515. """This is a test"""
  516. pass
  517. f.attr = 'This is also a test'
  518. f.__wrapped__ = "This is a bald faced lie"
  519. def wrapper(b:'This is the prior annotation'):
  520. pass
  521. functools.update_wrapper(wrapper, f)
  522. return wrapper, f
  523. def test_default_update(self):
  524. wrapper, f = self._default_update()
  525. self.check_wrapper(wrapper, f)
  526. self.assertIs(wrapper.__wrapped__, f)
  527. self.assertEqual(wrapper.__name__, 'f')
  528. self.assertEqual(wrapper.__qualname__, f.__qualname__)
  529. self.assertEqual(wrapper.attr, 'This is also a test')
  530. self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
  531. self.assertNotIn('b', wrapper.__annotations__)
  532. @unittest.skipIf(sys.flags.optimize >= 2,
  533. "Docstrings are omitted with -O2 and above")
  534. def test_default_update_doc(self):
  535. wrapper, f = self._default_update()
  536. self.assertEqual(wrapper.__doc__, 'This is a test')
  537. def test_no_update(self):
  538. def f():
  539. """This is a test"""
  540. pass
  541. f.attr = 'This is also a test'
  542. def wrapper():
  543. pass
  544. functools.update_wrapper(wrapper, f, (), ())
  545. self.check_wrapper(wrapper, f, (), ())
  546. self.assertEqual(wrapper.__name__, 'wrapper')
  547. self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
  548. self.assertEqual(wrapper.__doc__, None)
  549. self.assertEqual(wrapper.__annotations__, {})
  550. self.assertFalse(hasattr(wrapper, 'attr'))
  551. def test_selective_update(self):
  552. def f():
  553. pass
  554. f.attr = 'This is a different test'
  555. f.dict_attr = dict(a=1, b=2, c=3)
  556. def wrapper():
  557. pass
  558. wrapper.dict_attr = {}
  559. assign = ('attr',)
  560. update = ('dict_attr',)
  561. functools.update_wrapper(wrapper, f, assign, update)
  562. self.check_wrapper(wrapper, f, assign, update)
  563. self.assertEqual(wrapper.__name__, 'wrapper')
  564. self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
  565. self.assertEqual(wrapper.__doc__, None)
  566. self.assertEqual(wrapper.attr, 'This is a different test')
  567. self.assertEqual(wrapper.dict_attr, f.dict_attr)
  568. def test_missing_attributes(self):
  569. def f():
  570. pass
  571. def wrapper():
  572. pass
  573. wrapper.dict_attr = {}
  574. assign = ('attr',)
  575. update = ('dict_attr',)
  576. # Missing attributes on wrapped object are ignored
  577. functools.update_wrapper(wrapper, f, assign, update)
  578. self.assertNotIn('attr', wrapper.__dict__)
  579. self.assertEqual(wrapper.dict_attr, {})
  580. # Wrapper must have expected attributes for updating
  581. del wrapper.dict_attr
  582. with self.assertRaises(AttributeError):
  583. functools.update_wrapper(wrapper, f, assign, update)
  584. wrapper.dict_attr = 1
  585. with self.assertRaises(AttributeError):
  586. functools.update_wrapper(wrapper, f, assign, update)
  587. @support.requires_docstrings
  588. @unittest.skipIf(sys.flags.optimize >= 2,
  589. "Docstrings are omitted with -O2 and above")
  590. def test_builtin_update(self):
  591. # Test for bug #1576241
  592. def wrapper():
  593. pass
  594. functools.update_wrapper(wrapper, max)
  595. self.assertEqual(wrapper.__name__, 'max')
  596. self.assertTrue(wrapper.__doc__.startswith('max('))
  597. self.assertEqual(wrapper.__annotations__, {})
  598. class TestWraps(TestUpdateWrapper):
  599. def _default_update(self):
  600. def f():
  601. """This is a test"""
  602. pass
  603. f.attr = 'This is also a test'
  604. f.__wrapped__ = "This is still a bald faced lie"
  605. @functools.wraps(f)
  606. def wrapper():
  607. pass
  608. return wrapper, f
  609. def test_default_update(self):
  610. wrapper, f = self._default_update()
  611. self.check_wrapper(wrapper, f)
  612. self.assertEqual(wrapper.__name__, 'f')
  613. self.assertEqual(wrapper.__qualname__, f.__qualname__)
  614. self.assertEqual(wrapper.attr, 'This is also a test')
  615. @unittest.skipIf(sys.flags.optimize >= 2,
  616. "Docstrings are omitted with -O2 and above")
  617. def test_default_update_doc(self):
  618. wrapper, _ = self._default_update()
  619. self.assertEqual(wrapper.__doc__, 'This is a test')
  620. def test_no_update(self):
  621. def f():
  622. """This is a test"""
  623. pass
  624. f.attr = 'This is also a test'
  625. @functools.wraps(f, (), ())
  626. def wrapper():
  627. pass
  628. self.check_wrapper(wrapper, f, (), ())
  629. self.assertEqual(wrapper.__name__, 'wrapper')
  630. self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
  631. self.assertEqual(wrapper.__doc__, None)
  632. self.assertFalse(hasattr(wrapper, 'attr'))
  633. def test_selective_update(self):
  634. def f():
  635. pass
  636. f.attr = 'This is a different test'
  637. f.dict_attr = dict(a=1, b=2, c=3)
  638. def add_dict_attr(f):
  639. f.dict_attr = {}
  640. return f
  641. assign = ('attr',)
  642. update = ('dict_attr',)
  643. @functools.wraps(f, assign, update)
  644. @add_dict_attr
  645. def wrapper():
  646. pass
  647. self.check_wrapper(wrapper, f, assign, update)
  648. self.assertEqual(wrapper.__name__, 'wrapper')
  649. self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
  650. self.assertEqual(wrapper.__doc__, None)
  651. self.assertEqual(wrapper.attr, 'This is a different test')
  652. self.assertEqual(wrapper.dict_attr, f.dict_attr)
  653. class TestReduce:
  654. def test_reduce(self):
  655. class Squares:
  656. def __init__(self, max):
  657. self.max = max
  658. self.sofar = []
  659. def __len__(self):
  660. return len(self.sofar)
  661. def __getitem__(self, i):
  662. if not 0 <= i < self.max: raise IndexError
  663. n = len(self.sofar)
  664. while n <= i:
  665. self.sofar.append(n*n)
  666. n += 1
  667. return self.sofar[i]
  668. def add(x, y):
  669. return x + y
  670. self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
  671. self.assertEqual(
  672. self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
  673. ['a','c','d','w']
  674. )
  675. self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
  676. self.assertEqual(
  677. self.reduce(lambda x, y: x*y, range(2,21), 1),
  678. 2432902008176640000
  679. )
  680. self.assertEqual(self.reduce(add, Squares(10)), 285)
  681. self.assertEqual(self.reduce(add, Squares(10), 0), 285)
  682. self.assertEqual(self.reduce(add, Squares(0), 0), 0)
  683. self.assertRaises(TypeError, self.reduce)
  684. self.assertRaises(TypeError, self.reduce, 42, 42)
  685. self.assertRaises(TypeError, self.reduce, 42, 42, 42)
  686. self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
  687. self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
  688. self.assertRaises(TypeError, self.reduce, 42, (42, 42))
  689. self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
  690. self.assertRaises(TypeError, self.reduce, add, "")
  691. self.assertRaises(TypeError, self.reduce, add, ())
  692. self.assertRaises(TypeError, self.reduce, add, object())
  693. class TestFailingIter:
  694. def __iter__(self):
  695. raise RuntimeError
  696. self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
  697. self.assertEqual(self.reduce(add, [], None), None)
  698. self.assertEqual(self.reduce(add, [], 42), 42)
  699. class BadSeq:
  700. def __getitem__(self, index):
  701. raise ValueError
  702. self.assertRaises(ValueError, self.reduce, 42, BadSeq())
  703. # Test reduce()'s use of iterators.
  704. def test_iterator_usage(self):
  705. class SequenceClass:
  706. def __init__(self, n):
  707. self.n = n
  708. def __getitem__(self, i):
  709. if 0 <= i < self.n:
  710. return i
  711. else:
  712. raise IndexError
  713. from operator import add
  714. self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
  715. self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
  716. self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
  717. self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
  718. self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
  719. self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
  720. d = {"one": 1, "two": 2, "three": 3}
  721. self.assertEqual(self.reduce(add, d), "".join(d.keys()))
  722. @unittest.skipUnless(c_functools, 'requires the C _functools module')
  723. class TestReduceC(TestReduce, unittest.TestCase):
  724. if c_functools:
  725. reduce = c_functools.reduce
  726. class TestReducePy(TestReduce, unittest.TestCase):
  727. reduce = staticmethod(py_functools.reduce)
  728. class TestCmpToKey:
  729. def test_cmp_to_key(self):
  730. def cmp1(x, y):
  731. return (x > y) - (x < y)
  732. key = self.cmp_to_key(cmp1)
  733. self.assertEqual(key(3), key(3))
  734. self.assertGreater(key(3), key(1))
  735. self.assertGreaterEqual(key(3), key(3))
  736. def cmp2(x, y):
  737. return int(x) - int(y)
  738. key = self.cmp_to_key(cmp2)
  739. self.assertEqual(key(4.0), key('4'))
  740. self.assertLess(key(2), key('35'))
  741. self.assertLessEqual(key(2), key('35'))
  742. self.assertNotEqual(key(2), key('35'))
  743. def test_cmp_to_key_arguments(self):
  744. def cmp1(x, y):
  745. return (x > y) - (x < y)
  746. key = self.cmp_to_key(mycmp=cmp1)
  747. self.assertEqual(key(obj=3), key(obj=3))
  748. self.assertGreater(key(obj=3), key(obj=1))
  749. with self.assertRaises((TypeError, AttributeError)):
  750. key(3) > 1 # rhs is not a K object
  751. with self.assertRaises((TypeError, AttributeError)):
  752. 1 < key(3) # lhs is not a K object
  753. with self.assertRaises(TypeError):
  754. key = self.cmp_to_key() # too few args
  755. with self.assertRaises(TypeError):
  756. key = self.cmp_to_key(cmp1, None) # too many args
  757. key = self.cmp_to_key(cmp1)
  758. with self.assertRaises(TypeError):
  759. key() # too few args
  760. with self.assertRaises(TypeError):
  761. key(None, None) # too many args
  762. def test_bad_cmp(self):
  763. def cmp1(x, y):
  764. raise ZeroDivisionError
  765. key = self.cmp_to_key(cmp1)
  766. with self.assertRaises(ZeroDivisionError):
  767. key(3) > key(1)
  768. class BadCmp:
  769. def __lt__(self, other):
  770. raise ZeroDivisionError
  771. def cmp1(x, y):
  772. return BadCmp()
  773. with self.assertRaises(ZeroDivisionError):
  774. key(3) > key(1)
  775. def test_obj_field(self):
  776. def cmp1(x, y):
  777. return (x > y) - (x < y)
  778. key = self.cmp_to_key(mycmp=cmp1)
  779. self.assertEqual(key(50).obj, 50)
  780. def test_sort_int(self):
  781. def mycmp(x, y):
  782. return y - x
  783. self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
  784. [4, 3, 2, 1, 0])
  785. def test_sort_int_str(self):
  786. def mycmp(x, y):
  787. x, y = int(x), int(y)
  788. return (x > y) - (x < y)
  789. values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
  790. values = sorted(values, key=self.cmp_to_key(mycmp))
  791. self.assertEqual([int(value) for value in values],
  792. [0, 1, 1, 2, 3, 4, 5, 7, 10])
  793. def test_hash(self):
  794. def mycmp(x, y):
  795. return y - x
  796. key = self.cmp_to_key(mycmp)
  797. k = key(10)
  798. self.assertRaises(TypeError, hash, k)
  799. self.assertNotIsInstance(k, collections.abc.Hashable)
  800. @unittest.skipUnless(c_functools, 'requires the C _functools module')
  801. class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
  802. if c_functools:
  803. cmp_to_key = c_functools.cmp_to_key
  804. @support.cpython_only
  805. def test_disallow_instantiation(self):
  806. # Ensure that the type disallows instantiation (bpo-43916)
  807. support.check_disallow_instantiation(
  808. self, type(c_functools.cmp_to_key(None))
  809. )
  810. class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
  811. cmp_to_key = staticmethod(py_functools.cmp_to_key)
  812. class TestTotalOrdering(unittest.TestCase):
  813. def test_total_ordering_lt(self):
  814. @functools.total_ordering
  815. class A:
  816. def __init__(self, value):
  817. self.value = value
  818. def __lt__(self, other):
  819. return self.value < other.value
  820. def __eq__(self, other):
  821. return self.value == other.value
  822. self.assertTrue(A(1) < A(2))
  823. self.assertTrue(A(2) > A(1))
  824. self.assertTrue(A(1) <= A(2))
  825. self.assertTrue(A(2) >= A(1))
  826. self.assertTrue(A(2) <= A(2))
  827. self.assertTrue(A(2) >= A(2))
  828. self.assertFalse(A(1) > A(2))
  829. def test_total_ordering_le(self):
  830. @functools.total_ordering
  831. class A:
  832. def __init__(self, value):
  833. self.value = value
  834. def __le__(self, other):
  835. return self.value <= other.value
  836. def __eq__(self, other):
  837. return self.value == other.value
  838. self.assertTrue(A(1) < A(2))
  839. self.assertTrue(A(2) > A(1))
  840. self.assertTrue(A(1) <= A(2))
  841. self.assertTrue(A(2) >= A(1))
  842. self.assertTrue(A(2) <= A(2))
  843. self.assertTrue(A(2) >= A(2))
  844. self.assertFalse(A(1) >= A(2))
  845. def test_total_ordering_gt(self):
  846. @functools.total_ordering
  847. class A:
  848. def __init__(self, value):
  849. self.value = value
  850. def __gt__(self, other):
  851. return self.value > other.value
  852. def __eq__(self, other):
  853. return self.value == other.value
  854. self.assertTrue(A(1) < A(2))
  855. self.assertTrue(A(2) > A(1))
  856. self.assertTrue(A(1) <= A(2))
  857. self.assertTrue(A(2) >= A(1))
  858. self.assertTrue(A(2) <= A(2))
  859. self.assertTrue(A(2) >= A(2))
  860. self.assertFalse(A(2) < A(1))
  861. def test_total_ordering_ge(self):
  862. @functools.total_ordering
  863. class A:
  864. def __init__(self, value):
  865. self.value = value
  866. def __ge__(self, other):
  867. return self.value >= other.value
  868. def __eq__(self, other):
  869. return self.value == other.value
  870. self.assertTrue(A(1) < A(2))
  871. self.assertTrue(A(2) > A(1))
  872. self.assertTrue(A(1) <= A(2))
  873. self.assertTrue(A(2) >= A(1))
  874. self.assertTrue(A(2) <= A(2))
  875. self.assertTrue(A(2) >= A(2))
  876. self.assertFalse(A(2) <= A(1))
  877. def test_total_ordering_no_overwrite(self):
  878. # new methods should not overwrite existing
  879. @functools.total_ordering
  880. class A(int):
  881. pass
  882. self.assertTrue(A(1) < A(2))
  883. self.assertTrue(A(2) > A(1))
  884. self.assertTrue(A(1) <= A(2))
  885. self.assertTrue(A(2) >= A(1))
  886. self.assertTrue(A(2) <= A(2))
  887. self.assertTrue(A(2) >= A(2))
  888. def test_no_operations_defined(self):
  889. with self.assertRaises(ValueError):
  890. @functools.total_ordering
  891. class A:
  892. pass
  893. def test_notimplemented(self):
  894. # Verify NotImplemented results are correctly handled
  895. @functools.total_ordering
  896. class ImplementsLessThan:
  897. def __init__(self, value):
  898. self.value = value
  899. def __eq__(self, other):
  900. if isinstance(other, ImplementsLessThan):
  901. return self.value == other.value
  902. return False
  903. def __lt__(self, other):
  904. if isinstance(other, ImplementsLessThan):
  905. return self.value < other.value
  906. return NotImplemented
  907. @functools.total_ordering
  908. class ImplementsLessThanEqualTo:
  909. def __init__(self, value):
  910. self.value = value
  911. def __eq__(self, other):
  912. if isinstance(other, ImplementsLessThanEqualTo):
  913. return self.value == other.value
  914. return False
  915. def __le__(self, other):
  916. if isinstance(other, ImplementsLessThanEqualTo):
  917. return self.value <= other.value
  918. return NotImplemented
  919. @functools.total_ordering
  920. class ImplementsGreaterThan:
  921. def __init__(self, value):
  922. self.value = value
  923. def __eq__(self, other):
  924. if isinstance(other, ImplementsGreaterThan):
  925. return self.value == other.value
  926. return False
  927. def __gt__(self, other):
  928. if isinstance(other, ImplementsGreaterThan):
  929. return self.value > other.value
  930. return NotImplemented
  931. @functools.total_ordering
  932. class ImplementsGreaterThanEqualTo:
  933. def __init__(self, value):
  934. self.value = value
  935. def __eq__(self, other):
  936. if isinstance(other, ImplementsGreaterThanEqualTo):
  937. return self.value == other.value
  938. return False
  939. def __ge__(self, other):
  940. if isinstance(other, ImplementsGreaterThanEqualTo):
  941. return self.value >= other.value
  942. return NotImplemented
  943. self.assertIs(ImplementsLessThan(1).__le__(1), NotImplemented)
  944. self.assertIs(ImplementsLessThan(1).__gt__(1), NotImplemented)
  945. self.assertIs(ImplementsLessThan(1).__ge__(1), NotImplemented)
  946. self.assertIs(ImplementsLessThanEqualTo(1).__lt__(1), NotImplemented)
  947. self.assertIs(ImplementsLessThanEqualTo(1).__gt__(1), NotImplemented)
  948. self.assertIs(ImplementsLessThanEqualTo(1).__ge__(1), NotImplemented)
  949. self.assertIs(ImplementsGreaterThan(1).__lt__(1), NotImplemented)
  950. self.assertIs(ImplementsGreaterThan(1).__gt__(1), NotImplemented)
  951. self.assertIs(ImplementsGreaterThan(1).__ge__(1), NotImplemented)
  952. self.assertIs(ImplementsGreaterThanEqualTo(1).__lt__(1), NotImplemented)
  953. self.assertIs(ImplementsGreaterThanEqualTo(1).__le__(1), NotImplemented)
  954. self.assertIs(ImplementsGreaterThanEqualTo(1).__gt__(1), NotImplemented)
  955. def test_type_error_when_not_implemented(self):
  956. # bug 10042; ensure stack overflow does not occur
  957. # when decorated types return NotImplemented
  958. @functools.total_ordering
  959. class ImplementsLessThan:
  960. def __init__(self, value):
  961. self.value = value
  962. def __eq__(self, other):
  963. if isinstance(other, ImplementsLessThan):
  964. return self.value == other.value
  965. return False
  966. def __lt__(self, other):
  967. if isinstance(other, ImplementsLessThan):
  968. return self.value < other.value
  969. return NotImplemented
  970. @functools.total_ordering
  971. class ImplementsGreaterThan:
  972. def __init__(self, value):
  973. self.value = value
  974. def __eq__(self, other):
  975. if isinstance(other, ImplementsGreaterThan):
  976. return self.value == other.value
  977. return False
  978. def __gt__(self, other):
  979. if isinstance(other, ImplementsGreaterThan):
  980. return self.value > other.value
  981. return NotImplemented
  982. @functools.total_ordering
  983. class ImplementsLessThanEqualTo:
  984. def __init__(self, value):
  985. self.value = value
  986. def __eq__(self, other):
  987. if isinstance(other, ImplementsLessThanEqualTo):
  988. return self.value == other.value
  989. return False
  990. def __le__(self, other):
  991. if isinstance(other, ImplementsLessThanEqualTo):
  992. return self.value <= other.value
  993. return NotImplemented
  994. @functools.total_ordering
  995. class ImplementsGreaterThanEqualTo:
  996. def __init__(self, value):
  997. self.value = value
  998. def __eq__(self, other):
  999. if isinstance(other, ImplementsGreaterThanEqualTo):
  1000. return self.value == other.value
  1001. return False
  1002. def __ge__(self, other):
  1003. if isinstance(other, ImplementsGreaterThanEqualTo):
  1004. return self.value >= other.value
  1005. return NotImplemented
  1006. @functools.total_ordering
  1007. class ComparatorNotImplemented:
  1008. def __init__(self, value):
  1009. self.value = value
  1010. def __eq__(self, other):
  1011. if isinstance(other, ComparatorNotImplemented):
  1012. return self.value == other.value
  1013. return False
  1014. def __lt__(self, other):
  1015. return NotImplemented
  1016. with self.subTest("LT < 1"), self.assertRaises(TypeError):
  1017. ImplementsLessThan(-1) < 1
  1018. with self.subTest("LT < LE"), self.assertRaises(TypeError):
  1019. ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
  1020. with self.subTest("LT < GT"), self.assertRaises(TypeError):
  1021. ImplementsLessThan(1) < ImplementsGreaterThan(1)
  1022. with self.subTest("LE <= LT"), self.assertRaises(TypeError):
  1023. ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
  1024. with self.subTest("LE <= GE"), self.assertRaises(TypeError):
  1025. ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
  1026. with self.subTest("GT > GE"), self.assertRaises(TypeError):
  1027. ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
  1028. with self.subTest("GT > LT"), self.assertRaises(TypeError):
  1029. ImplementsGreaterThan(5) > ImplementsLessThan(5)
  1030. with self.subTest("GE >= GT"), self.assertRaises(TypeError):
  1031. ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
  1032. with self.subTest("GE >= LE"), self.assertRaises(TypeError):
  1033. ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
  1034. with self.subTest("GE when equal"):
  1035. a = ComparatorNotImplemented(8)
  1036. b = ComparatorNotImplemented(8)
  1037. self.assertEqual(a, b)
  1038. with self.assertRaises(TypeError):
  1039. a >= b
  1040. with self.subTest("LE when equal"):
  1041. a = ComparatorNotImplemented(9)
  1042. b = ComparatorNotImplemented(9)
  1043. self.assertEqual(a, b)
  1044. with self.assertRaises(TypeError):
  1045. a <= b
  1046. def test_pickle(self):
  1047. for proto in range(pickle.HIGHEST_PROTOCOL + 1):
  1048. for name in '__lt__', '__gt__', '__le__', '__ge__':
  1049. with self.subTest(method=name, proto=proto):
  1050. method = getattr(Orderable_LT, name)
  1051. method_copy = pickle.loads(pickle.dumps(method, proto))
  1052. self.assertIs(method_copy, method)
  1053. def test_total_ordering_for_metaclasses_issue_44605(self):
  1054. @functools.total_ordering
  1055. class SortableMeta(type):
  1056. def __new__(cls, name, bases, ns):
  1057. return super().__new__(cls, name, bases, ns)
  1058. def __lt__(self, other):
  1059. if not isinstance(other, SortableMeta):
  1060. pass
  1061. return self.__name__ < other.__name__
  1062. def __eq__(self, other):
  1063. if not isinstance(other, SortableMeta):
  1064. pass
  1065. return self.__name__ == other.__name__
  1066. class B(metaclass=SortableMeta):
  1067. pass
  1068. class A(metaclass=SortableMeta):
  1069. pass
  1070. self.assertTrue(A < B)
  1071. self.assertFalse(A > B)
  1072. @functools.total_ordering
  1073. class Orderable_LT:
  1074. def __init__(self, value):
  1075. self.value = value
  1076. def __lt__(self, other):
  1077. return self.value < other.value
  1078. def __eq__(self, other):
  1079. return self.value == other.value
  1080. class TestCache:
  1081. # This tests that the pass-through is working as designed.
  1082. # The underlying functionality is tested in TestLRU.
  1083. def test_cache(self):
  1084. @self.module.cache
  1085. def fib(n):
  1086. if n < 2:
  1087. return n
  1088. return fib(n-1) + fib(n-2)
  1089. self.assertEqual([fib(n) for n in range(16)],
  1090. [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
  1091. self.assertEqual(fib.cache_info(),
  1092. self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
  1093. fib.cache_clear()
  1094. self.assertEqual(fib.cache_info(),
  1095. self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
  1096. class TestLRU:
  1097. def test_lru(self):
  1098. def orig(x, y):
  1099. return 3 * x + y
  1100. f = self.module.lru_cache(maxsize=20)(orig)
  1101. hits, misses, maxsize, currsize = f.cache_info()
  1102. self.assertEqual(maxsize, 20)
  1103. self.assertEqual(currsize, 0)
  1104. self.assertEqual(hits, 0)
  1105. self.assertEqual(misses, 0)
  1106. domain = range(5)
  1107. for i in range(1000):
  1108. x, y = choice(domain), choice(domain)
  1109. actual = f(x, y)
  1110. expected = orig(x, y)
  1111. self.assertEqual(actual, expected)
  1112. hits, misses, maxsize, currsize = f.cache_info()
  1113. self.assertTrue(hits > misses)
  1114. self.assertEqual(hits + misses, 1000)
  1115. self.assertEqual(currsize, 20)
  1116. f.cache_clear() # test clearing
  1117. hits, misses, maxsize, currsize = f.cache_info()
  1118. self.assertEqual(hits, 0)
  1119. self.assertEqual(misses, 0)
  1120. self.assertEqual(currsize, 0)
  1121. f(x, y)
  1122. hits, misses, maxsize, currsize = f.cache_info()
  1123. self.assertEqual(hits, 0)
  1124. self.assertEqual(misses, 1)
  1125. self.assertEqual(currsize, 1)
  1126. # Test bypassing the cache
  1127. self.assertIs(f.__wrapped__, orig)
  1128. f.__wrapped__(x, y)
  1129. hits, misses, maxsize, currsize = f.cache_info()
  1130. self.assertEqual(hits, 0)
  1131. self.assertEqual(misses, 1)
  1132. self.assertEqual(currsize, 1)
  1133. # test size zero (which means "never-cache")
  1134. @self.module.lru_cache(0)
  1135. def f():
  1136. nonlocal f_cnt
  1137. f_cnt += 1
  1138. return 20
  1139. self.assertEqual(f.cache_info().maxsize, 0)
  1140. f_cnt = 0
  1141. for i in range(5):
  1142. self.assertEqual(f(), 20)
  1143. self.assertEqual(f_cnt, 5)
  1144. hits, misses, maxsize, currsize = f.cache_info()
  1145. self.assertEqual(hits, 0)
  1146. self.assertEqual(misses, 5)
  1147. self.assertEqual(currsize, 0)
  1148. # test size one
  1149. @self.module.lru_cache(1)
  1150. def f():
  1151. nonlocal f_cnt
  1152. f_cnt += 1
  1153. return 20
  1154. self.assertEqual(f.cache_info().maxsize, 1)
  1155. f_cnt = 0
  1156. for i in range(5):
  1157. self.assertEqual(f(), 20)
  1158. self.assertEqual(f_cnt, 1)
  1159. hits, misses, maxsize, currsize = f.cache_info()
  1160. self.assertEqual(hits, 4)
  1161. self.assertEqual(misses, 1)
  1162. self.assertEqual(currsize, 1)
  1163. # test size two
  1164. @self.module.lru_cache(2)
  1165. def f(x):
  1166. nonlocal f_cnt
  1167. f_cnt += 1
  1168. return x*10
  1169. self.assertEqual(f.cache_info().maxsize, 2)
  1170. f_cnt = 0
  1171. for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
  1172. # * * * *
  1173. self.assertEqual(f(x), x*10)
  1174. self.assertEqual(f_cnt, 4)
  1175. hits, misses, maxsize, currsize = f.cache_info()
  1176. self.assertEqual(hits, 12)
  1177. self.assertEqual(misses, 4)
  1178. self.assertEqual(currsize, 2)
  1179. def test_lru_no_args(self):
  1180. @self.module.lru_cache
  1181. def square(x):
  1182. return x ** 2
  1183. self.assertEqual(list(map(square, [10, 20, 10])),
  1184. [100, 400, 100])
  1185. self.assertEqual(square.cache_info().hits, 1)
  1186. self.assertEqual(square.cache_info().misses, 2)
  1187. self.assertEqual(square.cache_info().maxsize, 128)
  1188. self.assertEqual(square.cache_info().currsize, 2)
  1189. def test_lru_bug_35780(self):
  1190. # C version of the lru_cache was not checking to see if
  1191. # the user function call has already modified the cache
  1192. # (this arises in recursive calls and in multi-threading).
  1193. # This cause the cache to have orphan links not referenced
  1194. # by the cache dictionary.
  1195. once = True # Modified by f(x) below
  1196. @self.module.lru_cache(maxsize=10)
  1197. def f(x):
  1198. nonlocal once
  1199. rv = f'.{x}.'
  1200. if x == 20 and once:
  1201. once = False
  1202. rv = f(x)
  1203. return rv
  1204. # Fill the cache
  1205. for x in range(15):
  1206. self.assertEqual(f(x), f'.{x}.')
  1207. self.assertEqual(f.cache_info().currsize, 10)
  1208. # Make a recursive call and make sure the cache remains full
  1209. self.assertEqual(f(20), '.20.')
  1210. self.assertEqual(f.cache_info().currsize, 10)
  1211. def test_lru_bug_36650(self):
  1212. # C version of lru_cache was treating a call with an empty **kwargs
  1213. # dictionary as being distinct from a call with no keywords at all.
  1214. # This did not result in an incorrect answer, but it did trigger
  1215. # an unexpected cache miss.
  1216. @self.module.lru_cache()
  1217. def f(x):
  1218. pass
  1219. f(0)
  1220. f(0, **{})
  1221. self.assertEqual(f.cache_info().hits, 1)
  1222. def test_lru_hash_only_once(self):
  1223. # To protect against weird reentrancy bugs and to improve
  1224. # efficiency when faced with slow __hash__ methods, the
  1225. # LRU cache guarantees that it will only call __hash__
  1226. # only once per use as an argument to the cached function.
  1227. @self.module.lru_cache(maxsize=1)
  1228. def f(x, y):
  1229. return x * 3 + y
  1230. # Simulate the integer 5
  1231. mock_int = unittest.mock.Mock()
  1232. mock_int.__mul__ = unittest.mock.Mock(return_value=15)
  1233. mock_int.__hash__ = unittest.mock.Mock(return_value=999)
  1234. # Add to cache: One use as an argument gives one call
  1235. self.assertEqual(f(mock_int, 1), 16)
  1236. self.assertEqual(mock_int.__hash__.call_count, 1)
  1237. self.assertEqual(f.cache_info(), (0, 1, 1, 1))
  1238. # Cache hit: One use as an argument gives one additional call
  1239. self.assertEqual(f(mock_int, 1), 16)
  1240. self.assertEqual(mock_int.__hash__.call_count, 2)
  1241. self.assertEqual(f.cache_info(), (1, 1, 1, 1))
  1242. # Cache eviction: No use as an argument gives no additional call
  1243. self.assertEqual(f(6, 2), 20)
  1244. self.assertEqual(mock_int.__hash__.call_count, 2)
  1245. self.assertEqual(f.cache_info(), (1, 2, 1, 1))
  1246. # Cache miss: One use as an argument gives one additional call
  1247. self.assertEqual(f(mock_int, 1), 16)
  1248. self.assertEqual(mock_int.__hash__.call_count, 3)
  1249. self.assertEqual(f.cache_info(), (1, 3, 1, 1))
  1250. def test_lru_reentrancy_with_len(self):
  1251. # Test to make sure the LRU cache code isn't thrown-off by
  1252. # caching the built-in len() function. Since len() can be
  1253. # cached, we shouldn't use it inside the lru code itself.
  1254. old_len = builtins.len
  1255. try:
  1256. builtins.len = self.module.lru_cache(4)(len)
  1257. for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
  1258. self.assertEqual(len('abcdefghijklmn'[:i]), i)
  1259. finally:
  1260. builtins.len = old_len
  1261. def test_lru_star_arg_handling(self):
  1262. # Test regression that arose in ea064ff3c10f
  1263. @self.module.lru_cache()
  1264. def f(*args):
  1265. return args
  1266. self.assertEqual(f(1, 2), (1, 2))
  1267. self.assertEqual(f((1, 2)), ((1, 2),))
  1268. def test_lru_type_error(self):
  1269. # Regression test for issue #28653.
  1270. # lru_cache was leaking when one of the arguments
  1271. # wasn't cacheable.
  1272. @self.module.lru_cache(maxsize=None)
  1273. def infinite_cache(o):
  1274. pass
  1275. @self.module.lru_cache(maxsize=10)
  1276. def limited_cache(o):
  1277. pass
  1278. with self.assertRaises(TypeError):
  1279. infinite_cache([])
  1280. with self.assertRaises(TypeError):
  1281. limited_cache([])
  1282. def test_lru_with_maxsize_none(self):
  1283. @self.module.lru_cache(maxsize=None)
  1284. def fib(n):
  1285. if n < 2:
  1286. return n
  1287. return fib(n-1) + fib(n-2)
  1288. self.assertEqual([fib(n) for n in range(16)],
  1289. [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
  1290. self.assertEqual(fib.cache_info(),
  1291. self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
  1292. fib.cache_clear()
  1293. self.assertEqual(fib.cache_info(),
  1294. self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
  1295. def test_lru_with_maxsize_negative(self):
  1296. @self.module.lru_cache(maxsize=-10)
  1297. def eq(n):
  1298. return n
  1299. for i in (0, 1):
  1300. self.assertEqual([eq(n) for n in range(150)], list(range(150)))
  1301. self.assertEqual(eq.cache_info(),
  1302. self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
  1303. def test_lru_with_exceptions(self):
  1304. # Verify that user_function exceptions get passed through without
  1305. # creating a hard-to-read chained exception.
  1306. # http://bugs.python.org/issue13177
  1307. for maxsize in (None, 128):
  1308. @self.module.lru_cache(maxsize)
  1309. def func(i):
  1310. return 'abc'[i]
  1311. self.assertEqual(func(0), 'a')
  1312. with self.assertRaises(IndexError) as cm:
  1313. func(15)
  1314. self.assertIsNone(cm.exception.__context__)
  1315. # Verify that the previous exception did not result in a cached entry
  1316. with self.assertRaises(IndexError):
  1317. func(15)
  1318. def test_lru_with_types(self):
  1319. for maxsize in (None, 128):
  1320. @self.module.lru_cache(maxsize=maxsize, typed=True)
  1321. def square(x):
  1322. return x * x
  1323. self.assertEqual(square(3), 9)
  1324. self.assertEqual(type(square(3)), type(9))
  1325. self.assertEqual(square(3.0), 9.0)
  1326. self.assertEqual(type(square(3.0)), type(9.0))
  1327. self.assertEqual(square(x=3), 9)
  1328. self.assertEqual(type(square(x=3)), type(9))
  1329. self.assertEqual(square(x=3.0), 9.0)
  1330. self.assertEqual(type(square(x=3.0)), type(9.0))
  1331. self.assertEqual(square.cache_info().hits, 4)
  1332. self.assertEqual(square.cache_info().misses, 4)
  1333. def test_lru_cache_typed_is_not_recursive(self):
  1334. cached = self.module.lru_cache(typed=True)(repr)
  1335. self.assertEqual(cached(1), '1')
  1336. self.assertEqual(cached(True), 'True')
  1337. self.assertEqual(cached(1.0), '1.0')
  1338. self.assertEqual(cached(0), '0')
  1339. self.assertEqual(cached(False), 'False')
  1340. self.assertEqual(cached(0.0), '0.0')
  1341. self.assertEqual(cached((1,)), '(1,)')
  1342. self.assertEqual(cached((True,)), '(1,)')
  1343. self.assertEqual(cached((1.0,)), '(1,)')
  1344. self.assertEqual(cached((0,)), '(0,)')
  1345. self.assertEqual(cached((False,)), '(0,)')
  1346. self.assertEqual(cached((0.0,)), '(0,)')
  1347. class T(tuple):
  1348. pass
  1349. self.assertEqual(cached(T((1,))), '(1,)')
  1350. self.assertEqual(cached(T((True,))), '(1,)')
  1351. self.assertEqual(cached(T((1.0,))), '(1,)')
  1352. self.assertEqual(cached(T((0,))), '(0,)')
  1353. self.assertEqual(cached(T((False,))), '(0,)')
  1354. self.assertEqual(cached(T((0.0,))), '(0,)')
  1355. def test_lru_with_keyword_args(self):
  1356. @self.module.lru_cache()
  1357. def fib(n):
  1358. if n < 2:
  1359. return n
  1360. return fib(n=n-1) + fib(n=n-2)
  1361. self.assertEqual(
  1362. [fib(n=number) for number in range(16)],
  1363. [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
  1364. )
  1365. self.assertEqual(fib.cache_info(),
  1366. self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
  1367. fib.cache_clear()
  1368. self.assertEqual(fib.cache_info(),
  1369. self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
  1370. def test_lru_with_keyword_args_maxsize_none(self):
  1371. @self.module.lru_cache(maxsize=None)
  1372. def fib(n):
  1373. if n < 2:
  1374. return n
  1375. return fib(n=n-1) + fib(n=n-2)
  1376. self.assertEqual([fib(n=number) for number in range(16)],
  1377. [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
  1378. self.assertEqual(fib.cache_info(),
  1379. self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
  1380. fib.cache_clear()
  1381. self.assertEqual(fib.cache_info(),
  1382. self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
  1383. def test_kwargs_order(self):
  1384. # PEP 468: Preserving Keyword Argument Order
  1385. @self.module.lru_cache(maxsize=10)
  1386. def f(**kwargs):
  1387. return list(kwargs.items())
  1388. self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
  1389. self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
  1390. self.assertEqual(f.cache_info(),
  1391. self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
  1392. def test_lru_cache_decoration(self):
  1393. def f(zomg: 'zomg_annotation'):
  1394. """f doc string"""
  1395. return 42
  1396. g = self.module.lru_cache()(f)
  1397. for attr in self.module.WRAPPER_ASSIGNMENTS:
  1398. self.assertEqual(getattr(g, attr), getattr(f, attr))
  1399. @threading_helper.requires_working_threading()
  1400. def test_lru_cache_threaded(self):
  1401. n, m = 5, 11
  1402. def orig(x, y):
  1403. return 3 * x + y
  1404. f = self.module.lru_cache(maxsize=n*m)(orig)
  1405. hits, misses, maxsize, currsize = f.cache_info()
  1406. self.assertEqual(currsize, 0)
  1407. start = threading.Event()
  1408. def full(k):
  1409. start.wait(10)
  1410. for _ in range(m):
  1411. self.assertEqual(f(k, 0), orig(k, 0))
  1412. def clear():
  1413. start.wait(10)
  1414. for _ in range(2*m):
  1415. f.cache_clear()
  1416. orig_si = sys.getswitchinterval()
  1417. support.setswitchinterval(1e-6)
  1418. try:
  1419. # create n threads in order to fill cache
  1420. threads = [threading.Thread(target=full, args=[k])
  1421. for k in range(n)]
  1422. with threading_helper.start_threads(threads):
  1423. start.set()
  1424. hits, misses, maxsize, currsize = f.cache_info()
  1425. if self.module is py_functools:
  1426. # XXX: Why can be not equal?
  1427. self.assertLessEqual(misses, n)
  1428. self.assertLessEqual(hits, m*n - misses)
  1429. else:
  1430. self.assertEqual(misses, n)
  1431. self.assertEqual(hits, m*n - misses)
  1432. self.assertEqual(currsize, n)
  1433. # create n threads in order to fill cache and 1 to clear it
  1434. threads = [threading.Thread(target=clear)]
  1435. threads += [threading.Thread(target=full, args=[k])
  1436. for k in range(n)]
  1437. start.clear()
  1438. with threading_helper.start_threads(threads):
  1439. start.set()
  1440. finally:
  1441. sys.setswitchinterval(orig_si)
  1442. @threading_helper.requires_working_threading()
  1443. def test_lru_cache_threaded2(self):
  1444. # Simultaneous call with the same arguments
  1445. n, m = 5, 7
  1446. start = threading.Barrier(n+1)
  1447. pause = threading.Barrier(n+1)
  1448. stop = threading.Barrier(n+1)
  1449. @self.module.lru_cache(maxsize=m*n)
  1450. def f(x):
  1451. pause.wait(10)
  1452. return 3 * x
  1453. self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
  1454. def test():
  1455. for i in range(m):
  1456. start.wait(10)
  1457. self.assertEqual(f(i), 3 * i)
  1458. stop.wait(10)
  1459. threads = [threading.Thread(target=test) for k in range(n)]
  1460. with threading_helper.start_threads(threads):
  1461. for i in range(m):
  1462. start.wait(10)
  1463. stop.reset()
  1464. pause.wait(10)
  1465. start.reset()
  1466. stop.wait(10)
  1467. pause.reset()
  1468. self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
  1469. @threading_helper.requires_working_threading()
  1470. def test_lru_cache_threaded3(self):
  1471. @self.module.lru_cache(maxsize=2)
  1472. def f(x):
  1473. time.sleep(.01)
  1474. return 3 * x
  1475. def test(i, x):
  1476. with self.subTest(thread=i):
  1477. self.assertEqual(f(x), 3 * x, i)
  1478. threads = [threading.Thread(target=test, args=(i, v))
  1479. for i, v in enumerate([1, 2, 2, 3, 2])]
  1480. with threading_helper.start_threads(threads):
  1481. pass
  1482. def test_need_for_rlock(self):
  1483. # This will deadlock on an LRU cache that uses a regular lock
  1484. @self.module.lru_cache(maxsize=10)
  1485. def test_func(x):
  1486. 'Used to demonstrate a reentrant lru_cache call within a single thread'
  1487. return x
  1488. class DoubleEq:
  1489. 'Demonstrate a reentrant lru_cache call within a single thread'
  1490. def __init__(self, x):
  1491. self.x = x
  1492. def __hash__(self):
  1493. return self.x
  1494. def __eq__(self, other):
  1495. if self.x == 2:
  1496. test_func(DoubleEq(1))
  1497. return self.x == other.x
  1498. test_func(DoubleEq(1)) # Load the cache
  1499. test_func(DoubleEq(2)) # Load the cache
  1500. self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
  1501. DoubleEq(2)) # Verify the correct return value
  1502. def test_lru_method(self):
  1503. class X(int):
  1504. f_cnt = 0
  1505. @self.module.lru_cache(2)
  1506. def f(self, x):
  1507. self.f_cnt += 1
  1508. return x*10+self
  1509. a = X(5)
  1510. b = X(5)
  1511. c = X(7)
  1512. self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
  1513. for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
  1514. self.assertEqual(a.f(x), x*10 + 5)
  1515. self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
  1516. self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
  1517. for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
  1518. self.assertEqual(b.f(x), x*10 + 5)
  1519. self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
  1520. self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
  1521. for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
  1522. self.assertEqual(c.f(x), x*10 + 7)
  1523. self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
  1524. self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
  1525. self.assertEqual(a.f.cache_info(), X.f.cache_info())
  1526. self.assertEqual(b.f.cache_info(), X.f.cache_info())
  1527. self.assertEqual(c.f.cache_info(), X.f.cache_info())
  1528. def test_pickle(self):
  1529. cls = self.__class__
  1530. for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
  1531. for proto in range(pickle.HIGHEST_PROTOCOL + 1):
  1532. with self.subTest(proto=proto, func=f):
  1533. f_copy = pickle.loads(pickle.dumps(f, proto))
  1534. self.assertIs(f_copy, f)
  1535. def test_copy(self):
  1536. cls = self.__class__
  1537. def orig(x, y):
  1538. return 3 * x + y
  1539. part = self.module.partial(orig, 2)
  1540. funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
  1541. self.module.lru_cache(2)(part))
  1542. for f in funcs:
  1543. with self.subTest(func=f):
  1544. f_copy = copy.copy(f)
  1545. self.assertIs(f_copy, f)
  1546. def test_deepcopy(self):
  1547. cls = self.__class__
  1548. def orig(x, y):
  1549. return 3 * x + y
  1550. part = self.module.partial(orig, 2)
  1551. funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
  1552. self.module.lru_cache(2)(part))
  1553. for f in funcs:
  1554. with self.subTest(func=f):
  1555. f_copy = copy.deepcopy(f)
  1556. self.assertIs(f_copy, f)
  1557. def test_lru_cache_parameters(self):
  1558. @self.module.lru_cache(maxsize=2)
  1559. def f():
  1560. return 1
  1561. self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
  1562. @self.module.lru_cache(maxsize=1000, typed=True)
  1563. def f():
  1564. return 1
  1565. self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
  1566. def test_lru_cache_weakrefable(self):
  1567. @self.module.lru_cache
  1568. def test_function(x):
  1569. return x
  1570. class A:
  1571. @self.module.lru_cache
  1572. def test_method(self, x):
  1573. return (self, x)
  1574. @staticmethod
  1575. @self.module.lru_cache
  1576. def test_staticmethod(x):
  1577. return (self, x)
  1578. refs = [weakref.ref(test_function),
  1579. weakref.ref(A.test_method),
  1580. weakref.ref(A.test_staticmethod)]
  1581. for ref in refs:
  1582. self.assertIsNotNone(ref())
  1583. del A
  1584. del test_function
  1585. gc.collect()
  1586. for ref in refs:
  1587. self.assertIsNone(ref())
  1588. @py_functools.lru_cache()
  1589. def py_cached_func(x, y):
  1590. return 3 * x + y
  1591. @c_functools.lru_cache()
  1592. def c_cached_func(x, y):
  1593. return 3 * x + y
  1594. class TestLRUPy(TestLRU, unittest.TestCase):
  1595. module = py_functools
  1596. cached_func = py_cached_func,
  1597. @module.lru_cache()
  1598. def cached_meth(self, x, y):
  1599. return 3 * x + y
  1600. @staticmethod
  1601. @module.lru_cache()
  1602. def cached_staticmeth(x, y):
  1603. return 3 * x + y
  1604. class TestLRUC(TestLRU, unittest.TestCase):
  1605. module = c_functools
  1606. cached_func = c_cached_func,
  1607. @module.lru_cache()
  1608. def cached_meth(self, x, y):
  1609. return 3 * x + y
  1610. @staticmethod
  1611. @module.lru_cache()
  1612. def cached_staticmeth(x, y):
  1613. return 3 * x + y
  1614. class TestSingleDispatch(unittest.TestCase):
  1615. def test_simple_overloads(self):
  1616. @functools.singledispatch
  1617. def g(obj):
  1618. return "base"
  1619. def g_int(i):
  1620. return "integer"
  1621. g.register(int, g_int)
  1622. self.assertEqual(g("str"), "base")
  1623. self.assertEqual(g(1), "integer")
  1624. self.assertEqual(g([1,2,3]), "base")
  1625. def test_mro(self):
  1626. @functools.singledispatch
  1627. def g(obj):
  1628. return "base"
  1629. class A:
  1630. pass
  1631. class C(A):
  1632. pass
  1633. class B(A):
  1634. pass
  1635. class D(C, B):
  1636. pass
  1637. def g_A(a):
  1638. return "A"
  1639. def g_B(b):
  1640. return "B"
  1641. g.register(A, g_A)
  1642. g.register(B, g_B)
  1643. self.assertEqual(g(A()), "A")
  1644. self.assertEqual(g(B()), "B")
  1645. self.assertEqual(g(C()), "A")
  1646. self.assertEqual(g(D()), "B")
  1647. def test_register_decorator(self):
  1648. @functools.singledispatch
  1649. def g(obj):
  1650. return "base"
  1651. @g.register(int)
  1652. def g_int(i):
  1653. return "int %s" % (i,)
  1654. self.assertEqual(g(""), "base")
  1655. self.assertEqual(g(12), "int 12")
  1656. self.assertIs(g.dispatch(int), g_int)
  1657. self.assertIs(g.dispatch(object), g.dispatch(str))
  1658. # Note: in the assert above this is not g.
  1659. # @singledispatch returns the wrapper.
  1660. def test_wrapping_attributes(self):
  1661. @functools.singledispatch
  1662. def g(obj):
  1663. "Simple test"
  1664. return "Test"
  1665. self.assertEqual(g.__name__, "g")
  1666. if sys.flags.optimize < 2:
  1667. self.assertEqual(g.__doc__, "Simple test")
  1668. @unittest.skipUnless(decimal, 'requires _decimal')
  1669. @support.cpython_only
  1670. def test_c_classes(self):
  1671. @functools.singledispatch
  1672. def g(obj):
  1673. return "base"
  1674. @g.register(decimal.DecimalException)
  1675. def _(obj):
  1676. return obj.args
  1677. subn = decimal.Subnormal("Exponent < Emin")
  1678. rnd = decimal.Rounded("Number got rounded")
  1679. self.assertEqual(g(subn), ("Exponent < Emin",))
  1680. self.assertEqual(g(rnd), ("Number got rounded",))
  1681. @g.register(decimal.Subnormal)
  1682. def _(obj):
  1683. return "Too small to care."
  1684. self.assertEqual(g(subn), "Too small to care.")
  1685. self.assertEqual(g(rnd), ("Number got rounded",))
  1686. def test_compose_mro(self):
  1687. # None of the examples in this test depend on haystack ordering.
  1688. c = collections.abc
  1689. mro = functools._compose_mro
  1690. bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
  1691. for haystack in permutations(bases):
  1692. m = mro(dict, haystack)
  1693. self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
  1694. c.Collection, c.Sized, c.Iterable,
  1695. c.Container, object])
  1696. bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
  1697. for haystack in permutations(bases):
  1698. m = mro(collections.ChainMap, haystack)
  1699. self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
  1700. c.Collection, c.Sized, c.Iterable,
  1701. c.Container, object])
  1702. # If there's a generic function with implementations registered for
  1703. # both Sized and Container, passing a defaultdict to it results in an
  1704. # ambiguous dispatch which will cause a RuntimeError (see
  1705. # test_mro_conflicts).
  1706. bases = [c.Container, c.Sized, str]
  1707. for haystack in permutations(bases):
  1708. m = mro(collections.defaultdict, [c.Sized, c.Container, str])
  1709. self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
  1710. c.Container, object])
  1711. # MutableSequence below is registered directly on D. In other words, it
  1712. # precedes MutableMapping which means single dispatch will always
  1713. # choose MutableSequence here.
  1714. class D(collections.defaultdict):
  1715. pass
  1716. c.MutableSequence.register(D)
  1717. bases = [c.MutableSequence, c.MutableMapping]
  1718. for haystack in permutations(bases):
  1719. m = mro(D, bases)
  1720. self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
  1721. collections.defaultdict, dict, c.MutableMapping, c.Mapping,
  1722. c.Collection, c.Sized, c.Iterable, c.Container,
  1723. object])
  1724. # Container and Callable are registered on different base classes and
  1725. # a generic function supporting both should always pick the Callable
  1726. # implementation if a C instance is passed.
  1727. class C(collections.defaultdict):
  1728. def __call__(self):
  1729. pass
  1730. bases = [c.Sized, c.Callable, c.Container, c.Mapping]
  1731. for haystack in permutations(bases):
  1732. m = mro(C, haystack)
  1733. self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
  1734. c.Collection, c.Sized, c.Iterable,
  1735. c.Container, object])
  1736. def test_register_abc(self):
  1737. c = collections.abc
  1738. d = {"a": "b"}
  1739. l = [1, 2, 3]
  1740. s = {object(), None}
  1741. f = frozenset(s)
  1742. t = (1, 2, 3)
  1743. @functools.singledispatch
  1744. def g(obj):
  1745. return "base"
  1746. self.assertEqual(g(d), "base")
  1747. self.assertEqual(g(l), "base")
  1748. self.assertEqual(g(s), "base")
  1749. self.assertEqual(g(f), "base")
  1750. self.assertEqual(g(t), "base")
  1751. g.register(c.Sized, lambda obj: "sized")
  1752. self.assertEqual(g(d), "sized")
  1753. self.assertEqual(g(l), "sized")
  1754. self.assertEqual(g(s), "sized")
  1755. self.assertEqual(g(f), "sized")
  1756. self.assertEqual(g(t), "sized")
  1757. g.register(c.MutableMapping, lambda obj: "mutablemapping")
  1758. self.assertEqual(g(d), "mutablemapping")
  1759. self.assertEqual(g(l), "sized")
  1760. self.assertEqual(g(s), "sized")
  1761. self.assertEqual(g(f), "sized")
  1762. self.assertEqual(g(t), "sized")
  1763. g.register(collections.ChainMap, lambda obj: "chainmap")
  1764. self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
  1765. self.assertEqual(g(l), "sized")
  1766. self.assertEqual(g(s), "sized")
  1767. self.assertEqual(g(f), "sized")
  1768. self.assertEqual(g(t), "sized")
  1769. g.register(c.MutableSequence, lambda obj: "mutablesequence")
  1770. self.assertEqual(g(d), "mutablemapping")
  1771. self.assertEqual(g(l), "mutablesequence")
  1772. self.assertEqual(g(s), "sized")
  1773. self.assertEqual(g(f), "sized")
  1774. self.assertEqual(g(t), "sized")
  1775. g.register(c.MutableSet, lambda obj: "mutableset")
  1776. self.assertEqual(g(d), "mutablemapping")
  1777. self.assertEqual(g(l), "mutablesequence")
  1778. self.assertEqual(g(s), "mutableset")
  1779. self.assertEqual(g(f), "sized")
  1780. self.assertEqual(g(t), "sized")
  1781. g.register(c.Mapping, lambda obj: "mapping")
  1782. self.assertEqual(g(d), "mutablemapping") # not specific enough
  1783. self.assertEqual(g(l), "mutablesequence")
  1784. self.assertEqual(g(s), "mutableset")
  1785. self.assertEqual(g(f), "sized")
  1786. self.assertEqual(g(t), "sized")
  1787. g.register(c.Sequence, lambda obj: "sequence")
  1788. self.assertEqual(g(d), "mutablemapping")
  1789. self.assertEqual(g(l), "mutablesequence")
  1790. self.assertEqual(g(s), "mutableset")
  1791. self.assertEqual(g(f), "sized")
  1792. self.assertEqual(g(t), "sequence")
  1793. g.register(c.Set, lambda obj: "set")
  1794. self.assertEqual(g(d), "mutablemapping")
  1795. self.assertEqual(g(l), "mutablesequence")
  1796. self.assertEqual(g(s), "mutableset")
  1797. self.assertEqual(g(f), "set")
  1798. self.assertEqual(g(t), "sequence")
  1799. g.register(dict, lambda obj: "dict")
  1800. self.assertEqual(g(d), "dict")
  1801. self.assertEqual(g(l), "mutablesequence")
  1802. self.assertEqual(g(s), "mutableset")
  1803. self.assertEqual(g(f), "set")
  1804. self.assertEqual(g(t), "sequence")
  1805. g.register(list, lambda obj: "list")
  1806. self.assertEqual(g(d), "dict")
  1807. self.assertEqual(g(l), "list")
  1808. self.assertEqual(g(s), "mutableset")
  1809. self.assertEqual(g(f), "set")
  1810. self.assertEqual(g(t), "sequence")
  1811. g.register(set, lambda obj: "concrete-set")
  1812. self.assertEqual(g(d), "dict")
  1813. self.assertEqual(g(l), "list")
  1814. self.assertEqual(g(s), "concrete-set")
  1815. self.assertEqual(g(f), "set")
  1816. self.assertEqual(g(t), "sequence")
  1817. g.register(frozenset, lambda obj: "frozen-set")
  1818. self.assertEqual(g(d), "dict")
  1819. self.assertEqual(g(l), "list")
  1820. self.assertEqual(g(s), "concrete-set")
  1821. self.assertEqual(g(f), "frozen-set")
  1822. self.assertEqual(g(t), "sequence")
  1823. g.register(tuple, lambda obj: "tuple")
  1824. self.assertEqual(g(d), "dict")
  1825. self.assertEqual(g(l), "list")
  1826. self.assertEqual(g(s), "concrete-set")
  1827. self.assertEqual(g(f), "frozen-set")
  1828. self.assertEqual(g(t), "tuple")
  1829. def test_c3_abc(self):
  1830. c = collections.abc
  1831. mro = functools._c3_mro
  1832. class A(object):
  1833. pass
  1834. class B(A):
  1835. def __len__(self):
  1836. return 0 # implies Sized
  1837. @c.Container.register
  1838. class C(object):
  1839. pass
  1840. class D(object):
  1841. pass # unrelated
  1842. class X(D, C, B):
  1843. def __call__(self):
  1844. pass # implies Callable
  1845. expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
  1846. for abcs in permutations([c.Sized, c.Callable, c.Container]):
  1847. self.assertEqual(mro(X, abcs=abcs), expected)
  1848. # unrelated ABCs don't appear in the resulting MRO
  1849. many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
  1850. self.assertEqual(mro(X, abcs=many_abcs), expected)
  1851. def test_false_meta(self):
  1852. # see issue23572
  1853. class MetaA(type):
  1854. def __len__(self):
  1855. return 0
  1856. class A(metaclass=MetaA):
  1857. pass
  1858. class AA(A):
  1859. pass
  1860. @functools.singledispatch
  1861. def fun(a):
  1862. return 'base A'
  1863. @fun.register(A)
  1864. def _(a):
  1865. return 'fun A'
  1866. aa = AA()
  1867. self.assertEqual(fun(aa), 'fun A')
  1868. def test_mro_conflicts(self):
  1869. c = collections.abc
  1870. @functools.singledispatch
  1871. def g(arg):
  1872. return "base"
  1873. class O(c.Sized):
  1874. def __len__(self):
  1875. return 0
  1876. o = O()
  1877. self.assertEqual(g(o), "base")
  1878. g.register(c.Iterable, lambda arg: "iterable")
  1879. g.register(c.Container, lambda arg: "container")
  1880. g.register(c.Sized, lambda arg: "sized")
  1881. g.register(c.Set, lambda arg: "set")
  1882. self.assertEqual(g(o), "sized")
  1883. c.Iterable.register(O)
  1884. self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
  1885. c.Container.register(O)
  1886. self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
  1887. c.Set.register(O)
  1888. self.assertEqual(g(o), "set") # because c.Set is a subclass of
  1889. # c.Sized and c.Container
  1890. class P:
  1891. pass
  1892. p = P()
  1893. self.assertEqual(g(p), "base")
  1894. c.Iterable.register(P)
  1895. self.assertEqual(g(p), "iterable")
  1896. c.Container.register(P)
  1897. with self.assertRaises(RuntimeError) as re_one:
  1898. g(p)
  1899. self.assertIn(
  1900. str(re_one.exception),
  1901. (("Ambiguous dispatch: <class 'collections.abc.Container'> "
  1902. "or <class 'collections.abc.Iterable'>"),
  1903. ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
  1904. "or <class 'collections.abc.Container'>")),
  1905. )
  1906. class Q(c.Sized):
  1907. def __len__(self):
  1908. return 0
  1909. q = Q()
  1910. self.assertEqual(g(q), "sized")
  1911. c.Iterable.register(Q)
  1912. self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
  1913. c.Set.register(Q)
  1914. self.assertEqual(g(q), "set") # because c.Set is a subclass of
  1915. # c.Sized and c.Iterable
  1916. @functools.singledispatch
  1917. def h(arg):
  1918. return "base"
  1919. @h.register(c.Sized)
  1920. def _(arg):
  1921. return "sized"
  1922. @h.register(c.Container)
  1923. def _(arg):
  1924. return "container"
  1925. # Even though Sized and Container are explicit bases of MutableMapping,
  1926. # this ABC is implicitly registered on defaultdict which makes all of
  1927. # MutableMapping's bases implicit as well from defaultdict's
  1928. # perspective.
  1929. with self.assertRaises(RuntimeError) as re_two:
  1930. h(collections.defaultdict(lambda: 0))
  1931. self.assertIn(
  1932. str(re_two.exception),
  1933. (("Ambiguous dispatch: <class 'collections.abc.Container'> "
  1934. "or <class 'collections.abc.Sized'>"),
  1935. ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
  1936. "or <class 'collections.abc.Container'>")),
  1937. )
  1938. class R(collections.defaultdict):
  1939. pass
  1940. c.MutableSequence.register(R)
  1941. @functools.singledispatch
  1942. def i(arg):
  1943. return "base"
  1944. @i.register(c.MutableMapping)
  1945. def _(arg):
  1946. return "mapping"
  1947. @i.register(c.MutableSequence)
  1948. def _(arg):
  1949. return "sequence"
  1950. r = R()
  1951. self.assertEqual(i(r), "sequence")
  1952. class S:
  1953. pass
  1954. class T(S, c.Sized):
  1955. def __len__(self):
  1956. return 0
  1957. t = T()
  1958. self.assertEqual(h(t), "sized")
  1959. c.Container.register(T)
  1960. self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
  1961. class U:
  1962. def __len__(self):
  1963. return 0
  1964. u = U()
  1965. self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
  1966. # from the existence of __len__()
  1967. c.Container.register(U)
  1968. # There is no preference for registered versus inferred ABCs.
  1969. with self.assertRaises(RuntimeError) as re_three:
  1970. h(u)
  1971. self.assertIn(
  1972. str(re_three.exception),
  1973. (("Ambiguous dispatch: <class 'collections.abc.Container'> "
  1974. "or <class 'collections.abc.Sized'>"),
  1975. ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
  1976. "or <class 'collections.abc.Container'>")),
  1977. )
  1978. class V(c.Sized, S):
  1979. def __len__(self):
  1980. return 0
  1981. @functools.singledispatch
  1982. def j(arg):
  1983. return "base"
  1984. @j.register(S)
  1985. def _(arg):
  1986. return "s"
  1987. @j.register(c.Container)
  1988. def _(arg):
  1989. return "container"
  1990. v = V()
  1991. self.assertEqual(j(v), "s")
  1992. c.Container.register(V)
  1993. self.assertEqual(j(v), "container") # because it ends up right after
  1994. # Sized in the MRO
  1995. def test_cache_invalidation(self):
  1996. from collections import UserDict
  1997. import weakref
  1998. class TracingDict(UserDict):
  1999. def __init__(self, *args, **kwargs):
  2000. super(TracingDict, self).__init__(*args, **kwargs)
  2001. self.set_ops = []
  2002. self.get_ops = []
  2003. def __getitem__(self, key):
  2004. result = self.data[key]
  2005. self.get_ops.append(key)
  2006. return result
  2007. def __setitem__(self, key, value):
  2008. self.set_ops.append(key)
  2009. self.data[key] = value
  2010. def clear(self):
  2011. self.data.clear()
  2012. td = TracingDict()
  2013. with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
  2014. c = collections.abc
  2015. @functools.singledispatch
  2016. def g(arg):
  2017. return "base"
  2018. d = {}
  2019. l = []
  2020. self.assertEqual(len(td), 0)
  2021. self.assertEqual(g(d), "base")
  2022. self.assertEqual(len(td), 1)
  2023. self.assertEqual(td.get_ops, [])
  2024. self.assertEqual(td.set_ops, [dict])
  2025. self.assertEqual(td.data[dict], g.registry[object])
  2026. self.assertEqual(g(l), "base")
  2027. self.assertEqual(len(td), 2)
  2028. self.assertEqual(td.get_ops, [])
  2029. self.assertEqual(td.set_ops, [dict, list])
  2030. self.assertEqual(td.data[dict], g.registry[object])
  2031. self.assertEqual(td.data[list], g.registry[object])
  2032. self.assertEqual(td.data[dict], td.data[list])
  2033. self.assertEqual(g(l), "base")
  2034. self.assertEqual(g(d), "base")
  2035. self.assertEqual(td.get_ops, [list, dict])
  2036. self.assertEqual(td.set_ops, [dict, list])
  2037. g.register(list, lambda arg: "list")
  2038. self.assertEqual(td.get_ops, [list, dict])
  2039. self.assertEqual(len(td), 0)
  2040. self.assertEqual(g(d), "base")
  2041. self.assertEqual(len(td), 1)
  2042. self.assertEqual(td.get_ops, [list, dict])
  2043. self.assertEqual(td.set_ops, [dict, list, dict])
  2044. self.assertEqual(td.data[dict],
  2045. functools._find_impl(dict, g.registry))
  2046. self.assertEqual(g(l), "list")
  2047. self.assertEqual(len(td), 2)
  2048. self.assertEqual(td.get_ops, [list, dict])
  2049. self.assertEqual(td.set_ops, [dict, list, dict, list])
  2050. self.assertEqual(td.data[list],
  2051. functools._find_impl(list, g.registry))
  2052. class X:
  2053. pass
  2054. c.MutableMapping.register(X) # Will not invalidate the cache,
  2055. # not using ABCs yet.
  2056. self.assertEqual(g(d), "base")
  2057. self.assertEqual(g(l), "list")
  2058. self.assertEqual(td.get_ops, [list, dict, dict, list])
  2059. self.assertEqual(td.set_ops, [dict, list, dict, list])
  2060. g.register(c.Sized, lambda arg: "sized")
  2061. self.assertEqual(len(td), 0)
  2062. self.assertEqual(g(d), "sized")
  2063. self.assertEqual(len(td), 1)
  2064. self.assertEqual(td.get_ops, [list, dict, dict, list])
  2065. self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
  2066. self.assertEqual(g(l), "list")
  2067. self.assertEqual(len(td), 2)
  2068. self.assertEqual(td.get_ops, [list, dict, dict, list])
  2069. self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
  2070. self.assertEqual(g(l), "list")
  2071. self.assertEqual(g(d), "sized")
  2072. self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
  2073. self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
  2074. g.dispatch(list)
  2075. g.dispatch(dict)
  2076. self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
  2077. list, dict])
  2078. self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
  2079. c.MutableSet.register(X) # Will invalidate the cache.
  2080. self.assertEqual(len(td), 2) # Stale cache.
  2081. self.assertEqual(g(l), "list")
  2082. self.assertEqual(len(td), 1)
  2083. g.register(c.MutableMapping, lambda arg: "mutablemapping")
  2084. self.assertEqual(len(td), 0)
  2085. self.assertEqual(g(d), "mutablemapping")
  2086. self.assertEqual(len(td), 1)
  2087. self.assertEqual(g(l), "list")
  2088. self.assertEqual(len(td), 2)
  2089. g.register(dict, lambda arg: "dict")
  2090. self.assertEqual(g(d), "dict")
  2091. self.assertEqual(g(l), "list")
  2092. g._clear_cache()
  2093. self.assertEqual(len(td), 0)
  2094. def test_annotations(self):
  2095. @functools.singledispatch
  2096. def i(arg):
  2097. return "base"
  2098. @i.register
  2099. def _(arg: collections.abc.Mapping):
  2100. return "mapping"
  2101. @i.register
  2102. def _(arg: "collections.abc.Sequence"):
  2103. return "sequence"
  2104. self.assertEqual(i(None), "base")
  2105. self.assertEqual(i({"a": 1}), "mapping")
  2106. self.assertEqual(i([1, 2, 3]), "sequence")
  2107. self.assertEqual(i((1, 2, 3)), "sequence")
  2108. self.assertEqual(i("str"), "sequence")
  2109. # Registering classes as callables doesn't work with annotations,
  2110. # you need to pass the type explicitly.
  2111. @i.register(str)
  2112. class _:
  2113. def __init__(self, arg):
  2114. self.arg = arg
  2115. def __eq__(self, other):
  2116. return self.arg == other
  2117. self.assertEqual(i("str"), "str")
  2118. def test_method_register(self):
  2119. class A:
  2120. @functools.singledispatchmethod
  2121. def t(self, arg):
  2122. self.arg = "base"
  2123. @t.register(int)
  2124. def _(self, arg):
  2125. self.arg = "int"
  2126. @t.register(str)
  2127. def _(self, arg):
  2128. self.arg = "str"
  2129. a = A()
  2130. a.t(0)
  2131. self.assertEqual(a.arg, "int")
  2132. aa = A()
  2133. self.assertFalse(hasattr(aa, 'arg'))
  2134. a.t('')
  2135. self.assertEqual(a.arg, "str")
  2136. aa = A()
  2137. self.assertFalse(hasattr(aa, 'arg'))
  2138. a.t(0.0)
  2139. self.assertEqual(a.arg, "base")
  2140. aa = A()
  2141. self.assertFalse(hasattr(aa, 'arg'))
  2142. def test_staticmethod_register(self):
  2143. class A:
  2144. @functools.singledispatchmethod
  2145. @staticmethod
  2146. def t(arg):
  2147. return arg
  2148. @t.register(int)
  2149. @staticmethod
  2150. def _(arg):
  2151. return isinstance(arg, int)
  2152. @t.register(str)
  2153. @staticmethod
  2154. def _(arg):
  2155. return isinstance(arg, str)
  2156. a = A()
  2157. self.assertTrue(A.t(0))
  2158. self.assertTrue(A.t(''))
  2159. self.assertEqual(A.t(0.0), 0.0)
  2160. def test_classmethod_register(self):
  2161. class A:
  2162. def __init__(self, arg):
  2163. self.arg = arg
  2164. @functools.singledispatchmethod
  2165. @classmethod
  2166. def t(cls, arg):
  2167. return cls("base")
  2168. @t.register(int)
  2169. @classmethod
  2170. def _(cls, arg):
  2171. return cls("int")
  2172. @t.register(str)
  2173. @classmethod
  2174. def _(cls, arg):
  2175. return cls("str")
  2176. self.assertEqual(A.t(0).arg, "int")
  2177. self.assertEqual(A.t('').arg, "str")
  2178. self.assertEqual(A.t(0.0).arg, "base")
  2179. def test_callable_register(self):
  2180. class A:
  2181. def __init__(self, arg):
  2182. self.arg = arg
  2183. @functools.singledispatchmethod
  2184. @classmethod
  2185. def t(cls, arg):
  2186. return cls("base")
  2187. @A.t.register(int)
  2188. @classmethod
  2189. def _(cls, arg):
  2190. return cls("int")
  2191. @A.t.register(str)
  2192. @classmethod
  2193. def _(cls, arg):
  2194. return cls("str")
  2195. self.assertEqual(A.t(0).arg, "int")
  2196. self.assertEqual(A.t('').arg, "str")
  2197. self.assertEqual(A.t(0.0).arg, "base")
  2198. def test_abstractmethod_register(self):
  2199. class Abstract(metaclass=abc.ABCMeta):
  2200. @functools.singledispatchmethod
  2201. @abc.abstractmethod
  2202. def add(self, x, y):
  2203. pass
  2204. self.assertTrue(Abstract.add.__isabstractmethod__)
  2205. self.assertTrue(Abstract.__dict__['add'].__isabstractmethod__)
  2206. with self.assertRaises(TypeError):
  2207. Abstract()
  2208. def test_type_ann_register(self):
  2209. class A:
  2210. @functools.singledispatchmethod
  2211. def t(self, arg):
  2212. return "base"
  2213. @t.register
  2214. def _(self, arg: int):
  2215. return "int"
  2216. @t.register
  2217. def _(self, arg: str):
  2218. return "str"
  2219. a = A()
  2220. self.assertEqual(a.t(0), "int")
  2221. self.assertEqual(a.t(''), "str")
  2222. self.assertEqual(a.t(0.0), "base")
  2223. def test_staticmethod_type_ann_register(self):
  2224. class A:
  2225. @functools.singledispatchmethod
  2226. @staticmethod
  2227. def t(arg):
  2228. return arg
  2229. @t.register
  2230. @staticmethod
  2231. def _(arg: int):
  2232. return isinstance(arg, int)
  2233. @t.register
  2234. @staticmethod
  2235. def _(arg: str):
  2236. return isinstance(arg, str)
  2237. a = A()
  2238. self.assertTrue(A.t(0))
  2239. self.assertTrue(A.t(''))
  2240. self.assertEqual(A.t(0.0), 0.0)
  2241. def test_classmethod_type_ann_register(self):
  2242. class A:
  2243. def __init__(self, arg):
  2244. self.arg = arg
  2245. @functools.singledispatchmethod
  2246. @classmethod
  2247. def t(cls, arg):
  2248. return cls("base")
  2249. @t.register
  2250. @classmethod
  2251. def _(cls, arg: int):
  2252. return cls("int")
  2253. @t.register
  2254. @classmethod
  2255. def _(cls, arg: str):
  2256. return cls("str")
  2257. self.assertEqual(A.t(0).arg, "int")
  2258. self.assertEqual(A.t('').arg, "str")
  2259. self.assertEqual(A.t(0.0).arg, "base")
  2260. def test_method_wrapping_attributes(self):
  2261. class A:
  2262. @functools.singledispatchmethod
  2263. def func(self, arg: int) -> str:
  2264. """My function docstring"""
  2265. return str(arg)
  2266. @functools.singledispatchmethod
  2267. @classmethod
  2268. def cls_func(cls, arg: int) -> str:
  2269. """My function docstring"""
  2270. return str(arg)
  2271. @functools.singledispatchmethod
  2272. @staticmethod
  2273. def static_func(arg: int) -> str:
  2274. """My function docstring"""
  2275. return str(arg)
  2276. for meth in (
  2277. A.func,
  2278. A().func,
  2279. A.cls_func,
  2280. A().cls_func,
  2281. A.static_func,
  2282. A().static_func
  2283. ):
  2284. with self.subTest(meth=meth):
  2285. self.assertEqual(meth.__doc__, 'My function docstring')
  2286. self.assertEqual(meth.__annotations__['arg'], int)
  2287. self.assertEqual(A.func.__name__, 'func')
  2288. self.assertEqual(A().func.__name__, 'func')
  2289. self.assertEqual(A.cls_func.__name__, 'cls_func')
  2290. self.assertEqual(A().cls_func.__name__, 'cls_func')
  2291. self.assertEqual(A.static_func.__name__, 'static_func')
  2292. self.assertEqual(A().static_func.__name__, 'static_func')
  2293. def test_double_wrapped_methods(self):
  2294. def classmethod_friendly_decorator(func):
  2295. wrapped = func.__func__
  2296. @classmethod
  2297. @functools.wraps(wrapped)
  2298. def wrapper(*args, **kwargs):
  2299. return wrapped(*args, **kwargs)
  2300. return wrapper
  2301. class WithoutSingleDispatch:
  2302. @classmethod
  2303. @contextlib.contextmanager
  2304. def cls_context_manager(cls, arg: int) -> str:
  2305. try:
  2306. yield str(arg)
  2307. finally:
  2308. return 'Done'
  2309. @classmethod_friendly_decorator
  2310. @classmethod
  2311. def decorated_classmethod(cls, arg: int) -> str:
  2312. return str(arg)
  2313. class WithSingleDispatch:
  2314. @functools.singledispatchmethod
  2315. @classmethod
  2316. @contextlib.contextmanager
  2317. def cls_context_manager(cls, arg: int) -> str:
  2318. """My function docstring"""
  2319. try:
  2320. yield str(arg)
  2321. finally:
  2322. return 'Done'
  2323. @functools.singledispatchmethod
  2324. @classmethod_friendly_decorator
  2325. @classmethod
  2326. def decorated_classmethod(cls, arg: int) -> str:
  2327. """My function docstring"""
  2328. return str(arg)
  2329. # These are sanity checks
  2330. # to test the test itself is working as expected
  2331. with WithoutSingleDispatch.cls_context_manager(5) as foo:
  2332. without_single_dispatch_foo = foo
  2333. with WithSingleDispatch.cls_context_manager(5) as foo:
  2334. single_dispatch_foo = foo
  2335. self.assertEqual(without_single_dispatch_foo, single_dispatch_foo)
  2336. self.assertEqual(single_dispatch_foo, '5')
  2337. self.assertEqual(
  2338. WithoutSingleDispatch.decorated_classmethod(5),
  2339. WithSingleDispatch.decorated_classmethod(5)
  2340. )
  2341. self.assertEqual(WithSingleDispatch.decorated_classmethod(5), '5')
  2342. # Behavioural checks now follow
  2343. for method_name in ('cls_context_manager', 'decorated_classmethod'):
  2344. with self.subTest(method=method_name):
  2345. self.assertEqual(
  2346. getattr(WithSingleDispatch, method_name).__name__,
  2347. getattr(WithoutSingleDispatch, method_name).__name__
  2348. )
  2349. self.assertEqual(
  2350. getattr(WithSingleDispatch(), method_name).__name__,
  2351. getattr(WithoutSingleDispatch(), method_name).__name__
  2352. )
  2353. for meth in (
  2354. WithSingleDispatch.cls_context_manager,
  2355. WithSingleDispatch().cls_context_manager,
  2356. WithSingleDispatch.decorated_classmethod,
  2357. WithSingleDispatch().decorated_classmethod
  2358. ):
  2359. with self.subTest(meth=meth):
  2360. self.assertEqual(meth.__doc__, 'My function docstring')
  2361. self.assertEqual(meth.__annotations__['arg'], int)
  2362. self.assertEqual(
  2363. WithSingleDispatch.cls_context_manager.__name__,
  2364. 'cls_context_manager'
  2365. )
  2366. self.assertEqual(
  2367. WithSingleDispatch().cls_context_manager.__name__,
  2368. 'cls_context_manager'
  2369. )
  2370. self.assertEqual(
  2371. WithSingleDispatch.decorated_classmethod.__name__,
  2372. 'decorated_classmethod'
  2373. )
  2374. self.assertEqual(
  2375. WithSingleDispatch().decorated_classmethod.__name__,
  2376. 'decorated_classmethod'
  2377. )
  2378. def test_invalid_registrations(self):
  2379. msg_prefix = "Invalid first argument to `register()`: "
  2380. msg_suffix = (
  2381. ". Use either `@register(some_class)` or plain `@register` on an "
  2382. "annotated function."
  2383. )
  2384. @functools.singledispatch
  2385. def i(arg):
  2386. return "base"
  2387. with self.assertRaises(TypeError) as exc:
  2388. @i.register(42)
  2389. def _(arg):
  2390. return "I annotated with a non-type"
  2391. self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
  2392. self.assertTrue(str(exc.exception).endswith(msg_suffix))
  2393. with self.assertRaises(TypeError) as exc:
  2394. @i.register
  2395. def _(arg):
  2396. return "I forgot to annotate"
  2397. self.assertTrue(str(exc.exception).startswith(msg_prefix +
  2398. "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
  2399. ))
  2400. self.assertTrue(str(exc.exception).endswith(msg_suffix))
  2401. with self.assertRaises(TypeError) as exc:
  2402. @i.register
  2403. def _(arg: typing.Iterable[str]):
  2404. # At runtime, dispatching on generics is impossible.
  2405. # When registering implementations with singledispatch, avoid
  2406. # types from `typing`. Instead, annotate with regular types
  2407. # or ABCs.
  2408. return "I annotated with a generic collection"
  2409. self.assertTrue(str(exc.exception).startswith(
  2410. "Invalid annotation for 'arg'."
  2411. ))
  2412. self.assertTrue(str(exc.exception).endswith(
  2413. 'typing.Iterable[str] is not a class.'
  2414. ))
  2415. with self.assertRaises(TypeError) as exc:
  2416. @i.register
  2417. def _(arg: typing.Union[int, typing.Iterable[str]]):
  2418. return "Invalid Union"
  2419. self.assertTrue(str(exc.exception).startswith(
  2420. "Invalid annotation for 'arg'."
  2421. ))
  2422. self.assertTrue(str(exc.exception).endswith(
  2423. 'typing.Union[int, typing.Iterable[str]] not all arguments are classes.'
  2424. ))
  2425. def test_invalid_positional_argument(self):
  2426. @functools.singledispatch
  2427. def f(*args):
  2428. pass
  2429. msg = 'f requires at least 1 positional argument'
  2430. with self.assertRaisesRegex(TypeError, msg):
  2431. f()
  2432. def test_union(self):
  2433. @functools.singledispatch
  2434. def f(arg):
  2435. return "default"
  2436. @f.register
  2437. def _(arg: typing.Union[str, bytes]):
  2438. return "typing.Union"
  2439. @f.register
  2440. def _(arg: int | float):
  2441. return "types.UnionType"
  2442. self.assertEqual(f([]), "default")
  2443. self.assertEqual(f(""), "typing.Union")
  2444. self.assertEqual(f(b""), "typing.Union")
  2445. self.assertEqual(f(1), "types.UnionType")
  2446. self.assertEqual(f(1.0), "types.UnionType")
  2447. def test_union_conflict(self):
  2448. @functools.singledispatch
  2449. def f(arg):
  2450. return "default"
  2451. @f.register
  2452. def _(arg: typing.Union[str, bytes]):
  2453. return "typing.Union"
  2454. @f.register
  2455. def _(arg: int | str):
  2456. return "types.UnionType"
  2457. self.assertEqual(f([]), "default")
  2458. self.assertEqual(f(""), "types.UnionType") # last one wins
  2459. self.assertEqual(f(b""), "typing.Union")
  2460. self.assertEqual(f(1), "types.UnionType")
  2461. def test_union_None(self):
  2462. @functools.singledispatch
  2463. def typing_union(arg):
  2464. return "default"
  2465. @typing_union.register
  2466. def _(arg: typing.Union[str, None]):
  2467. return "typing.Union"
  2468. self.assertEqual(typing_union(1), "default")
  2469. self.assertEqual(typing_union(""), "typing.Union")
  2470. self.assertEqual(typing_union(None), "typing.Union")
  2471. @functools.singledispatch
  2472. def types_union(arg):
  2473. return "default"
  2474. @types_union.register
  2475. def _(arg: int | None):
  2476. return "types.UnionType"
  2477. self.assertEqual(types_union(""), "default")
  2478. self.assertEqual(types_union(1), "types.UnionType")
  2479. self.assertEqual(types_union(None), "types.UnionType")
  2480. def test_register_genericalias(self):
  2481. @functools.singledispatch
  2482. def f(arg):
  2483. return "default"
  2484. with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
  2485. f.register(list[int], lambda arg: "types.GenericAlias")
  2486. with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
  2487. f.register(typing.List[int], lambda arg: "typing.GenericAlias")
  2488. with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
  2489. f.register(list[int] | str, lambda arg: "types.UnionTypes(types.GenericAlias)")
  2490. with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
  2491. f.register(typing.List[float] | bytes, lambda arg: "typing.Union[typing.GenericAlias]")
  2492. self.assertEqual(f([1]), "default")
  2493. self.assertEqual(f([1.0]), "default")
  2494. self.assertEqual(f(""), "default")
  2495. self.assertEqual(f(b""), "default")
  2496. def test_register_genericalias_decorator(self):
  2497. @functools.singledispatch
  2498. def f(arg):
  2499. return "default"
  2500. with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
  2501. f.register(list[int])
  2502. with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
  2503. f.register(typing.List[int])
  2504. with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
  2505. f.register(list[int] | str)
  2506. with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
  2507. f.register(typing.List[int] | str)
  2508. def test_register_genericalias_annotation(self):
  2509. @functools.singledispatch
  2510. def f(arg):
  2511. return "default"
  2512. with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
  2513. @f.register
  2514. def _(arg: list[int]):
  2515. return "types.GenericAlias"
  2516. with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
  2517. @f.register
  2518. def _(arg: typing.List[float]):
  2519. return "typing.GenericAlias"
  2520. with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
  2521. @f.register
  2522. def _(arg: list[int] | str):
  2523. return "types.UnionType(types.GenericAlias)"
  2524. with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
  2525. @f.register
  2526. def _(arg: typing.List[float] | bytes):
  2527. return "typing.Union[typing.GenericAlias]"
  2528. self.assertEqual(f([1]), "default")
  2529. self.assertEqual(f([1.0]), "default")
  2530. self.assertEqual(f(""), "default")
  2531. self.assertEqual(f(b""), "default")
  2532. class CachedCostItem:
  2533. _cost = 1
  2534. def __init__(self):
  2535. self.lock = py_functools.RLock()
  2536. @py_functools.cached_property
  2537. def cost(self):
  2538. """The cost of the item."""
  2539. with self.lock:
  2540. self._cost += 1
  2541. return self._cost
  2542. class OptionallyCachedCostItem:
  2543. _cost = 1
  2544. def get_cost(self):
  2545. """The cost of the item."""
  2546. self._cost += 1
  2547. return self._cost
  2548. cached_cost = py_functools.cached_property(get_cost)
  2549. class CachedCostItemWait:
  2550. def __init__(self, event):
  2551. self._cost = 1
  2552. self.lock = py_functools.RLock()
  2553. self.event = event
  2554. @py_functools.cached_property
  2555. def cost(self):
  2556. self.event.wait(1)
  2557. with self.lock:
  2558. self._cost += 1
  2559. return self._cost
  2560. class CachedCostItemWithSlots:
  2561. __slots__ = ('_cost')
  2562. def __init__(self):
  2563. self._cost = 1
  2564. @py_functools.cached_property
  2565. def cost(self):
  2566. raise RuntimeError('never called, slots not supported')
  2567. class TestCachedProperty(unittest.TestCase):
  2568. def test_cached(self):
  2569. item = CachedCostItem()
  2570. self.assertEqual(item.cost, 2)
  2571. self.assertEqual(item.cost, 2) # not 3
  2572. def test_cached_attribute_name_differs_from_func_name(self):
  2573. item = OptionallyCachedCostItem()
  2574. self.assertEqual(item.get_cost(), 2)
  2575. self.assertEqual(item.cached_cost, 3)
  2576. self.assertEqual(item.get_cost(), 4)
  2577. self.assertEqual(item.cached_cost, 3)
  2578. @threading_helper.requires_working_threading()
  2579. def test_threaded(self):
  2580. go = threading.Event()
  2581. item = CachedCostItemWait(go)
  2582. num_threads = 3
  2583. orig_si = sys.getswitchinterval()
  2584. sys.setswitchinterval(1e-6)
  2585. try:
  2586. threads = [
  2587. threading.Thread(target=lambda: item.cost)
  2588. for k in range(num_threads)
  2589. ]
  2590. with threading_helper.start_threads(threads):
  2591. go.set()
  2592. finally:
  2593. sys.setswitchinterval(orig_si)
  2594. self.assertEqual(item.cost, 2)
  2595. def test_object_with_slots(self):
  2596. item = CachedCostItemWithSlots()
  2597. with self.assertRaisesRegex(
  2598. TypeError,
  2599. "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
  2600. ):
  2601. item.cost
  2602. def test_immutable_dict(self):
  2603. class MyMeta(type):
  2604. @py_functools.cached_property
  2605. def prop(self):
  2606. return True
  2607. class MyClass(metaclass=MyMeta):
  2608. pass
  2609. with self.assertRaisesRegex(
  2610. TypeError,
  2611. "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
  2612. ):
  2613. MyClass.prop
  2614. def test_reuse_different_names(self):
  2615. """Disallow this case because decorated function a would not be cached."""
  2616. with self.assertRaises(RuntimeError) as ctx:
  2617. class ReusedCachedProperty:
  2618. @py_functools.cached_property
  2619. def a(self):
  2620. pass
  2621. b = a
  2622. self.assertEqual(
  2623. str(ctx.exception.__context__),
  2624. str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
  2625. )
  2626. def test_reuse_same_name(self):
  2627. """Reusing a cached_property on different classes under the same name is OK."""
  2628. counter = 0
  2629. @py_functools.cached_property
  2630. def _cp(_self):
  2631. nonlocal counter
  2632. counter += 1
  2633. return counter
  2634. class A:
  2635. cp = _cp
  2636. class B:
  2637. cp = _cp
  2638. a = A()
  2639. b = B()
  2640. self.assertEqual(a.cp, 1)
  2641. self.assertEqual(b.cp, 2)
  2642. self.assertEqual(a.cp, 1)
  2643. def test_set_name_not_called(self):
  2644. cp = py_functools.cached_property(lambda s: None)
  2645. class Foo:
  2646. pass
  2647. Foo.cp = cp
  2648. with self.assertRaisesRegex(
  2649. TypeError,
  2650. "Cannot use cached_property instance without calling __set_name__ on it.",
  2651. ):
  2652. Foo().cp
  2653. def test_access_from_class(self):
  2654. self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
  2655. def test_doc(self):
  2656. self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
  2657. if __name__ == '__main__':
  2658. unittest.main()