| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066 |
- import abc
- import builtins
- import collections
- import collections.abc
- import copy
- from itertools import permutations
- import pickle
- from random import choice
- import sys
- from test import support
- import threading
- import time
- import typing
- import unittest
- import unittest.mock
- import os
- import weakref
- import gc
- from weakref import proxy
- import contextlib
- from test.support import import_helper
- from test.support import threading_helper
- from test.support.script_helper import assert_python_ok
- import functools
- py_functools = import_helper.import_fresh_module('functools',
- blocked=['_functools'])
- c_functools = import_helper.import_fresh_module('functools')
- decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal'])
- @contextlib.contextmanager
- def replaced_module(name, replacement):
- original_module = sys.modules[name]
- sys.modules[name] = replacement
- try:
- yield
- finally:
- sys.modules[name] = original_module
- def capture(*args, **kw):
- """capture all positional and keyword arguments"""
- return args, kw
- def signature(part):
- """ return the signature of a partial object """
- return (part.func, part.args, part.keywords, part.__dict__)
- class MyTuple(tuple):
- pass
- class BadTuple(tuple):
- def __add__(self, other):
- return list(self) + list(other)
- class MyDict(dict):
- pass
- class TestPartial:
- def test_basic_examples(self):
- p = self.partial(capture, 1, 2, a=10, b=20)
- self.assertTrue(callable(p))
- self.assertEqual(p(3, 4, b=30, c=40),
- ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
- p = self.partial(map, lambda x: x*10)
- self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
- def test_attributes(self):
- p = self.partial(capture, 1, 2, a=10, b=20)
- # attributes should be readable
- self.assertEqual(p.func, capture)
- self.assertEqual(p.args, (1, 2))
- self.assertEqual(p.keywords, dict(a=10, b=20))
- def test_argument_checking(self):
- self.assertRaises(TypeError, self.partial) # need at least a func arg
- try:
- self.partial(2)()
- except TypeError:
- pass
- else:
- self.fail('First arg not checked for callability')
- def test_protection_of_callers_dict_argument(self):
- # a caller's dictionary should not be altered by partial
- def func(a=10, b=20):
- return a
- d = {'a':3}
- p = self.partial(func, a=5)
- self.assertEqual(p(**d), 3)
- self.assertEqual(d, {'a':3})
- p(b=7)
- self.assertEqual(d, {'a':3})
- def test_kwargs_copy(self):
- # Issue #29532: Altering a kwarg dictionary passed to a constructor
- # should not affect a partial object after creation
- d = {'a': 3}
- p = self.partial(capture, **d)
- self.assertEqual(p(), ((), {'a': 3}))
- d['a'] = 5
- self.assertEqual(p(), ((), {'a': 3}))
- def test_arg_combinations(self):
- # exercise special code paths for zero args in either partial
- # object or the caller
- p = self.partial(capture)
- self.assertEqual(p(), ((), {}))
- self.assertEqual(p(1,2), ((1,2), {}))
- p = self.partial(capture, 1, 2)
- self.assertEqual(p(), ((1,2), {}))
- self.assertEqual(p(3,4), ((1,2,3,4), {}))
- def test_kw_combinations(self):
- # exercise special code paths for no keyword args in
- # either the partial object or the caller
- p = self.partial(capture)
- self.assertEqual(p.keywords, {})
- self.assertEqual(p(), ((), {}))
- self.assertEqual(p(a=1), ((), {'a':1}))
- p = self.partial(capture, a=1)
- self.assertEqual(p.keywords, {'a':1})
- self.assertEqual(p(), ((), {'a':1}))
- self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
- # keyword args in the call override those in the partial object
- self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
- def test_positional(self):
- # make sure positional arguments are captured correctly
- for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
- p = self.partial(capture, *args)
- expected = args + ('x',)
- got, empty = p('x')
- self.assertTrue(expected == got and empty == {})
- def test_keyword(self):
- # make sure keyword arguments are captured correctly
- for a in ['a', 0, None, 3.5]:
- p = self.partial(capture, a=a)
- expected = {'a':a,'x':None}
- empty, got = p(x=None)
- self.assertTrue(expected == got and empty == ())
- def test_no_side_effects(self):
- # make sure there are no side effects that affect subsequent calls
- p = self.partial(capture, 0, a=1)
- args1, kw1 = p(1, b=2)
- self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
- args2, kw2 = p()
- self.assertTrue(args2 == (0,) and kw2 == {'a':1})
- def test_error_propagation(self):
- def f(x, y):
- x / y
- self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
- self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
- self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
- self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
- def test_weakref(self):
- f = self.partial(int, base=16)
- p = proxy(f)
- self.assertEqual(f.func, p.func)
- f = None
- support.gc_collect() # For PyPy or other GCs.
- self.assertRaises(ReferenceError, getattr, p, 'func')
- def test_with_bound_and_unbound_methods(self):
- data = list(map(str, range(10)))
- join = self.partial(str.join, '')
- self.assertEqual(join(data), '0123456789')
- join = self.partial(''.join)
- self.assertEqual(join(data), '0123456789')
- def test_nested_optimization(self):
- partial = self.partial
- inner = partial(signature, 'asdf')
- nested = partial(inner, bar=True)
- flat = partial(signature, 'asdf', bar=True)
- self.assertEqual(signature(nested), signature(flat))
- def test_nested_partial_with_attribute(self):
- # see issue 25137
- partial = self.partial
- def foo(bar):
- return bar
- p = partial(foo, 'first')
- p2 = partial(p, 'second')
- p2.new_attr = 'spam'
- self.assertEqual(p2.new_attr, 'spam')
- def test_repr(self):
- args = (object(), object())
- args_repr = ', '.join(repr(a) for a in args)
- kwargs = {'a': object(), 'b': object()}
- kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
- 'b={b!r}, a={a!r}'.format_map(kwargs)]
- if self.partial in (c_functools.partial, py_functools.partial):
- name = 'functools.partial'
- else:
- name = self.partial.__name__
- f = self.partial(capture)
- self.assertEqual(f'{name}({capture!r})', repr(f))
- f = self.partial(capture, *args)
- self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
- f = self.partial(capture, **kwargs)
- self.assertIn(repr(f),
- [f'{name}({capture!r}, {kwargs_repr})'
- for kwargs_repr in kwargs_reprs])
- f = self.partial(capture, *args, **kwargs)
- self.assertIn(repr(f),
- [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
- for kwargs_repr in kwargs_reprs])
- def test_recursive_repr(self):
- if self.partial in (c_functools.partial, py_functools.partial):
- name = 'functools.partial'
- else:
- name = self.partial.__name__
- f = self.partial(capture)
- f.__setstate__((f, (), {}, {}))
- try:
- self.assertEqual(repr(f), '%s(...)' % (name,))
- finally:
- f.__setstate__((capture, (), {}, {}))
- f = self.partial(capture)
- f.__setstate__((capture, (f,), {}, {}))
- try:
- self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
- finally:
- f.__setstate__((capture, (), {}, {}))
- f = self.partial(capture)
- f.__setstate__((capture, (), {'a': f}, {}))
- try:
- self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
- finally:
- f.__setstate__((capture, (), {}, {}))
- def test_pickle(self):
- with self.AllowPickle():
- f = self.partial(signature, ['asdf'], bar=[True])
- f.attr = []
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- f_copy = pickle.loads(pickle.dumps(f, proto))
- self.assertEqual(signature(f_copy), signature(f))
- def test_copy(self):
- f = self.partial(signature, ['asdf'], bar=[True])
- f.attr = []
- f_copy = copy.copy(f)
- self.assertEqual(signature(f_copy), signature(f))
- self.assertIs(f_copy.attr, f.attr)
- self.assertIs(f_copy.args, f.args)
- self.assertIs(f_copy.keywords, f.keywords)
- def test_deepcopy(self):
- f = self.partial(signature, ['asdf'], bar=[True])
- f.attr = []
- f_copy = copy.deepcopy(f)
- self.assertEqual(signature(f_copy), signature(f))
- self.assertIsNot(f_copy.attr, f.attr)
- self.assertIsNot(f_copy.args, f.args)
- self.assertIsNot(f_copy.args[0], f.args[0])
- self.assertIsNot(f_copy.keywords, f.keywords)
- self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
- def test_setstate(self):
- f = self.partial(signature)
- f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
- self.assertEqual(signature(f),
- (capture, (1,), dict(a=10), dict(attr=[])))
- self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
- f.__setstate__((capture, (1,), dict(a=10), None))
- self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
- self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
- f.__setstate__((capture, (1,), None, None))
- #self.assertEqual(signature(f), (capture, (1,), {}, {}))
- self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
- self.assertEqual(f(2), ((1, 2), {}))
- self.assertEqual(f(), ((1,), {}))
- f.__setstate__((capture, (), {}, None))
- self.assertEqual(signature(f), (capture, (), {}, {}))
- self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
- self.assertEqual(f(2), ((2,), {}))
- self.assertEqual(f(), ((), {}))
- def test_setstate_errors(self):
- f = self.partial(signature)
- self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
- self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
- self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
- self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
- self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
- self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
- self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
- def test_setstate_subclasses(self):
- f = self.partial(signature)
- f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
- s = signature(f)
- self.assertEqual(s, (capture, (1,), dict(a=10), {}))
- self.assertIs(type(s[1]), tuple)
- self.assertIs(type(s[2]), dict)
- r = f()
- self.assertEqual(r, ((1,), {'a': 10}))
- self.assertIs(type(r[0]), tuple)
- self.assertIs(type(r[1]), dict)
- f.__setstate__((capture, BadTuple((1,)), {}, None))
- s = signature(f)
- self.assertEqual(s, (capture, (1,), {}, {}))
- self.assertIs(type(s[1]), tuple)
- r = f(2)
- self.assertEqual(r, ((1, 2), {}))
- self.assertIs(type(r[0]), tuple)
- def test_recursive_pickle(self):
- with self.AllowPickle():
- f = self.partial(capture)
- f.__setstate__((f, (), {}, {}))
- try:
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- with self.assertRaises(RecursionError):
- pickle.dumps(f, proto)
- finally:
- f.__setstate__((capture, (), {}, {}))
- f = self.partial(capture)
- f.__setstate__((capture, (f,), {}, {}))
- try:
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- f_copy = pickle.loads(pickle.dumps(f, proto))
- try:
- self.assertIs(f_copy.args[0], f_copy)
- finally:
- f_copy.__setstate__((capture, (), {}, {}))
- finally:
- f.__setstate__((capture, (), {}, {}))
- f = self.partial(capture)
- f.__setstate__((capture, (), {'a': f}, {}))
- try:
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- f_copy = pickle.loads(pickle.dumps(f, proto))
- try:
- self.assertIs(f_copy.keywords['a'], f_copy)
- finally:
- f_copy.__setstate__((capture, (), {}, {}))
- finally:
- f.__setstate__((capture, (), {}, {}))
- # Issue 6083: Reference counting bug
- def test_setstate_refcount(self):
- class BadSequence:
- def __len__(self):
- return 4
- def __getitem__(self, key):
- if key == 0:
- return max
- elif key == 1:
- return tuple(range(1000000))
- elif key in (2, 3):
- return {}
- raise IndexError
- f = self.partial(object)
- self.assertRaises(TypeError, f.__setstate__, BadSequence())
- @unittest.skipUnless(c_functools, 'requires the C _functools module')
- class TestPartialC(TestPartial, unittest.TestCase):
- if c_functools:
- partial = c_functools.partial
- class AllowPickle:
- def __enter__(self):
- return self
- def __exit__(self, type, value, tb):
- return False
- def test_attributes_unwritable(self):
- # attributes should not be writable
- p = self.partial(capture, 1, 2, a=10, b=20)
- self.assertRaises(AttributeError, setattr, p, 'func', map)
- self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
- self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
- p = self.partial(hex)
- try:
- del p.__dict__
- except TypeError:
- pass
- else:
- self.fail('partial object allowed __dict__ to be deleted')
- def test_manually_adding_non_string_keyword(self):
- p = self.partial(capture)
- # Adding a non-string/unicode keyword to partial kwargs
- p.keywords[1234] = 'value'
- r = repr(p)
- self.assertIn('1234', r)
- self.assertIn("'value'", r)
- with self.assertRaises(TypeError):
- p()
- def test_keystr_replaces_value(self):
- p = self.partial(capture)
- class MutatesYourDict(object):
- def __str__(self):
- p.keywords[self] = ['sth2']
- return 'astr'
- # Replacing the value during key formatting should keep the original
- # value alive (at least long enough).
- p.keywords[MutatesYourDict()] = ['sth']
- r = repr(p)
- self.assertIn('astr', r)
- self.assertIn("['sth']", r)
- class TestPartialPy(TestPartial, unittest.TestCase):
- partial = py_functools.partial
- class AllowPickle:
- def __init__(self):
- self._cm = replaced_module("functools", py_functools)
- def __enter__(self):
- return self._cm.__enter__()
- def __exit__(self, type, value, tb):
- return self._cm.__exit__(type, value, tb)
- if c_functools:
- class CPartialSubclass(c_functools.partial):
- pass
- class PyPartialSubclass(py_functools.partial):
- pass
- @unittest.skipUnless(c_functools, 'requires the C _functools module')
- class TestPartialCSubclass(TestPartialC):
- if c_functools:
- partial = CPartialSubclass
- # partial subclasses are not optimized for nested calls
- test_nested_optimization = None
- class TestPartialPySubclass(TestPartialPy):
- partial = PyPartialSubclass
- class TestPartialMethod(unittest.TestCase):
- class A(object):
- nothing = functools.partialmethod(capture)
- positional = functools.partialmethod(capture, 1)
- keywords = functools.partialmethod(capture, a=2)
- both = functools.partialmethod(capture, 3, b=4)
- spec_keywords = functools.partialmethod(capture, self=1, func=2)
- nested = functools.partialmethod(positional, 5)
- over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
- static = functools.partialmethod(staticmethod(capture), 8)
- cls = functools.partialmethod(classmethod(capture), d=9)
- a = A()
- def test_arg_combinations(self):
- self.assertEqual(self.a.nothing(), ((self.a,), {}))
- self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
- self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
- self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
- self.assertEqual(self.a.positional(), ((self.a, 1), {}))
- self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
- self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
- self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
- self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
- self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
- self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
- self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
- self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
- self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
- self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
- self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
- self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
- self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
- def test_nested(self):
- self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
- self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
- self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
- self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
- self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
- def test_over_partial(self):
- self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
- self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
- self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
- self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
- self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
- def test_bound_method_introspection(self):
- obj = self.a
- self.assertIs(obj.both.__self__, obj)
- self.assertIs(obj.nested.__self__, obj)
- self.assertIs(obj.over_partial.__self__, obj)
- self.assertIs(obj.cls.__self__, self.A)
- self.assertIs(self.A.cls.__self__, self.A)
- def test_unbound_method_retrieval(self):
- obj = self.A
- self.assertFalse(hasattr(obj.both, "__self__"))
- self.assertFalse(hasattr(obj.nested, "__self__"))
- self.assertFalse(hasattr(obj.over_partial, "__self__"))
- self.assertFalse(hasattr(obj.static, "__self__"))
- self.assertFalse(hasattr(self.a.static, "__self__"))
- def test_descriptors(self):
- for obj in [self.A, self.a]:
- with self.subTest(obj=obj):
- self.assertEqual(obj.static(), ((8,), {}))
- self.assertEqual(obj.static(5), ((8, 5), {}))
- self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
- self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
- self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
- self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
- self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
- self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
- def test_overriding_keywords(self):
- self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
- self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
- def test_invalid_args(self):
- with self.assertRaises(TypeError):
- class B(object):
- method = functools.partialmethod(None, 1)
- with self.assertRaises(TypeError):
- class B:
- method = functools.partialmethod()
- with self.assertRaises(TypeError):
- class B:
- method = functools.partialmethod(func=capture, a=1)
- def test_repr(self):
- self.assertEqual(repr(vars(self.A)['both']),
- 'functools.partialmethod({}, 3, b=4)'.format(capture))
- def test_abstract(self):
- class Abstract(abc.ABCMeta):
- @abc.abstractmethod
- def add(self, x, y):
- pass
- add5 = functools.partialmethod(add, 5)
- self.assertTrue(Abstract.add.__isabstractmethod__)
- self.assertTrue(Abstract.add5.__isabstractmethod__)
- for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
- self.assertFalse(getattr(func, '__isabstractmethod__', False))
- def test_positional_only(self):
- def f(a, b, /):
- return a + b
- p = functools.partial(f, 1)
- self.assertEqual(p(2), f(1, 2))
- class TestUpdateWrapper(unittest.TestCase):
- def check_wrapper(self, wrapper, wrapped,
- assigned=functools.WRAPPER_ASSIGNMENTS,
- updated=functools.WRAPPER_UPDATES):
- # Check attributes were assigned
- for name in assigned:
- self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
- # Check attributes were updated
- for name in updated:
- wrapper_attr = getattr(wrapper, name)
- wrapped_attr = getattr(wrapped, name)
- for key in wrapped_attr:
- if name == "__dict__" and key == "__wrapped__":
- # __wrapped__ is overwritten by the update code
- continue
- self.assertIs(wrapped_attr[key], wrapper_attr[key])
- # Check __wrapped__
- self.assertIs(wrapper.__wrapped__, wrapped)
- def _default_update(self):
- def f(a:'This is a new annotation'):
- """This is a test"""
- pass
- f.attr = 'This is also a test'
- f.__wrapped__ = "This is a bald faced lie"
- def wrapper(b:'This is the prior annotation'):
- pass
- functools.update_wrapper(wrapper, f)
- return wrapper, f
- def test_default_update(self):
- wrapper, f = self._default_update()
- self.check_wrapper(wrapper, f)
- self.assertIs(wrapper.__wrapped__, f)
- self.assertEqual(wrapper.__name__, 'f')
- self.assertEqual(wrapper.__qualname__, f.__qualname__)
- self.assertEqual(wrapper.attr, 'This is also a test')
- self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
- self.assertNotIn('b', wrapper.__annotations__)
- @unittest.skipIf(sys.flags.optimize >= 2,
- "Docstrings are omitted with -O2 and above")
- def test_default_update_doc(self):
- wrapper, f = self._default_update()
- self.assertEqual(wrapper.__doc__, 'This is a test')
- def test_no_update(self):
- def f():
- """This is a test"""
- pass
- f.attr = 'This is also a test'
- def wrapper():
- pass
- functools.update_wrapper(wrapper, f, (), ())
- self.check_wrapper(wrapper, f, (), ())
- self.assertEqual(wrapper.__name__, 'wrapper')
- self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
- self.assertEqual(wrapper.__doc__, None)
- self.assertEqual(wrapper.__annotations__, {})
- self.assertFalse(hasattr(wrapper, 'attr'))
- def test_selective_update(self):
- def f():
- pass
- f.attr = 'This is a different test'
- f.dict_attr = dict(a=1, b=2, c=3)
- def wrapper():
- pass
- wrapper.dict_attr = {}
- assign = ('attr',)
- update = ('dict_attr',)
- functools.update_wrapper(wrapper, f, assign, update)
- self.check_wrapper(wrapper, f, assign, update)
- self.assertEqual(wrapper.__name__, 'wrapper')
- self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
- self.assertEqual(wrapper.__doc__, None)
- self.assertEqual(wrapper.attr, 'This is a different test')
- self.assertEqual(wrapper.dict_attr, f.dict_attr)
- def test_missing_attributes(self):
- def f():
- pass
- def wrapper():
- pass
- wrapper.dict_attr = {}
- assign = ('attr',)
- update = ('dict_attr',)
- # Missing attributes on wrapped object are ignored
- functools.update_wrapper(wrapper, f, assign, update)
- self.assertNotIn('attr', wrapper.__dict__)
- self.assertEqual(wrapper.dict_attr, {})
- # Wrapper must have expected attributes for updating
- del wrapper.dict_attr
- with self.assertRaises(AttributeError):
- functools.update_wrapper(wrapper, f, assign, update)
- wrapper.dict_attr = 1
- with self.assertRaises(AttributeError):
- functools.update_wrapper(wrapper, f, assign, update)
- @support.requires_docstrings
- @unittest.skipIf(sys.flags.optimize >= 2,
- "Docstrings are omitted with -O2 and above")
- def test_builtin_update(self):
- # Test for bug #1576241
- def wrapper():
- pass
- functools.update_wrapper(wrapper, max)
- self.assertEqual(wrapper.__name__, 'max')
- self.assertTrue(wrapper.__doc__.startswith('max('))
- self.assertEqual(wrapper.__annotations__, {})
- class TestWraps(TestUpdateWrapper):
- def _default_update(self):
- def f():
- """This is a test"""
- pass
- f.attr = 'This is also a test'
- f.__wrapped__ = "This is still a bald faced lie"
- @functools.wraps(f)
- def wrapper():
- pass
- return wrapper, f
- def test_default_update(self):
- wrapper, f = self._default_update()
- self.check_wrapper(wrapper, f)
- self.assertEqual(wrapper.__name__, 'f')
- self.assertEqual(wrapper.__qualname__, f.__qualname__)
- self.assertEqual(wrapper.attr, 'This is also a test')
- @unittest.skipIf(sys.flags.optimize >= 2,
- "Docstrings are omitted with -O2 and above")
- def test_default_update_doc(self):
- wrapper, _ = self._default_update()
- self.assertEqual(wrapper.__doc__, 'This is a test')
- def test_no_update(self):
- def f():
- """This is a test"""
- pass
- f.attr = 'This is also a test'
- @functools.wraps(f, (), ())
- def wrapper():
- pass
- self.check_wrapper(wrapper, f, (), ())
- self.assertEqual(wrapper.__name__, 'wrapper')
- self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
- self.assertEqual(wrapper.__doc__, None)
- self.assertFalse(hasattr(wrapper, 'attr'))
- def test_selective_update(self):
- def f():
- pass
- f.attr = 'This is a different test'
- f.dict_attr = dict(a=1, b=2, c=3)
- def add_dict_attr(f):
- f.dict_attr = {}
- return f
- assign = ('attr',)
- update = ('dict_attr',)
- @functools.wraps(f, assign, update)
- @add_dict_attr
- def wrapper():
- pass
- self.check_wrapper(wrapper, f, assign, update)
- self.assertEqual(wrapper.__name__, 'wrapper')
- self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
- self.assertEqual(wrapper.__doc__, None)
- self.assertEqual(wrapper.attr, 'This is a different test')
- self.assertEqual(wrapper.dict_attr, f.dict_attr)
- class TestReduce:
- def test_reduce(self):
- class Squares:
- def __init__(self, max):
- self.max = max
- self.sofar = []
- def __len__(self):
- return len(self.sofar)
- def __getitem__(self, i):
- if not 0 <= i < self.max: raise IndexError
- n = len(self.sofar)
- while n <= i:
- self.sofar.append(n*n)
- n += 1
- return self.sofar[i]
- def add(x, y):
- return x + y
- self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
- self.assertEqual(
- self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
- ['a','c','d','w']
- )
- self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
- self.assertEqual(
- self.reduce(lambda x, y: x*y, range(2,21), 1),
- 2432902008176640000
- )
- self.assertEqual(self.reduce(add, Squares(10)), 285)
- self.assertEqual(self.reduce(add, Squares(10), 0), 285)
- self.assertEqual(self.reduce(add, Squares(0), 0), 0)
- self.assertRaises(TypeError, self.reduce)
- self.assertRaises(TypeError, self.reduce, 42, 42)
- self.assertRaises(TypeError, self.reduce, 42, 42, 42)
- self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
- self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
- self.assertRaises(TypeError, self.reduce, 42, (42, 42))
- self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
- self.assertRaises(TypeError, self.reduce, add, "")
- self.assertRaises(TypeError, self.reduce, add, ())
- self.assertRaises(TypeError, self.reduce, add, object())
- class TestFailingIter:
- def __iter__(self):
- raise RuntimeError
- self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
- self.assertEqual(self.reduce(add, [], None), None)
- self.assertEqual(self.reduce(add, [], 42), 42)
- class BadSeq:
- def __getitem__(self, index):
- raise ValueError
- self.assertRaises(ValueError, self.reduce, 42, BadSeq())
- # Test reduce()'s use of iterators.
- def test_iterator_usage(self):
- class SequenceClass:
- def __init__(self, n):
- self.n = n
- def __getitem__(self, i):
- if 0 <= i < self.n:
- return i
- else:
- raise IndexError
- from operator import add
- self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
- self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
- self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
- self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
- self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
- self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
- d = {"one": 1, "two": 2, "three": 3}
- self.assertEqual(self.reduce(add, d), "".join(d.keys()))
- @unittest.skipUnless(c_functools, 'requires the C _functools module')
- class TestReduceC(TestReduce, unittest.TestCase):
- if c_functools:
- reduce = c_functools.reduce
- class TestReducePy(TestReduce, unittest.TestCase):
- reduce = staticmethod(py_functools.reduce)
- class TestCmpToKey:
- def test_cmp_to_key(self):
- def cmp1(x, y):
- return (x > y) - (x < y)
- key = self.cmp_to_key(cmp1)
- self.assertEqual(key(3), key(3))
- self.assertGreater(key(3), key(1))
- self.assertGreaterEqual(key(3), key(3))
- def cmp2(x, y):
- return int(x) - int(y)
- key = self.cmp_to_key(cmp2)
- self.assertEqual(key(4.0), key('4'))
- self.assertLess(key(2), key('35'))
- self.assertLessEqual(key(2), key('35'))
- self.assertNotEqual(key(2), key('35'))
- def test_cmp_to_key_arguments(self):
- def cmp1(x, y):
- return (x > y) - (x < y)
- key = self.cmp_to_key(mycmp=cmp1)
- self.assertEqual(key(obj=3), key(obj=3))
- self.assertGreater(key(obj=3), key(obj=1))
- with self.assertRaises((TypeError, AttributeError)):
- key(3) > 1 # rhs is not a K object
- with self.assertRaises((TypeError, AttributeError)):
- 1 < key(3) # lhs is not a K object
- with self.assertRaises(TypeError):
- key = self.cmp_to_key() # too few args
- with self.assertRaises(TypeError):
- key = self.cmp_to_key(cmp1, None) # too many args
- key = self.cmp_to_key(cmp1)
- with self.assertRaises(TypeError):
- key() # too few args
- with self.assertRaises(TypeError):
- key(None, None) # too many args
- def test_bad_cmp(self):
- def cmp1(x, y):
- raise ZeroDivisionError
- key = self.cmp_to_key(cmp1)
- with self.assertRaises(ZeroDivisionError):
- key(3) > key(1)
- class BadCmp:
- def __lt__(self, other):
- raise ZeroDivisionError
- def cmp1(x, y):
- return BadCmp()
- with self.assertRaises(ZeroDivisionError):
- key(3) > key(1)
- def test_obj_field(self):
- def cmp1(x, y):
- return (x > y) - (x < y)
- key = self.cmp_to_key(mycmp=cmp1)
- self.assertEqual(key(50).obj, 50)
- def test_sort_int(self):
- def mycmp(x, y):
- return y - x
- self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
- [4, 3, 2, 1, 0])
- def test_sort_int_str(self):
- def mycmp(x, y):
- x, y = int(x), int(y)
- return (x > y) - (x < y)
- values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
- values = sorted(values, key=self.cmp_to_key(mycmp))
- self.assertEqual([int(value) for value in values],
- [0, 1, 1, 2, 3, 4, 5, 7, 10])
- def test_hash(self):
- def mycmp(x, y):
- return y - x
- key = self.cmp_to_key(mycmp)
- k = key(10)
- self.assertRaises(TypeError, hash, k)
- self.assertNotIsInstance(k, collections.abc.Hashable)
- @unittest.skipUnless(c_functools, 'requires the C _functools module')
- class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
- if c_functools:
- cmp_to_key = c_functools.cmp_to_key
- @support.cpython_only
- def test_disallow_instantiation(self):
- # Ensure that the type disallows instantiation (bpo-43916)
- support.check_disallow_instantiation(
- self, type(c_functools.cmp_to_key(None))
- )
- class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
- cmp_to_key = staticmethod(py_functools.cmp_to_key)
- class TestTotalOrdering(unittest.TestCase):
- def test_total_ordering_lt(self):
- @functools.total_ordering
- class A:
- def __init__(self, value):
- self.value = value
- def __lt__(self, other):
- return self.value < other.value
- def __eq__(self, other):
- return self.value == other.value
- self.assertTrue(A(1) < A(2))
- self.assertTrue(A(2) > A(1))
- self.assertTrue(A(1) <= A(2))
- self.assertTrue(A(2) >= A(1))
- self.assertTrue(A(2) <= A(2))
- self.assertTrue(A(2) >= A(2))
- self.assertFalse(A(1) > A(2))
- def test_total_ordering_le(self):
- @functools.total_ordering
- class A:
- def __init__(self, value):
- self.value = value
- def __le__(self, other):
- return self.value <= other.value
- def __eq__(self, other):
- return self.value == other.value
- self.assertTrue(A(1) < A(2))
- self.assertTrue(A(2) > A(1))
- self.assertTrue(A(1) <= A(2))
- self.assertTrue(A(2) >= A(1))
- self.assertTrue(A(2) <= A(2))
- self.assertTrue(A(2) >= A(2))
- self.assertFalse(A(1) >= A(2))
- def test_total_ordering_gt(self):
- @functools.total_ordering
- class A:
- def __init__(self, value):
- self.value = value
- def __gt__(self, other):
- return self.value > other.value
- def __eq__(self, other):
- return self.value == other.value
- self.assertTrue(A(1) < A(2))
- self.assertTrue(A(2) > A(1))
- self.assertTrue(A(1) <= A(2))
- self.assertTrue(A(2) >= A(1))
- self.assertTrue(A(2) <= A(2))
- self.assertTrue(A(2) >= A(2))
- self.assertFalse(A(2) < A(1))
- def test_total_ordering_ge(self):
- @functools.total_ordering
- class A:
- def __init__(self, value):
- self.value = value
- def __ge__(self, other):
- return self.value >= other.value
- def __eq__(self, other):
- return self.value == other.value
- self.assertTrue(A(1) < A(2))
- self.assertTrue(A(2) > A(1))
- self.assertTrue(A(1) <= A(2))
- self.assertTrue(A(2) >= A(1))
- self.assertTrue(A(2) <= A(2))
- self.assertTrue(A(2) >= A(2))
- self.assertFalse(A(2) <= A(1))
- def test_total_ordering_no_overwrite(self):
- # new methods should not overwrite existing
- @functools.total_ordering
- class A(int):
- pass
- self.assertTrue(A(1) < A(2))
- self.assertTrue(A(2) > A(1))
- self.assertTrue(A(1) <= A(2))
- self.assertTrue(A(2) >= A(1))
- self.assertTrue(A(2) <= A(2))
- self.assertTrue(A(2) >= A(2))
- def test_no_operations_defined(self):
- with self.assertRaises(ValueError):
- @functools.total_ordering
- class A:
- pass
- def test_notimplemented(self):
- # Verify NotImplemented results are correctly handled
- @functools.total_ordering
- class ImplementsLessThan:
- def __init__(self, value):
- self.value = value
- def __eq__(self, other):
- if isinstance(other, ImplementsLessThan):
- return self.value == other.value
- return False
- def __lt__(self, other):
- if isinstance(other, ImplementsLessThan):
- return self.value < other.value
- return NotImplemented
- @functools.total_ordering
- class ImplementsLessThanEqualTo:
- def __init__(self, value):
- self.value = value
- def __eq__(self, other):
- if isinstance(other, ImplementsLessThanEqualTo):
- return self.value == other.value
- return False
- def __le__(self, other):
- if isinstance(other, ImplementsLessThanEqualTo):
- return self.value <= other.value
- return NotImplemented
- @functools.total_ordering
- class ImplementsGreaterThan:
- def __init__(self, value):
- self.value = value
- def __eq__(self, other):
- if isinstance(other, ImplementsGreaterThan):
- return self.value == other.value
- return False
- def __gt__(self, other):
- if isinstance(other, ImplementsGreaterThan):
- return self.value > other.value
- return NotImplemented
- @functools.total_ordering
- class ImplementsGreaterThanEqualTo:
- def __init__(self, value):
- self.value = value
- def __eq__(self, other):
- if isinstance(other, ImplementsGreaterThanEqualTo):
- return self.value == other.value
- return False
- def __ge__(self, other):
- if isinstance(other, ImplementsGreaterThanEqualTo):
- return self.value >= other.value
- return NotImplemented
- self.assertIs(ImplementsLessThan(1).__le__(1), NotImplemented)
- self.assertIs(ImplementsLessThan(1).__gt__(1), NotImplemented)
- self.assertIs(ImplementsLessThan(1).__ge__(1), NotImplemented)
- self.assertIs(ImplementsLessThanEqualTo(1).__lt__(1), NotImplemented)
- self.assertIs(ImplementsLessThanEqualTo(1).__gt__(1), NotImplemented)
- self.assertIs(ImplementsLessThanEqualTo(1).__ge__(1), NotImplemented)
- self.assertIs(ImplementsGreaterThan(1).__lt__(1), NotImplemented)
- self.assertIs(ImplementsGreaterThan(1).__gt__(1), NotImplemented)
- self.assertIs(ImplementsGreaterThan(1).__ge__(1), NotImplemented)
- self.assertIs(ImplementsGreaterThanEqualTo(1).__lt__(1), NotImplemented)
- self.assertIs(ImplementsGreaterThanEqualTo(1).__le__(1), NotImplemented)
- self.assertIs(ImplementsGreaterThanEqualTo(1).__gt__(1), NotImplemented)
- def test_type_error_when_not_implemented(self):
- # bug 10042; ensure stack overflow does not occur
- # when decorated types return NotImplemented
- @functools.total_ordering
- class ImplementsLessThan:
- def __init__(self, value):
- self.value = value
- def __eq__(self, other):
- if isinstance(other, ImplementsLessThan):
- return self.value == other.value
- return False
- def __lt__(self, other):
- if isinstance(other, ImplementsLessThan):
- return self.value < other.value
- return NotImplemented
- @functools.total_ordering
- class ImplementsGreaterThan:
- def __init__(self, value):
- self.value = value
- def __eq__(self, other):
- if isinstance(other, ImplementsGreaterThan):
- return self.value == other.value
- return False
- def __gt__(self, other):
- if isinstance(other, ImplementsGreaterThan):
- return self.value > other.value
- return NotImplemented
- @functools.total_ordering
- class ImplementsLessThanEqualTo:
- def __init__(self, value):
- self.value = value
- def __eq__(self, other):
- if isinstance(other, ImplementsLessThanEqualTo):
- return self.value == other.value
- return False
- def __le__(self, other):
- if isinstance(other, ImplementsLessThanEqualTo):
- return self.value <= other.value
- return NotImplemented
- @functools.total_ordering
- class ImplementsGreaterThanEqualTo:
- def __init__(self, value):
- self.value = value
- def __eq__(self, other):
- if isinstance(other, ImplementsGreaterThanEqualTo):
- return self.value == other.value
- return False
- def __ge__(self, other):
- if isinstance(other, ImplementsGreaterThanEqualTo):
- return self.value >= other.value
- return NotImplemented
- @functools.total_ordering
- class ComparatorNotImplemented:
- def __init__(self, value):
- self.value = value
- def __eq__(self, other):
- if isinstance(other, ComparatorNotImplemented):
- return self.value == other.value
- return False
- def __lt__(self, other):
- return NotImplemented
- with self.subTest("LT < 1"), self.assertRaises(TypeError):
- ImplementsLessThan(-1) < 1
- with self.subTest("LT < LE"), self.assertRaises(TypeError):
- ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
- with self.subTest("LT < GT"), self.assertRaises(TypeError):
- ImplementsLessThan(1) < ImplementsGreaterThan(1)
- with self.subTest("LE <= LT"), self.assertRaises(TypeError):
- ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
- with self.subTest("LE <= GE"), self.assertRaises(TypeError):
- ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
- with self.subTest("GT > GE"), self.assertRaises(TypeError):
- ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
- with self.subTest("GT > LT"), self.assertRaises(TypeError):
- ImplementsGreaterThan(5) > ImplementsLessThan(5)
- with self.subTest("GE >= GT"), self.assertRaises(TypeError):
- ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
- with self.subTest("GE >= LE"), self.assertRaises(TypeError):
- ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
- with self.subTest("GE when equal"):
- a = ComparatorNotImplemented(8)
- b = ComparatorNotImplemented(8)
- self.assertEqual(a, b)
- with self.assertRaises(TypeError):
- a >= b
- with self.subTest("LE when equal"):
- a = ComparatorNotImplemented(9)
- b = ComparatorNotImplemented(9)
- self.assertEqual(a, b)
- with self.assertRaises(TypeError):
- a <= b
- def test_pickle(self):
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- for name in '__lt__', '__gt__', '__le__', '__ge__':
- with self.subTest(method=name, proto=proto):
- method = getattr(Orderable_LT, name)
- method_copy = pickle.loads(pickle.dumps(method, proto))
- self.assertIs(method_copy, method)
- def test_total_ordering_for_metaclasses_issue_44605(self):
- @functools.total_ordering
- class SortableMeta(type):
- def __new__(cls, name, bases, ns):
- return super().__new__(cls, name, bases, ns)
- def __lt__(self, other):
- if not isinstance(other, SortableMeta):
- pass
- return self.__name__ < other.__name__
- def __eq__(self, other):
- if not isinstance(other, SortableMeta):
- pass
- return self.__name__ == other.__name__
- class B(metaclass=SortableMeta):
- pass
- class A(metaclass=SortableMeta):
- pass
- self.assertTrue(A < B)
- self.assertFalse(A > B)
- @functools.total_ordering
- class Orderable_LT:
- def __init__(self, value):
- self.value = value
- def __lt__(self, other):
- return self.value < other.value
- def __eq__(self, other):
- return self.value == other.value
- class TestCache:
- # This tests that the pass-through is working as designed.
- # The underlying functionality is tested in TestLRU.
- def test_cache(self):
- @self.module.cache
- def fib(n):
- if n < 2:
- return n
- return fib(n-1) + fib(n-2)
- self.assertEqual([fib(n) for n in range(16)],
- [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
- self.assertEqual(fib.cache_info(),
- self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
- fib.cache_clear()
- self.assertEqual(fib.cache_info(),
- self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
- class TestLRU:
- def test_lru(self):
- def orig(x, y):
- return 3 * x + y
- f = self.module.lru_cache(maxsize=20)(orig)
- hits, misses, maxsize, currsize = f.cache_info()
- self.assertEqual(maxsize, 20)
- self.assertEqual(currsize, 0)
- self.assertEqual(hits, 0)
- self.assertEqual(misses, 0)
- domain = range(5)
- for i in range(1000):
- x, y = choice(domain), choice(domain)
- actual = f(x, y)
- expected = orig(x, y)
- self.assertEqual(actual, expected)
- hits, misses, maxsize, currsize = f.cache_info()
- self.assertTrue(hits > misses)
- self.assertEqual(hits + misses, 1000)
- self.assertEqual(currsize, 20)
- f.cache_clear() # test clearing
- hits, misses, maxsize, currsize = f.cache_info()
- self.assertEqual(hits, 0)
- self.assertEqual(misses, 0)
- self.assertEqual(currsize, 0)
- f(x, y)
- hits, misses, maxsize, currsize = f.cache_info()
- self.assertEqual(hits, 0)
- self.assertEqual(misses, 1)
- self.assertEqual(currsize, 1)
- # Test bypassing the cache
- self.assertIs(f.__wrapped__, orig)
- f.__wrapped__(x, y)
- hits, misses, maxsize, currsize = f.cache_info()
- self.assertEqual(hits, 0)
- self.assertEqual(misses, 1)
- self.assertEqual(currsize, 1)
- # test size zero (which means "never-cache")
- @self.module.lru_cache(0)
- def f():
- nonlocal f_cnt
- f_cnt += 1
- return 20
- self.assertEqual(f.cache_info().maxsize, 0)
- f_cnt = 0
- for i in range(5):
- self.assertEqual(f(), 20)
- self.assertEqual(f_cnt, 5)
- hits, misses, maxsize, currsize = f.cache_info()
- self.assertEqual(hits, 0)
- self.assertEqual(misses, 5)
- self.assertEqual(currsize, 0)
- # test size one
- @self.module.lru_cache(1)
- def f():
- nonlocal f_cnt
- f_cnt += 1
- return 20
- self.assertEqual(f.cache_info().maxsize, 1)
- f_cnt = 0
- for i in range(5):
- self.assertEqual(f(), 20)
- self.assertEqual(f_cnt, 1)
- hits, misses, maxsize, currsize = f.cache_info()
- self.assertEqual(hits, 4)
- self.assertEqual(misses, 1)
- self.assertEqual(currsize, 1)
- # test size two
- @self.module.lru_cache(2)
- def f(x):
- nonlocal f_cnt
- f_cnt += 1
- return x*10
- self.assertEqual(f.cache_info().maxsize, 2)
- f_cnt = 0
- for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
- # * * * *
- self.assertEqual(f(x), x*10)
- self.assertEqual(f_cnt, 4)
- hits, misses, maxsize, currsize = f.cache_info()
- self.assertEqual(hits, 12)
- self.assertEqual(misses, 4)
- self.assertEqual(currsize, 2)
- def test_lru_no_args(self):
- @self.module.lru_cache
- def square(x):
- return x ** 2
- self.assertEqual(list(map(square, [10, 20, 10])),
- [100, 400, 100])
- self.assertEqual(square.cache_info().hits, 1)
- self.assertEqual(square.cache_info().misses, 2)
- self.assertEqual(square.cache_info().maxsize, 128)
- self.assertEqual(square.cache_info().currsize, 2)
- def test_lru_bug_35780(self):
- # C version of the lru_cache was not checking to see if
- # the user function call has already modified the cache
- # (this arises in recursive calls and in multi-threading).
- # This cause the cache to have orphan links not referenced
- # by the cache dictionary.
- once = True # Modified by f(x) below
- @self.module.lru_cache(maxsize=10)
- def f(x):
- nonlocal once
- rv = f'.{x}.'
- if x == 20 and once:
- once = False
- rv = f(x)
- return rv
- # Fill the cache
- for x in range(15):
- self.assertEqual(f(x), f'.{x}.')
- self.assertEqual(f.cache_info().currsize, 10)
- # Make a recursive call and make sure the cache remains full
- self.assertEqual(f(20), '.20.')
- self.assertEqual(f.cache_info().currsize, 10)
- def test_lru_bug_36650(self):
- # C version of lru_cache was treating a call with an empty **kwargs
- # dictionary as being distinct from a call with no keywords at all.
- # This did not result in an incorrect answer, but it did trigger
- # an unexpected cache miss.
- @self.module.lru_cache()
- def f(x):
- pass
- f(0)
- f(0, **{})
- self.assertEqual(f.cache_info().hits, 1)
- def test_lru_hash_only_once(self):
- # To protect against weird reentrancy bugs and to improve
- # efficiency when faced with slow __hash__ methods, the
- # LRU cache guarantees that it will only call __hash__
- # only once per use as an argument to the cached function.
- @self.module.lru_cache(maxsize=1)
- def f(x, y):
- return x * 3 + y
- # Simulate the integer 5
- mock_int = unittest.mock.Mock()
- mock_int.__mul__ = unittest.mock.Mock(return_value=15)
- mock_int.__hash__ = unittest.mock.Mock(return_value=999)
- # Add to cache: One use as an argument gives one call
- self.assertEqual(f(mock_int, 1), 16)
- self.assertEqual(mock_int.__hash__.call_count, 1)
- self.assertEqual(f.cache_info(), (0, 1, 1, 1))
- # Cache hit: One use as an argument gives one additional call
- self.assertEqual(f(mock_int, 1), 16)
- self.assertEqual(mock_int.__hash__.call_count, 2)
- self.assertEqual(f.cache_info(), (1, 1, 1, 1))
- # Cache eviction: No use as an argument gives no additional call
- self.assertEqual(f(6, 2), 20)
- self.assertEqual(mock_int.__hash__.call_count, 2)
- self.assertEqual(f.cache_info(), (1, 2, 1, 1))
- # Cache miss: One use as an argument gives one additional call
- self.assertEqual(f(mock_int, 1), 16)
- self.assertEqual(mock_int.__hash__.call_count, 3)
- self.assertEqual(f.cache_info(), (1, 3, 1, 1))
- def test_lru_reentrancy_with_len(self):
- # Test to make sure the LRU cache code isn't thrown-off by
- # caching the built-in len() function. Since len() can be
- # cached, we shouldn't use it inside the lru code itself.
- old_len = builtins.len
- try:
- builtins.len = self.module.lru_cache(4)(len)
- for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
- self.assertEqual(len('abcdefghijklmn'[:i]), i)
- finally:
- builtins.len = old_len
- def test_lru_star_arg_handling(self):
- # Test regression that arose in ea064ff3c10f
- @self.module.lru_cache()
- def f(*args):
- return args
- self.assertEqual(f(1, 2), (1, 2))
- self.assertEqual(f((1, 2)), ((1, 2),))
- def test_lru_type_error(self):
- # Regression test for issue #28653.
- # lru_cache was leaking when one of the arguments
- # wasn't cacheable.
- @self.module.lru_cache(maxsize=None)
- def infinite_cache(o):
- pass
- @self.module.lru_cache(maxsize=10)
- def limited_cache(o):
- pass
- with self.assertRaises(TypeError):
- infinite_cache([])
- with self.assertRaises(TypeError):
- limited_cache([])
- def test_lru_with_maxsize_none(self):
- @self.module.lru_cache(maxsize=None)
- def fib(n):
- if n < 2:
- return n
- return fib(n-1) + fib(n-2)
- self.assertEqual([fib(n) for n in range(16)],
- [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
- self.assertEqual(fib.cache_info(),
- self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
- fib.cache_clear()
- self.assertEqual(fib.cache_info(),
- self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
- def test_lru_with_maxsize_negative(self):
- @self.module.lru_cache(maxsize=-10)
- def eq(n):
- return n
- for i in (0, 1):
- self.assertEqual([eq(n) for n in range(150)], list(range(150)))
- self.assertEqual(eq.cache_info(),
- self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
- def test_lru_with_exceptions(self):
- # Verify that user_function exceptions get passed through without
- # creating a hard-to-read chained exception.
- # http://bugs.python.org/issue13177
- for maxsize in (None, 128):
- @self.module.lru_cache(maxsize)
- def func(i):
- return 'abc'[i]
- self.assertEqual(func(0), 'a')
- with self.assertRaises(IndexError) as cm:
- func(15)
- self.assertIsNone(cm.exception.__context__)
- # Verify that the previous exception did not result in a cached entry
- with self.assertRaises(IndexError):
- func(15)
- def test_lru_with_types(self):
- for maxsize in (None, 128):
- @self.module.lru_cache(maxsize=maxsize, typed=True)
- def square(x):
- return x * x
- self.assertEqual(square(3), 9)
- self.assertEqual(type(square(3)), type(9))
- self.assertEqual(square(3.0), 9.0)
- self.assertEqual(type(square(3.0)), type(9.0))
- self.assertEqual(square(x=3), 9)
- self.assertEqual(type(square(x=3)), type(9))
- self.assertEqual(square(x=3.0), 9.0)
- self.assertEqual(type(square(x=3.0)), type(9.0))
- self.assertEqual(square.cache_info().hits, 4)
- self.assertEqual(square.cache_info().misses, 4)
- def test_lru_cache_typed_is_not_recursive(self):
- cached = self.module.lru_cache(typed=True)(repr)
- self.assertEqual(cached(1), '1')
- self.assertEqual(cached(True), 'True')
- self.assertEqual(cached(1.0), '1.0')
- self.assertEqual(cached(0), '0')
- self.assertEqual(cached(False), 'False')
- self.assertEqual(cached(0.0), '0.0')
- self.assertEqual(cached((1,)), '(1,)')
- self.assertEqual(cached((True,)), '(1,)')
- self.assertEqual(cached((1.0,)), '(1,)')
- self.assertEqual(cached((0,)), '(0,)')
- self.assertEqual(cached((False,)), '(0,)')
- self.assertEqual(cached((0.0,)), '(0,)')
- class T(tuple):
- pass
- self.assertEqual(cached(T((1,))), '(1,)')
- self.assertEqual(cached(T((True,))), '(1,)')
- self.assertEqual(cached(T((1.0,))), '(1,)')
- self.assertEqual(cached(T((0,))), '(0,)')
- self.assertEqual(cached(T((False,))), '(0,)')
- self.assertEqual(cached(T((0.0,))), '(0,)')
- def test_lru_with_keyword_args(self):
- @self.module.lru_cache()
- def fib(n):
- if n < 2:
- return n
- return fib(n=n-1) + fib(n=n-2)
- self.assertEqual(
- [fib(n=number) for number in range(16)],
- [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
- )
- self.assertEqual(fib.cache_info(),
- self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
- fib.cache_clear()
- self.assertEqual(fib.cache_info(),
- self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
- def test_lru_with_keyword_args_maxsize_none(self):
- @self.module.lru_cache(maxsize=None)
- def fib(n):
- if n < 2:
- return n
- return fib(n=n-1) + fib(n=n-2)
- self.assertEqual([fib(n=number) for number in range(16)],
- [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
- self.assertEqual(fib.cache_info(),
- self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
- fib.cache_clear()
- self.assertEqual(fib.cache_info(),
- self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
- def test_kwargs_order(self):
- # PEP 468: Preserving Keyword Argument Order
- @self.module.lru_cache(maxsize=10)
- def f(**kwargs):
- return list(kwargs.items())
- self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
- self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
- self.assertEqual(f.cache_info(),
- self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
- def test_lru_cache_decoration(self):
- def f(zomg: 'zomg_annotation'):
- """f doc string"""
- return 42
- g = self.module.lru_cache()(f)
- for attr in self.module.WRAPPER_ASSIGNMENTS:
- self.assertEqual(getattr(g, attr), getattr(f, attr))
- @threading_helper.requires_working_threading()
- def test_lru_cache_threaded(self):
- n, m = 5, 11
- def orig(x, y):
- return 3 * x + y
- f = self.module.lru_cache(maxsize=n*m)(orig)
- hits, misses, maxsize, currsize = f.cache_info()
- self.assertEqual(currsize, 0)
- start = threading.Event()
- def full(k):
- start.wait(10)
- for _ in range(m):
- self.assertEqual(f(k, 0), orig(k, 0))
- def clear():
- start.wait(10)
- for _ in range(2*m):
- f.cache_clear()
- orig_si = sys.getswitchinterval()
- support.setswitchinterval(1e-6)
- try:
- # create n threads in order to fill cache
- threads = [threading.Thread(target=full, args=[k])
- for k in range(n)]
- with threading_helper.start_threads(threads):
- start.set()
- hits, misses, maxsize, currsize = f.cache_info()
- if self.module is py_functools:
- # XXX: Why can be not equal?
- self.assertLessEqual(misses, n)
- self.assertLessEqual(hits, m*n - misses)
- else:
- self.assertEqual(misses, n)
- self.assertEqual(hits, m*n - misses)
- self.assertEqual(currsize, n)
- # create n threads in order to fill cache and 1 to clear it
- threads = [threading.Thread(target=clear)]
- threads += [threading.Thread(target=full, args=[k])
- for k in range(n)]
- start.clear()
- with threading_helper.start_threads(threads):
- start.set()
- finally:
- sys.setswitchinterval(orig_si)
- @threading_helper.requires_working_threading()
- def test_lru_cache_threaded2(self):
- # Simultaneous call with the same arguments
- n, m = 5, 7
- start = threading.Barrier(n+1)
- pause = threading.Barrier(n+1)
- stop = threading.Barrier(n+1)
- @self.module.lru_cache(maxsize=m*n)
- def f(x):
- pause.wait(10)
- return 3 * x
- self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
- def test():
- for i in range(m):
- start.wait(10)
- self.assertEqual(f(i), 3 * i)
- stop.wait(10)
- threads = [threading.Thread(target=test) for k in range(n)]
- with threading_helper.start_threads(threads):
- for i in range(m):
- start.wait(10)
- stop.reset()
- pause.wait(10)
- start.reset()
- stop.wait(10)
- pause.reset()
- self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
- @threading_helper.requires_working_threading()
- def test_lru_cache_threaded3(self):
- @self.module.lru_cache(maxsize=2)
- def f(x):
- time.sleep(.01)
- return 3 * x
- def test(i, x):
- with self.subTest(thread=i):
- self.assertEqual(f(x), 3 * x, i)
- threads = [threading.Thread(target=test, args=(i, v))
- for i, v in enumerate([1, 2, 2, 3, 2])]
- with threading_helper.start_threads(threads):
- pass
- def test_need_for_rlock(self):
- # This will deadlock on an LRU cache that uses a regular lock
- @self.module.lru_cache(maxsize=10)
- def test_func(x):
- 'Used to demonstrate a reentrant lru_cache call within a single thread'
- return x
- class DoubleEq:
- 'Demonstrate a reentrant lru_cache call within a single thread'
- def __init__(self, x):
- self.x = x
- def __hash__(self):
- return self.x
- def __eq__(self, other):
- if self.x == 2:
- test_func(DoubleEq(1))
- return self.x == other.x
- test_func(DoubleEq(1)) # Load the cache
- test_func(DoubleEq(2)) # Load the cache
- self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
- DoubleEq(2)) # Verify the correct return value
- def test_lru_method(self):
- class X(int):
- f_cnt = 0
- @self.module.lru_cache(2)
- def f(self, x):
- self.f_cnt += 1
- return x*10+self
- a = X(5)
- b = X(5)
- c = X(7)
- self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
- for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
- self.assertEqual(a.f(x), x*10 + 5)
- self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
- self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
- for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
- self.assertEqual(b.f(x), x*10 + 5)
- self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
- self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
- for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
- self.assertEqual(c.f(x), x*10 + 7)
- self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
- self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
- self.assertEqual(a.f.cache_info(), X.f.cache_info())
- self.assertEqual(b.f.cache_info(), X.f.cache_info())
- self.assertEqual(c.f.cache_info(), X.f.cache_info())
- def test_pickle(self):
- cls = self.__class__
- for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- with self.subTest(proto=proto, func=f):
- f_copy = pickle.loads(pickle.dumps(f, proto))
- self.assertIs(f_copy, f)
- def test_copy(self):
- cls = self.__class__
- def orig(x, y):
- return 3 * x + y
- part = self.module.partial(orig, 2)
- funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
- self.module.lru_cache(2)(part))
- for f in funcs:
- with self.subTest(func=f):
- f_copy = copy.copy(f)
- self.assertIs(f_copy, f)
- def test_deepcopy(self):
- cls = self.__class__
- def orig(x, y):
- return 3 * x + y
- part = self.module.partial(orig, 2)
- funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
- self.module.lru_cache(2)(part))
- for f in funcs:
- with self.subTest(func=f):
- f_copy = copy.deepcopy(f)
- self.assertIs(f_copy, f)
- def test_lru_cache_parameters(self):
- @self.module.lru_cache(maxsize=2)
- def f():
- return 1
- self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
- @self.module.lru_cache(maxsize=1000, typed=True)
- def f():
- return 1
- self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
- def test_lru_cache_weakrefable(self):
- @self.module.lru_cache
- def test_function(x):
- return x
- class A:
- @self.module.lru_cache
- def test_method(self, x):
- return (self, x)
- @staticmethod
- @self.module.lru_cache
- def test_staticmethod(x):
- return (self, x)
- refs = [weakref.ref(test_function),
- weakref.ref(A.test_method),
- weakref.ref(A.test_staticmethod)]
- for ref in refs:
- self.assertIsNotNone(ref())
- del A
- del test_function
- gc.collect()
- for ref in refs:
- self.assertIsNone(ref())
- @py_functools.lru_cache()
- def py_cached_func(x, y):
- return 3 * x + y
- @c_functools.lru_cache()
- def c_cached_func(x, y):
- return 3 * x + y
- class TestLRUPy(TestLRU, unittest.TestCase):
- module = py_functools
- cached_func = py_cached_func,
- @module.lru_cache()
- def cached_meth(self, x, y):
- return 3 * x + y
- @staticmethod
- @module.lru_cache()
- def cached_staticmeth(x, y):
- return 3 * x + y
- class TestLRUC(TestLRU, unittest.TestCase):
- module = c_functools
- cached_func = c_cached_func,
- @module.lru_cache()
- def cached_meth(self, x, y):
- return 3 * x + y
- @staticmethod
- @module.lru_cache()
- def cached_staticmeth(x, y):
- return 3 * x + y
- class TestSingleDispatch(unittest.TestCase):
- def test_simple_overloads(self):
- @functools.singledispatch
- def g(obj):
- return "base"
- def g_int(i):
- return "integer"
- g.register(int, g_int)
- self.assertEqual(g("str"), "base")
- self.assertEqual(g(1), "integer")
- self.assertEqual(g([1,2,3]), "base")
- def test_mro(self):
- @functools.singledispatch
- def g(obj):
- return "base"
- class A:
- pass
- class C(A):
- pass
- class B(A):
- pass
- class D(C, B):
- pass
- def g_A(a):
- return "A"
- def g_B(b):
- return "B"
- g.register(A, g_A)
- g.register(B, g_B)
- self.assertEqual(g(A()), "A")
- self.assertEqual(g(B()), "B")
- self.assertEqual(g(C()), "A")
- self.assertEqual(g(D()), "B")
- def test_register_decorator(self):
- @functools.singledispatch
- def g(obj):
- return "base"
- @g.register(int)
- def g_int(i):
- return "int %s" % (i,)
- self.assertEqual(g(""), "base")
- self.assertEqual(g(12), "int 12")
- self.assertIs(g.dispatch(int), g_int)
- self.assertIs(g.dispatch(object), g.dispatch(str))
- # Note: in the assert above this is not g.
- # @singledispatch returns the wrapper.
- def test_wrapping_attributes(self):
- @functools.singledispatch
- def g(obj):
- "Simple test"
- return "Test"
- self.assertEqual(g.__name__, "g")
- if sys.flags.optimize < 2:
- self.assertEqual(g.__doc__, "Simple test")
- @unittest.skipUnless(decimal, 'requires _decimal')
- @support.cpython_only
- def test_c_classes(self):
- @functools.singledispatch
- def g(obj):
- return "base"
- @g.register(decimal.DecimalException)
- def _(obj):
- return obj.args
- subn = decimal.Subnormal("Exponent < Emin")
- rnd = decimal.Rounded("Number got rounded")
- self.assertEqual(g(subn), ("Exponent < Emin",))
- self.assertEqual(g(rnd), ("Number got rounded",))
- @g.register(decimal.Subnormal)
- def _(obj):
- return "Too small to care."
- self.assertEqual(g(subn), "Too small to care.")
- self.assertEqual(g(rnd), ("Number got rounded",))
- def test_compose_mro(self):
- # None of the examples in this test depend on haystack ordering.
- c = collections.abc
- mro = functools._compose_mro
- bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
- for haystack in permutations(bases):
- m = mro(dict, haystack)
- self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
- c.Collection, c.Sized, c.Iterable,
- c.Container, object])
- bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
- for haystack in permutations(bases):
- m = mro(collections.ChainMap, haystack)
- self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
- c.Collection, c.Sized, c.Iterable,
- c.Container, object])
- # If there's a generic function with implementations registered for
- # both Sized and Container, passing a defaultdict to it results in an
- # ambiguous dispatch which will cause a RuntimeError (see
- # test_mro_conflicts).
- bases = [c.Container, c.Sized, str]
- for haystack in permutations(bases):
- m = mro(collections.defaultdict, [c.Sized, c.Container, str])
- self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
- c.Container, object])
- # MutableSequence below is registered directly on D. In other words, it
- # precedes MutableMapping which means single dispatch will always
- # choose MutableSequence here.
- class D(collections.defaultdict):
- pass
- c.MutableSequence.register(D)
- bases = [c.MutableSequence, c.MutableMapping]
- for haystack in permutations(bases):
- m = mro(D, bases)
- self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
- collections.defaultdict, dict, c.MutableMapping, c.Mapping,
- c.Collection, c.Sized, c.Iterable, c.Container,
- object])
- # Container and Callable are registered on different base classes and
- # a generic function supporting both should always pick the Callable
- # implementation if a C instance is passed.
- class C(collections.defaultdict):
- def __call__(self):
- pass
- bases = [c.Sized, c.Callable, c.Container, c.Mapping]
- for haystack in permutations(bases):
- m = mro(C, haystack)
- self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
- c.Collection, c.Sized, c.Iterable,
- c.Container, object])
- def test_register_abc(self):
- c = collections.abc
- d = {"a": "b"}
- l = [1, 2, 3]
- s = {object(), None}
- f = frozenset(s)
- t = (1, 2, 3)
- @functools.singledispatch
- def g(obj):
- return "base"
- self.assertEqual(g(d), "base")
- self.assertEqual(g(l), "base")
- self.assertEqual(g(s), "base")
- self.assertEqual(g(f), "base")
- self.assertEqual(g(t), "base")
- g.register(c.Sized, lambda obj: "sized")
- self.assertEqual(g(d), "sized")
- self.assertEqual(g(l), "sized")
- self.assertEqual(g(s), "sized")
- self.assertEqual(g(f), "sized")
- self.assertEqual(g(t), "sized")
- g.register(c.MutableMapping, lambda obj: "mutablemapping")
- self.assertEqual(g(d), "mutablemapping")
- self.assertEqual(g(l), "sized")
- self.assertEqual(g(s), "sized")
- self.assertEqual(g(f), "sized")
- self.assertEqual(g(t), "sized")
- g.register(collections.ChainMap, lambda obj: "chainmap")
- self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
- self.assertEqual(g(l), "sized")
- self.assertEqual(g(s), "sized")
- self.assertEqual(g(f), "sized")
- self.assertEqual(g(t), "sized")
- g.register(c.MutableSequence, lambda obj: "mutablesequence")
- self.assertEqual(g(d), "mutablemapping")
- self.assertEqual(g(l), "mutablesequence")
- self.assertEqual(g(s), "sized")
- self.assertEqual(g(f), "sized")
- self.assertEqual(g(t), "sized")
- g.register(c.MutableSet, lambda obj: "mutableset")
- self.assertEqual(g(d), "mutablemapping")
- self.assertEqual(g(l), "mutablesequence")
- self.assertEqual(g(s), "mutableset")
- self.assertEqual(g(f), "sized")
- self.assertEqual(g(t), "sized")
- g.register(c.Mapping, lambda obj: "mapping")
- self.assertEqual(g(d), "mutablemapping") # not specific enough
- self.assertEqual(g(l), "mutablesequence")
- self.assertEqual(g(s), "mutableset")
- self.assertEqual(g(f), "sized")
- self.assertEqual(g(t), "sized")
- g.register(c.Sequence, lambda obj: "sequence")
- self.assertEqual(g(d), "mutablemapping")
- self.assertEqual(g(l), "mutablesequence")
- self.assertEqual(g(s), "mutableset")
- self.assertEqual(g(f), "sized")
- self.assertEqual(g(t), "sequence")
- g.register(c.Set, lambda obj: "set")
- self.assertEqual(g(d), "mutablemapping")
- self.assertEqual(g(l), "mutablesequence")
- self.assertEqual(g(s), "mutableset")
- self.assertEqual(g(f), "set")
- self.assertEqual(g(t), "sequence")
- g.register(dict, lambda obj: "dict")
- self.assertEqual(g(d), "dict")
- self.assertEqual(g(l), "mutablesequence")
- self.assertEqual(g(s), "mutableset")
- self.assertEqual(g(f), "set")
- self.assertEqual(g(t), "sequence")
- g.register(list, lambda obj: "list")
- self.assertEqual(g(d), "dict")
- self.assertEqual(g(l), "list")
- self.assertEqual(g(s), "mutableset")
- self.assertEqual(g(f), "set")
- self.assertEqual(g(t), "sequence")
- g.register(set, lambda obj: "concrete-set")
- self.assertEqual(g(d), "dict")
- self.assertEqual(g(l), "list")
- self.assertEqual(g(s), "concrete-set")
- self.assertEqual(g(f), "set")
- self.assertEqual(g(t), "sequence")
- g.register(frozenset, lambda obj: "frozen-set")
- self.assertEqual(g(d), "dict")
- self.assertEqual(g(l), "list")
- self.assertEqual(g(s), "concrete-set")
- self.assertEqual(g(f), "frozen-set")
- self.assertEqual(g(t), "sequence")
- g.register(tuple, lambda obj: "tuple")
- self.assertEqual(g(d), "dict")
- self.assertEqual(g(l), "list")
- self.assertEqual(g(s), "concrete-set")
- self.assertEqual(g(f), "frozen-set")
- self.assertEqual(g(t), "tuple")
- def test_c3_abc(self):
- c = collections.abc
- mro = functools._c3_mro
- class A(object):
- pass
- class B(A):
- def __len__(self):
- return 0 # implies Sized
- @c.Container.register
- class C(object):
- pass
- class D(object):
- pass # unrelated
- class X(D, C, B):
- def __call__(self):
- pass # implies Callable
- expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
- for abcs in permutations([c.Sized, c.Callable, c.Container]):
- self.assertEqual(mro(X, abcs=abcs), expected)
- # unrelated ABCs don't appear in the resulting MRO
- many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
- self.assertEqual(mro(X, abcs=many_abcs), expected)
- def test_false_meta(self):
- # see issue23572
- class MetaA(type):
- def __len__(self):
- return 0
- class A(metaclass=MetaA):
- pass
- class AA(A):
- pass
- @functools.singledispatch
- def fun(a):
- return 'base A'
- @fun.register(A)
- def _(a):
- return 'fun A'
- aa = AA()
- self.assertEqual(fun(aa), 'fun A')
- def test_mro_conflicts(self):
- c = collections.abc
- @functools.singledispatch
- def g(arg):
- return "base"
- class O(c.Sized):
- def __len__(self):
- return 0
- o = O()
- self.assertEqual(g(o), "base")
- g.register(c.Iterable, lambda arg: "iterable")
- g.register(c.Container, lambda arg: "container")
- g.register(c.Sized, lambda arg: "sized")
- g.register(c.Set, lambda arg: "set")
- self.assertEqual(g(o), "sized")
- c.Iterable.register(O)
- self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
- c.Container.register(O)
- self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
- c.Set.register(O)
- self.assertEqual(g(o), "set") # because c.Set is a subclass of
- # c.Sized and c.Container
- class P:
- pass
- p = P()
- self.assertEqual(g(p), "base")
- c.Iterable.register(P)
- self.assertEqual(g(p), "iterable")
- c.Container.register(P)
- with self.assertRaises(RuntimeError) as re_one:
- g(p)
- self.assertIn(
- str(re_one.exception),
- (("Ambiguous dispatch: <class 'collections.abc.Container'> "
- "or <class 'collections.abc.Iterable'>"),
- ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
- "or <class 'collections.abc.Container'>")),
- )
- class Q(c.Sized):
- def __len__(self):
- return 0
- q = Q()
- self.assertEqual(g(q), "sized")
- c.Iterable.register(Q)
- self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
- c.Set.register(Q)
- self.assertEqual(g(q), "set") # because c.Set is a subclass of
- # c.Sized and c.Iterable
- @functools.singledispatch
- def h(arg):
- return "base"
- @h.register(c.Sized)
- def _(arg):
- return "sized"
- @h.register(c.Container)
- def _(arg):
- return "container"
- # Even though Sized and Container are explicit bases of MutableMapping,
- # this ABC is implicitly registered on defaultdict which makes all of
- # MutableMapping's bases implicit as well from defaultdict's
- # perspective.
- with self.assertRaises(RuntimeError) as re_two:
- h(collections.defaultdict(lambda: 0))
- self.assertIn(
- str(re_two.exception),
- (("Ambiguous dispatch: <class 'collections.abc.Container'> "
- "or <class 'collections.abc.Sized'>"),
- ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
- "or <class 'collections.abc.Container'>")),
- )
- class R(collections.defaultdict):
- pass
- c.MutableSequence.register(R)
- @functools.singledispatch
- def i(arg):
- return "base"
- @i.register(c.MutableMapping)
- def _(arg):
- return "mapping"
- @i.register(c.MutableSequence)
- def _(arg):
- return "sequence"
- r = R()
- self.assertEqual(i(r), "sequence")
- class S:
- pass
- class T(S, c.Sized):
- def __len__(self):
- return 0
- t = T()
- self.assertEqual(h(t), "sized")
- c.Container.register(T)
- self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
- class U:
- def __len__(self):
- return 0
- u = U()
- self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
- # from the existence of __len__()
- c.Container.register(U)
- # There is no preference for registered versus inferred ABCs.
- with self.assertRaises(RuntimeError) as re_three:
- h(u)
- self.assertIn(
- str(re_three.exception),
- (("Ambiguous dispatch: <class 'collections.abc.Container'> "
- "or <class 'collections.abc.Sized'>"),
- ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
- "or <class 'collections.abc.Container'>")),
- )
- class V(c.Sized, S):
- def __len__(self):
- return 0
- @functools.singledispatch
- def j(arg):
- return "base"
- @j.register(S)
- def _(arg):
- return "s"
- @j.register(c.Container)
- def _(arg):
- return "container"
- v = V()
- self.assertEqual(j(v), "s")
- c.Container.register(V)
- self.assertEqual(j(v), "container") # because it ends up right after
- # Sized in the MRO
- def test_cache_invalidation(self):
- from collections import UserDict
- import weakref
- class TracingDict(UserDict):
- def __init__(self, *args, **kwargs):
- super(TracingDict, self).__init__(*args, **kwargs)
- self.set_ops = []
- self.get_ops = []
- def __getitem__(self, key):
- result = self.data[key]
- self.get_ops.append(key)
- return result
- def __setitem__(self, key, value):
- self.set_ops.append(key)
- self.data[key] = value
- def clear(self):
- self.data.clear()
- td = TracingDict()
- with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
- c = collections.abc
- @functools.singledispatch
- def g(arg):
- return "base"
- d = {}
- l = []
- self.assertEqual(len(td), 0)
- self.assertEqual(g(d), "base")
- self.assertEqual(len(td), 1)
- self.assertEqual(td.get_ops, [])
- self.assertEqual(td.set_ops, [dict])
- self.assertEqual(td.data[dict], g.registry[object])
- self.assertEqual(g(l), "base")
- self.assertEqual(len(td), 2)
- self.assertEqual(td.get_ops, [])
- self.assertEqual(td.set_ops, [dict, list])
- self.assertEqual(td.data[dict], g.registry[object])
- self.assertEqual(td.data[list], g.registry[object])
- self.assertEqual(td.data[dict], td.data[list])
- self.assertEqual(g(l), "base")
- self.assertEqual(g(d), "base")
- self.assertEqual(td.get_ops, [list, dict])
- self.assertEqual(td.set_ops, [dict, list])
- g.register(list, lambda arg: "list")
- self.assertEqual(td.get_ops, [list, dict])
- self.assertEqual(len(td), 0)
- self.assertEqual(g(d), "base")
- self.assertEqual(len(td), 1)
- self.assertEqual(td.get_ops, [list, dict])
- self.assertEqual(td.set_ops, [dict, list, dict])
- self.assertEqual(td.data[dict],
- functools._find_impl(dict, g.registry))
- self.assertEqual(g(l), "list")
- self.assertEqual(len(td), 2)
- self.assertEqual(td.get_ops, [list, dict])
- self.assertEqual(td.set_ops, [dict, list, dict, list])
- self.assertEqual(td.data[list],
- functools._find_impl(list, g.registry))
- class X:
- pass
- c.MutableMapping.register(X) # Will not invalidate the cache,
- # not using ABCs yet.
- self.assertEqual(g(d), "base")
- self.assertEqual(g(l), "list")
- self.assertEqual(td.get_ops, [list, dict, dict, list])
- self.assertEqual(td.set_ops, [dict, list, dict, list])
- g.register(c.Sized, lambda arg: "sized")
- self.assertEqual(len(td), 0)
- self.assertEqual(g(d), "sized")
- self.assertEqual(len(td), 1)
- self.assertEqual(td.get_ops, [list, dict, dict, list])
- self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
- self.assertEqual(g(l), "list")
- self.assertEqual(len(td), 2)
- self.assertEqual(td.get_ops, [list, dict, dict, list])
- self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
- self.assertEqual(g(l), "list")
- self.assertEqual(g(d), "sized")
- self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
- self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
- g.dispatch(list)
- g.dispatch(dict)
- self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
- list, dict])
- self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
- c.MutableSet.register(X) # Will invalidate the cache.
- self.assertEqual(len(td), 2) # Stale cache.
- self.assertEqual(g(l), "list")
- self.assertEqual(len(td), 1)
- g.register(c.MutableMapping, lambda arg: "mutablemapping")
- self.assertEqual(len(td), 0)
- self.assertEqual(g(d), "mutablemapping")
- self.assertEqual(len(td), 1)
- self.assertEqual(g(l), "list")
- self.assertEqual(len(td), 2)
- g.register(dict, lambda arg: "dict")
- self.assertEqual(g(d), "dict")
- self.assertEqual(g(l), "list")
- g._clear_cache()
- self.assertEqual(len(td), 0)
- def test_annotations(self):
- @functools.singledispatch
- def i(arg):
- return "base"
- @i.register
- def _(arg: collections.abc.Mapping):
- return "mapping"
- @i.register
- def _(arg: "collections.abc.Sequence"):
- return "sequence"
- self.assertEqual(i(None), "base")
- self.assertEqual(i({"a": 1}), "mapping")
- self.assertEqual(i([1, 2, 3]), "sequence")
- self.assertEqual(i((1, 2, 3)), "sequence")
- self.assertEqual(i("str"), "sequence")
- # Registering classes as callables doesn't work with annotations,
- # you need to pass the type explicitly.
- @i.register(str)
- class _:
- def __init__(self, arg):
- self.arg = arg
- def __eq__(self, other):
- return self.arg == other
- self.assertEqual(i("str"), "str")
- def test_method_register(self):
- class A:
- @functools.singledispatchmethod
- def t(self, arg):
- self.arg = "base"
- @t.register(int)
- def _(self, arg):
- self.arg = "int"
- @t.register(str)
- def _(self, arg):
- self.arg = "str"
- a = A()
- a.t(0)
- self.assertEqual(a.arg, "int")
- aa = A()
- self.assertFalse(hasattr(aa, 'arg'))
- a.t('')
- self.assertEqual(a.arg, "str")
- aa = A()
- self.assertFalse(hasattr(aa, 'arg'))
- a.t(0.0)
- self.assertEqual(a.arg, "base")
- aa = A()
- self.assertFalse(hasattr(aa, 'arg'))
- def test_staticmethod_register(self):
- class A:
- @functools.singledispatchmethod
- @staticmethod
- def t(arg):
- return arg
- @t.register(int)
- @staticmethod
- def _(arg):
- return isinstance(arg, int)
- @t.register(str)
- @staticmethod
- def _(arg):
- return isinstance(arg, str)
- a = A()
- self.assertTrue(A.t(0))
- self.assertTrue(A.t(''))
- self.assertEqual(A.t(0.0), 0.0)
- def test_classmethod_register(self):
- class A:
- def __init__(self, arg):
- self.arg = arg
- @functools.singledispatchmethod
- @classmethod
- def t(cls, arg):
- return cls("base")
- @t.register(int)
- @classmethod
- def _(cls, arg):
- return cls("int")
- @t.register(str)
- @classmethod
- def _(cls, arg):
- return cls("str")
- self.assertEqual(A.t(0).arg, "int")
- self.assertEqual(A.t('').arg, "str")
- self.assertEqual(A.t(0.0).arg, "base")
- def test_callable_register(self):
- class A:
- def __init__(self, arg):
- self.arg = arg
- @functools.singledispatchmethod
- @classmethod
- def t(cls, arg):
- return cls("base")
- @A.t.register(int)
- @classmethod
- def _(cls, arg):
- return cls("int")
- @A.t.register(str)
- @classmethod
- def _(cls, arg):
- return cls("str")
- self.assertEqual(A.t(0).arg, "int")
- self.assertEqual(A.t('').arg, "str")
- self.assertEqual(A.t(0.0).arg, "base")
- def test_abstractmethod_register(self):
- class Abstract(metaclass=abc.ABCMeta):
- @functools.singledispatchmethod
- @abc.abstractmethod
- def add(self, x, y):
- pass
- self.assertTrue(Abstract.add.__isabstractmethod__)
- self.assertTrue(Abstract.__dict__['add'].__isabstractmethod__)
- with self.assertRaises(TypeError):
- Abstract()
- def test_type_ann_register(self):
- class A:
- @functools.singledispatchmethod
- def t(self, arg):
- return "base"
- @t.register
- def _(self, arg: int):
- return "int"
- @t.register
- def _(self, arg: str):
- return "str"
- a = A()
- self.assertEqual(a.t(0), "int")
- self.assertEqual(a.t(''), "str")
- self.assertEqual(a.t(0.0), "base")
- def test_staticmethod_type_ann_register(self):
- class A:
- @functools.singledispatchmethod
- @staticmethod
- def t(arg):
- return arg
- @t.register
- @staticmethod
- def _(arg: int):
- return isinstance(arg, int)
- @t.register
- @staticmethod
- def _(arg: str):
- return isinstance(arg, str)
- a = A()
- self.assertTrue(A.t(0))
- self.assertTrue(A.t(''))
- self.assertEqual(A.t(0.0), 0.0)
- def test_classmethod_type_ann_register(self):
- class A:
- def __init__(self, arg):
- self.arg = arg
- @functools.singledispatchmethod
- @classmethod
- def t(cls, arg):
- return cls("base")
- @t.register
- @classmethod
- def _(cls, arg: int):
- return cls("int")
- @t.register
- @classmethod
- def _(cls, arg: str):
- return cls("str")
- self.assertEqual(A.t(0).arg, "int")
- self.assertEqual(A.t('').arg, "str")
- self.assertEqual(A.t(0.0).arg, "base")
- def test_method_wrapping_attributes(self):
- class A:
- @functools.singledispatchmethod
- def func(self, arg: int) -> str:
- """My function docstring"""
- return str(arg)
- @functools.singledispatchmethod
- @classmethod
- def cls_func(cls, arg: int) -> str:
- """My function docstring"""
- return str(arg)
- @functools.singledispatchmethod
- @staticmethod
- def static_func(arg: int) -> str:
- """My function docstring"""
- return str(arg)
- for meth in (
- A.func,
- A().func,
- A.cls_func,
- A().cls_func,
- A.static_func,
- A().static_func
- ):
- with self.subTest(meth=meth):
- self.assertEqual(meth.__doc__, 'My function docstring')
- self.assertEqual(meth.__annotations__['arg'], int)
- self.assertEqual(A.func.__name__, 'func')
- self.assertEqual(A().func.__name__, 'func')
- self.assertEqual(A.cls_func.__name__, 'cls_func')
- self.assertEqual(A().cls_func.__name__, 'cls_func')
- self.assertEqual(A.static_func.__name__, 'static_func')
- self.assertEqual(A().static_func.__name__, 'static_func')
- def test_double_wrapped_methods(self):
- def classmethod_friendly_decorator(func):
- wrapped = func.__func__
- @classmethod
- @functools.wraps(wrapped)
- def wrapper(*args, **kwargs):
- return wrapped(*args, **kwargs)
- return wrapper
- class WithoutSingleDispatch:
- @classmethod
- @contextlib.contextmanager
- def cls_context_manager(cls, arg: int) -> str:
- try:
- yield str(arg)
- finally:
- return 'Done'
- @classmethod_friendly_decorator
- @classmethod
- def decorated_classmethod(cls, arg: int) -> str:
- return str(arg)
- class WithSingleDispatch:
- @functools.singledispatchmethod
- @classmethod
- @contextlib.contextmanager
- def cls_context_manager(cls, arg: int) -> str:
- """My function docstring"""
- try:
- yield str(arg)
- finally:
- return 'Done'
- @functools.singledispatchmethod
- @classmethod_friendly_decorator
- @classmethod
- def decorated_classmethod(cls, arg: int) -> str:
- """My function docstring"""
- return str(arg)
- # These are sanity checks
- # to test the test itself is working as expected
- with WithoutSingleDispatch.cls_context_manager(5) as foo:
- without_single_dispatch_foo = foo
- with WithSingleDispatch.cls_context_manager(5) as foo:
- single_dispatch_foo = foo
- self.assertEqual(without_single_dispatch_foo, single_dispatch_foo)
- self.assertEqual(single_dispatch_foo, '5')
- self.assertEqual(
- WithoutSingleDispatch.decorated_classmethod(5),
- WithSingleDispatch.decorated_classmethod(5)
- )
- self.assertEqual(WithSingleDispatch.decorated_classmethod(5), '5')
- # Behavioural checks now follow
- for method_name in ('cls_context_manager', 'decorated_classmethod'):
- with self.subTest(method=method_name):
- self.assertEqual(
- getattr(WithSingleDispatch, method_name).__name__,
- getattr(WithoutSingleDispatch, method_name).__name__
- )
- self.assertEqual(
- getattr(WithSingleDispatch(), method_name).__name__,
- getattr(WithoutSingleDispatch(), method_name).__name__
- )
- for meth in (
- WithSingleDispatch.cls_context_manager,
- WithSingleDispatch().cls_context_manager,
- WithSingleDispatch.decorated_classmethod,
- WithSingleDispatch().decorated_classmethod
- ):
- with self.subTest(meth=meth):
- self.assertEqual(meth.__doc__, 'My function docstring')
- self.assertEqual(meth.__annotations__['arg'], int)
- self.assertEqual(
- WithSingleDispatch.cls_context_manager.__name__,
- 'cls_context_manager'
- )
- self.assertEqual(
- WithSingleDispatch().cls_context_manager.__name__,
- 'cls_context_manager'
- )
- self.assertEqual(
- WithSingleDispatch.decorated_classmethod.__name__,
- 'decorated_classmethod'
- )
- self.assertEqual(
- WithSingleDispatch().decorated_classmethod.__name__,
- 'decorated_classmethod'
- )
- def test_invalid_registrations(self):
- msg_prefix = "Invalid first argument to `register()`: "
- msg_suffix = (
- ". Use either `@register(some_class)` or plain `@register` on an "
- "annotated function."
- )
- @functools.singledispatch
- def i(arg):
- return "base"
- with self.assertRaises(TypeError) as exc:
- @i.register(42)
- def _(arg):
- return "I annotated with a non-type"
- self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
- self.assertTrue(str(exc.exception).endswith(msg_suffix))
- with self.assertRaises(TypeError) as exc:
- @i.register
- def _(arg):
- return "I forgot to annotate"
- self.assertTrue(str(exc.exception).startswith(msg_prefix +
- "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
- ))
- self.assertTrue(str(exc.exception).endswith(msg_suffix))
- with self.assertRaises(TypeError) as exc:
- @i.register
- def _(arg: typing.Iterable[str]):
- # At runtime, dispatching on generics is impossible.
- # When registering implementations with singledispatch, avoid
- # types from `typing`. Instead, annotate with regular types
- # or ABCs.
- return "I annotated with a generic collection"
- self.assertTrue(str(exc.exception).startswith(
- "Invalid annotation for 'arg'."
- ))
- self.assertTrue(str(exc.exception).endswith(
- 'typing.Iterable[str] is not a class.'
- ))
- with self.assertRaises(TypeError) as exc:
- @i.register
- def _(arg: typing.Union[int, typing.Iterable[str]]):
- return "Invalid Union"
- self.assertTrue(str(exc.exception).startswith(
- "Invalid annotation for 'arg'."
- ))
- self.assertTrue(str(exc.exception).endswith(
- 'typing.Union[int, typing.Iterable[str]] not all arguments are classes.'
- ))
- def test_invalid_positional_argument(self):
- @functools.singledispatch
- def f(*args):
- pass
- msg = 'f requires at least 1 positional argument'
- with self.assertRaisesRegex(TypeError, msg):
- f()
- def test_union(self):
- @functools.singledispatch
- def f(arg):
- return "default"
- @f.register
- def _(arg: typing.Union[str, bytes]):
- return "typing.Union"
- @f.register
- def _(arg: int | float):
- return "types.UnionType"
- self.assertEqual(f([]), "default")
- self.assertEqual(f(""), "typing.Union")
- self.assertEqual(f(b""), "typing.Union")
- self.assertEqual(f(1), "types.UnionType")
- self.assertEqual(f(1.0), "types.UnionType")
- def test_union_conflict(self):
- @functools.singledispatch
- def f(arg):
- return "default"
- @f.register
- def _(arg: typing.Union[str, bytes]):
- return "typing.Union"
- @f.register
- def _(arg: int | str):
- return "types.UnionType"
- self.assertEqual(f([]), "default")
- self.assertEqual(f(""), "types.UnionType") # last one wins
- self.assertEqual(f(b""), "typing.Union")
- self.assertEqual(f(1), "types.UnionType")
- def test_union_None(self):
- @functools.singledispatch
- def typing_union(arg):
- return "default"
- @typing_union.register
- def _(arg: typing.Union[str, None]):
- return "typing.Union"
- self.assertEqual(typing_union(1), "default")
- self.assertEqual(typing_union(""), "typing.Union")
- self.assertEqual(typing_union(None), "typing.Union")
- @functools.singledispatch
- def types_union(arg):
- return "default"
- @types_union.register
- def _(arg: int | None):
- return "types.UnionType"
- self.assertEqual(types_union(""), "default")
- self.assertEqual(types_union(1), "types.UnionType")
- self.assertEqual(types_union(None), "types.UnionType")
- def test_register_genericalias(self):
- @functools.singledispatch
- def f(arg):
- return "default"
- with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
- f.register(list[int], lambda arg: "types.GenericAlias")
- with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
- f.register(typing.List[int], lambda arg: "typing.GenericAlias")
- with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
- f.register(list[int] | str, lambda arg: "types.UnionTypes(types.GenericAlias)")
- with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
- f.register(typing.List[float] | bytes, lambda arg: "typing.Union[typing.GenericAlias]")
- self.assertEqual(f([1]), "default")
- self.assertEqual(f([1.0]), "default")
- self.assertEqual(f(""), "default")
- self.assertEqual(f(b""), "default")
- def test_register_genericalias_decorator(self):
- @functools.singledispatch
- def f(arg):
- return "default"
- with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
- f.register(list[int])
- with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
- f.register(typing.List[int])
- with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
- f.register(list[int] | str)
- with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
- f.register(typing.List[int] | str)
- def test_register_genericalias_annotation(self):
- @functools.singledispatch
- def f(arg):
- return "default"
- with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
- @f.register
- def _(arg: list[int]):
- return "types.GenericAlias"
- with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
- @f.register
- def _(arg: typing.List[float]):
- return "typing.GenericAlias"
- with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
- @f.register
- def _(arg: list[int] | str):
- return "types.UnionType(types.GenericAlias)"
- with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
- @f.register
- def _(arg: typing.List[float] | bytes):
- return "typing.Union[typing.GenericAlias]"
- self.assertEqual(f([1]), "default")
- self.assertEqual(f([1.0]), "default")
- self.assertEqual(f(""), "default")
- self.assertEqual(f(b""), "default")
- class CachedCostItem:
- _cost = 1
- def __init__(self):
- self.lock = py_functools.RLock()
- @py_functools.cached_property
- def cost(self):
- """The cost of the item."""
- with self.lock:
- self._cost += 1
- return self._cost
- class OptionallyCachedCostItem:
- _cost = 1
- def get_cost(self):
- """The cost of the item."""
- self._cost += 1
- return self._cost
- cached_cost = py_functools.cached_property(get_cost)
- class CachedCostItemWait:
- def __init__(self, event):
- self._cost = 1
- self.lock = py_functools.RLock()
- self.event = event
- @py_functools.cached_property
- def cost(self):
- self.event.wait(1)
- with self.lock:
- self._cost += 1
- return self._cost
- class CachedCostItemWithSlots:
- __slots__ = ('_cost')
- def __init__(self):
- self._cost = 1
- @py_functools.cached_property
- def cost(self):
- raise RuntimeError('never called, slots not supported')
- class TestCachedProperty(unittest.TestCase):
- def test_cached(self):
- item = CachedCostItem()
- self.assertEqual(item.cost, 2)
- self.assertEqual(item.cost, 2) # not 3
- def test_cached_attribute_name_differs_from_func_name(self):
- item = OptionallyCachedCostItem()
- self.assertEqual(item.get_cost(), 2)
- self.assertEqual(item.cached_cost, 3)
- self.assertEqual(item.get_cost(), 4)
- self.assertEqual(item.cached_cost, 3)
- @threading_helper.requires_working_threading()
- def test_threaded(self):
- go = threading.Event()
- item = CachedCostItemWait(go)
- num_threads = 3
- orig_si = sys.getswitchinterval()
- sys.setswitchinterval(1e-6)
- try:
- threads = [
- threading.Thread(target=lambda: item.cost)
- for k in range(num_threads)
- ]
- with threading_helper.start_threads(threads):
- go.set()
- finally:
- sys.setswitchinterval(orig_si)
- self.assertEqual(item.cost, 2)
- def test_object_with_slots(self):
- item = CachedCostItemWithSlots()
- with self.assertRaisesRegex(
- TypeError,
- "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
- ):
- item.cost
- def test_immutable_dict(self):
- class MyMeta(type):
- @py_functools.cached_property
- def prop(self):
- return True
- class MyClass(metaclass=MyMeta):
- pass
- with self.assertRaisesRegex(
- TypeError,
- "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
- ):
- MyClass.prop
- def test_reuse_different_names(self):
- """Disallow this case because decorated function a would not be cached."""
- with self.assertRaises(RuntimeError) as ctx:
- class ReusedCachedProperty:
- @py_functools.cached_property
- def a(self):
- pass
- b = a
- self.assertEqual(
- str(ctx.exception.__context__),
- str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
- )
- def test_reuse_same_name(self):
- """Reusing a cached_property on different classes under the same name is OK."""
- counter = 0
- @py_functools.cached_property
- def _cp(_self):
- nonlocal counter
- counter += 1
- return counter
- class A:
- cp = _cp
- class B:
- cp = _cp
- a = A()
- b = B()
- self.assertEqual(a.cp, 1)
- self.assertEqual(b.cp, 2)
- self.assertEqual(a.cp, 1)
- def test_set_name_not_called(self):
- cp = py_functools.cached_property(lambda s: None)
- class Foo:
- pass
- Foo.cp = cp
- with self.assertRaisesRegex(
- TypeError,
- "Cannot use cached_property instance without calling __set_name__ on it.",
- ):
- Foo().cp
- def test_access_from_class(self):
- self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
- def test_doc(self):
- self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
- if __name__ == '__main__':
- unittest.main()
|