test_named_expressions.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603
  1. import unittest
  2. GLOBAL_VAR = None
  3. class NamedExpressionInvalidTest(unittest.TestCase):
  4. def test_named_expression_invalid_01(self):
  5. code = """x := 0"""
  6. with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
  7. exec(code, {}, {})
  8. def test_named_expression_invalid_02(self):
  9. code = """x = y := 0"""
  10. with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
  11. exec(code, {}, {})
  12. def test_named_expression_invalid_03(self):
  13. code = """y := f(x)"""
  14. with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
  15. exec(code, {}, {})
  16. def test_named_expression_invalid_04(self):
  17. code = """y0 = y1 := f(x)"""
  18. with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
  19. exec(code, {}, {})
  20. def test_named_expression_invalid_06(self):
  21. code = """((a, b) := (1, 2))"""
  22. with self.assertRaisesRegex(SyntaxError, "cannot use assignment expressions with tuple"):
  23. exec(code, {}, {})
  24. def test_named_expression_invalid_07(self):
  25. code = """def spam(a = b := 42): pass"""
  26. with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
  27. exec(code, {}, {})
  28. def test_named_expression_invalid_08(self):
  29. code = """def spam(a: b := 42 = 5): pass"""
  30. with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
  31. exec(code, {}, {})
  32. def test_named_expression_invalid_09(self):
  33. code = """spam(a=b := 'c')"""
  34. with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
  35. exec(code, {}, {})
  36. def test_named_expression_invalid_10(self):
  37. code = """spam(x = y := f(x))"""
  38. with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
  39. exec(code, {}, {})
  40. def test_named_expression_invalid_11(self):
  41. code = """spam(a=1, b := 2)"""
  42. with self.assertRaisesRegex(SyntaxError,
  43. "positional argument follows keyword argument"):
  44. exec(code, {}, {})
  45. def test_named_expression_invalid_12(self):
  46. code = """spam(a=1, (b := 2))"""
  47. with self.assertRaisesRegex(SyntaxError,
  48. "positional argument follows keyword argument"):
  49. exec(code, {}, {})
  50. def test_named_expression_invalid_13(self):
  51. code = """spam(a=1, (b := 2))"""
  52. with self.assertRaisesRegex(SyntaxError,
  53. "positional argument follows keyword argument"):
  54. exec(code, {}, {})
  55. def test_named_expression_invalid_14(self):
  56. code = """(x := lambda: y := 1)"""
  57. with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
  58. exec(code, {}, {})
  59. def test_named_expression_invalid_15(self):
  60. code = """(lambda: x := 1)"""
  61. with self.assertRaisesRegex(SyntaxError,
  62. "cannot use assignment expressions with lambda"):
  63. exec(code, {}, {})
  64. def test_named_expression_invalid_16(self):
  65. code = "[i + 1 for i in i := [1,2]]"
  66. with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
  67. exec(code, {}, {})
  68. def test_named_expression_invalid_17(self):
  69. code = "[i := 0, j := 1 for i, j in [(1, 2), (3, 4)]]"
  70. with self.assertRaisesRegex(SyntaxError,
  71. "did you forget parentheses around the comprehension target?"):
  72. exec(code, {}, {})
  73. def test_named_expression_invalid_in_class_body(self):
  74. code = """class Foo():
  75. [(42, 1 + ((( j := i )))) for i in range(5)]
  76. """
  77. with self.assertRaisesRegex(SyntaxError,
  78. "assignment expression within a comprehension cannot be used in a class body"):
  79. exec(code, {}, {})
  80. def test_named_expression_invalid_rebinding_list_comprehension_iteration_variable(self):
  81. cases = [
  82. ("Local reuse", 'i', "[i := 0 for i in range(5)]"),
  83. ("Nested reuse", 'j', "[[(j := 0) for i in range(5)] for j in range(5)]"),
  84. ("Reuse inner loop target", 'j', "[(j := 0) for i in range(5) for j in range(5)]"),
  85. ("Unpacking reuse", 'i', "[i := 0 for i, j in [(0, 1)]]"),
  86. ("Reuse in loop condition", 'i', "[i+1 for i in range(5) if (i := 0)]"),
  87. ("Unreachable reuse", 'i', "[False or (i:=0) for i in range(5)]"),
  88. ("Unreachable nested reuse", 'i',
  89. "[(i, j) for i in range(5) for j in range(5) if True or (i:=10)]"),
  90. ]
  91. for case, target, code in cases:
  92. msg = f"assignment expression cannot rebind comprehension iteration variable '{target}'"
  93. with self.subTest(case=case):
  94. with self.assertRaisesRegex(SyntaxError, msg):
  95. exec(code, {}, {})
  96. def test_named_expression_invalid_rebinding_list_comprehension_inner_loop(self):
  97. cases = [
  98. ("Inner reuse", 'j', "[i for i in range(5) if (j := 0) for j in range(5)]"),
  99. ("Inner unpacking reuse", 'j', "[i for i in range(5) if (j := 0) for j, k in [(0, 1)]]"),
  100. ]
  101. for case, target, code in cases:
  102. msg = f"comprehension inner loop cannot rebind assignment expression target '{target}'"
  103. with self.subTest(case=case):
  104. with self.assertRaisesRegex(SyntaxError, msg):
  105. exec(code, {}) # Module scope
  106. with self.assertRaisesRegex(SyntaxError, msg):
  107. exec(code, {}, {}) # Class scope
  108. with self.assertRaisesRegex(SyntaxError, msg):
  109. exec(f"lambda: {code}", {}) # Function scope
  110. def test_named_expression_invalid_list_comprehension_iterable_expression(self):
  111. cases = [
  112. ("Top level", "[i for i in (i := range(5))]"),
  113. ("Inside tuple", "[i for i in (2, 3, i := range(5))]"),
  114. ("Inside list", "[i for i in [2, 3, i := range(5)]]"),
  115. ("Different name", "[i for i in (j := range(5))]"),
  116. ("Lambda expression", "[i for i in (lambda:(j := range(5)))()]"),
  117. ("Inner loop", "[i for i in range(5) for j in (i := range(5))]"),
  118. ("Nested comprehension", "[i for i in [j for j in (k := range(5))]]"),
  119. ("Nested comprehension condition", "[i for i in [j for j in range(5) if (j := True)]]"),
  120. ("Nested comprehension body", "[i for i in [(j := True) for j in range(5)]]"),
  121. ]
  122. msg = "assignment expression cannot be used in a comprehension iterable expression"
  123. for case, code in cases:
  124. with self.subTest(case=case):
  125. with self.assertRaisesRegex(SyntaxError, msg):
  126. exec(code, {}) # Module scope
  127. with self.assertRaisesRegex(SyntaxError, msg):
  128. exec(code, {}, {}) # Class scope
  129. with self.assertRaisesRegex(SyntaxError, msg):
  130. exec(f"lambda: {code}", {}) # Function scope
  131. def test_named_expression_invalid_rebinding_set_comprehension_iteration_variable(self):
  132. cases = [
  133. ("Local reuse", 'i', "{i := 0 for i in range(5)}"),
  134. ("Nested reuse", 'j', "{{(j := 0) for i in range(5)} for j in range(5)}"),
  135. ("Reuse inner loop target", 'j', "{(j := 0) for i in range(5) for j in range(5)}"),
  136. ("Unpacking reuse", 'i', "{i := 0 for i, j in {(0, 1)}}"),
  137. ("Reuse in loop condition", 'i', "{i+1 for i in range(5) if (i := 0)}"),
  138. ("Unreachable reuse", 'i', "{False or (i:=0) for i in range(5)}"),
  139. ("Unreachable nested reuse", 'i',
  140. "{(i, j) for i in range(5) for j in range(5) if True or (i:=10)}"),
  141. ]
  142. for case, target, code in cases:
  143. msg = f"assignment expression cannot rebind comprehension iteration variable '{target}'"
  144. with self.subTest(case=case):
  145. with self.assertRaisesRegex(SyntaxError, msg):
  146. exec(code, {}, {})
  147. def test_named_expression_invalid_rebinding_set_comprehension_inner_loop(self):
  148. cases = [
  149. ("Inner reuse", 'j', "{i for i in range(5) if (j := 0) for j in range(5)}"),
  150. ("Inner unpacking reuse", 'j', "{i for i in range(5) if (j := 0) for j, k in {(0, 1)}}"),
  151. ]
  152. for case, target, code in cases:
  153. msg = f"comprehension inner loop cannot rebind assignment expression target '{target}'"
  154. with self.subTest(case=case):
  155. with self.assertRaisesRegex(SyntaxError, msg):
  156. exec(code, {}) # Module scope
  157. with self.assertRaisesRegex(SyntaxError, msg):
  158. exec(code, {}, {}) # Class scope
  159. with self.assertRaisesRegex(SyntaxError, msg):
  160. exec(f"lambda: {code}", {}) # Function scope
  161. def test_named_expression_invalid_set_comprehension_iterable_expression(self):
  162. cases = [
  163. ("Top level", "{i for i in (i := range(5))}"),
  164. ("Inside tuple", "{i for i in (2, 3, i := range(5))}"),
  165. ("Inside list", "{i for i in {2, 3, i := range(5)}}"),
  166. ("Different name", "{i for i in (j := range(5))}"),
  167. ("Lambda expression", "{i for i in (lambda:(j := range(5)))()}"),
  168. ("Inner loop", "{i for i in range(5) for j in (i := range(5))}"),
  169. ("Nested comprehension", "{i for i in {j for j in (k := range(5))}}"),
  170. ("Nested comprehension condition", "{i for i in {j for j in range(5) if (j := True)}}"),
  171. ("Nested comprehension body", "{i for i in {(j := True) for j in range(5)}}"),
  172. ]
  173. msg = "assignment expression cannot be used in a comprehension iterable expression"
  174. for case, code in cases:
  175. with self.subTest(case=case):
  176. with self.assertRaisesRegex(SyntaxError, msg):
  177. exec(code, {}) # Module scope
  178. with self.assertRaisesRegex(SyntaxError, msg):
  179. exec(code, {}, {}) # Class scope
  180. with self.assertRaisesRegex(SyntaxError, msg):
  181. exec(f"lambda: {code}", {}) # Function scope
  182. class NamedExpressionAssignmentTest(unittest.TestCase):
  183. def test_named_expression_assignment_01(self):
  184. (a := 10)
  185. self.assertEqual(a, 10)
  186. def test_named_expression_assignment_02(self):
  187. a = 20
  188. (a := a)
  189. self.assertEqual(a, 20)
  190. def test_named_expression_assignment_03(self):
  191. (total := 1 + 2)
  192. self.assertEqual(total, 3)
  193. def test_named_expression_assignment_04(self):
  194. (info := (1, 2, 3))
  195. self.assertEqual(info, (1, 2, 3))
  196. def test_named_expression_assignment_05(self):
  197. (x := 1, 2)
  198. self.assertEqual(x, 1)
  199. def test_named_expression_assignment_06(self):
  200. (z := (y := (x := 0)))
  201. self.assertEqual(x, 0)
  202. self.assertEqual(y, 0)
  203. self.assertEqual(z, 0)
  204. def test_named_expression_assignment_07(self):
  205. (loc := (1, 2))
  206. self.assertEqual(loc, (1, 2))
  207. def test_named_expression_assignment_08(self):
  208. if spam := "eggs":
  209. self.assertEqual(spam, "eggs")
  210. else: self.fail("variable was not assigned using named expression")
  211. def test_named_expression_assignment_09(self):
  212. if True and (spam := True):
  213. self.assertTrue(spam)
  214. else: self.fail("variable was not assigned using named expression")
  215. def test_named_expression_assignment_10(self):
  216. if (match := 10) == 10:
  217. pass
  218. else: self.fail("variable was not assigned using named expression")
  219. def test_named_expression_assignment_11(self):
  220. def spam(a):
  221. return a
  222. input_data = [1, 2, 3]
  223. res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0]
  224. self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)])
  225. def test_named_expression_assignment_12(self):
  226. def spam(a):
  227. return a
  228. res = [[y := spam(x), x/y] for x in range(1, 5)]
  229. self.assertEqual(res, [[1, 1.0], [2, 1.0], [3, 1.0], [4, 1.0]])
  230. def test_named_expression_assignment_13(self):
  231. length = len(lines := [1, 2])
  232. self.assertEqual(length, 2)
  233. self.assertEqual(lines, [1,2])
  234. def test_named_expression_assignment_14(self):
  235. """
  236. Where all variables are positive integers, and a is at least as large
  237. as the n'th root of x, this algorithm returns the floor of the n'th
  238. root of x (and roughly doubling the number of accurate bits per
  239. iteration):
  240. """
  241. a = 9
  242. n = 2
  243. x = 3
  244. while a > (d := x // a**(n-1)):
  245. a = ((n-1)*a + d) // n
  246. self.assertEqual(a, 1)
  247. def test_named_expression_assignment_15(self):
  248. while a := False:
  249. pass # This will not run
  250. self.assertEqual(a, False)
  251. def test_named_expression_assignment_16(self):
  252. a, b = 1, 2
  253. fib = {(c := a): (a := b) + (b := a + c) - b for __ in range(6)}
  254. self.assertEqual(fib, {1: 2, 2: 3, 3: 5, 5: 8, 8: 13, 13: 21})
  255. def test_named_expression_assignment_17(self):
  256. a = [1]
  257. element = a[b:=0]
  258. self.assertEqual(b, 0)
  259. self.assertEqual(element, a[0])
  260. def test_named_expression_assignment_18(self):
  261. class TwoDimensionalList:
  262. def __init__(self, two_dimensional_list):
  263. self.two_dimensional_list = two_dimensional_list
  264. def __getitem__(self, index):
  265. return self.two_dimensional_list[index[0]][index[1]]
  266. a = TwoDimensionalList([[1], [2]])
  267. element = a[b:=0, c:=0]
  268. self.assertEqual(b, 0)
  269. self.assertEqual(c, 0)
  270. self.assertEqual(element, a.two_dimensional_list[b][c])
  271. class NamedExpressionScopeTest(unittest.TestCase):
  272. def test_named_expression_scope_01(self):
  273. code = """def spam():
  274. (a := 5)
  275. print(a)"""
  276. with self.assertRaisesRegex(NameError, "name 'a' is not defined"):
  277. exec(code, {}, {})
  278. def test_named_expression_scope_02(self):
  279. total = 0
  280. partial_sums = [total := total + v for v in range(5)]
  281. self.assertEqual(partial_sums, [0, 1, 3, 6, 10])
  282. self.assertEqual(total, 10)
  283. def test_named_expression_scope_03(self):
  284. containsOne = any((lastNum := num) == 1 for num in [1, 2, 3])
  285. self.assertTrue(containsOne)
  286. self.assertEqual(lastNum, 1)
  287. def test_named_expression_scope_04(self):
  288. def spam(a):
  289. return a
  290. res = [[y := spam(x), x/y] for x in range(1, 5)]
  291. self.assertEqual(y, 4)
  292. def test_named_expression_scope_05(self):
  293. def spam(a):
  294. return a
  295. input_data = [1, 2, 3]
  296. res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0]
  297. self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)])
  298. self.assertEqual(y, 3)
  299. def test_named_expression_scope_06(self):
  300. res = [[spam := i for i in range(3)] for j in range(2)]
  301. self.assertEqual(res, [[0, 1, 2], [0, 1, 2]])
  302. self.assertEqual(spam, 2)
  303. def test_named_expression_scope_07(self):
  304. len(lines := [1, 2])
  305. self.assertEqual(lines, [1, 2])
  306. def test_named_expression_scope_08(self):
  307. def spam(a):
  308. return a
  309. def eggs(b):
  310. return b * 2
  311. res = [spam(a := eggs(b := h)) for h in range(2)]
  312. self.assertEqual(res, [0, 2])
  313. self.assertEqual(a, 2)
  314. self.assertEqual(b, 1)
  315. def test_named_expression_scope_09(self):
  316. def spam(a):
  317. return a
  318. def eggs(b):
  319. return b * 2
  320. res = [spam(a := eggs(a := h)) for h in range(2)]
  321. self.assertEqual(res, [0, 2])
  322. self.assertEqual(a, 2)
  323. def test_named_expression_scope_10(self):
  324. res = [b := [a := 1 for i in range(2)] for j in range(2)]
  325. self.assertEqual(res, [[1, 1], [1, 1]])
  326. self.assertEqual(a, 1)
  327. self.assertEqual(b, [1, 1])
  328. def test_named_expression_scope_11(self):
  329. res = [j := i for i in range(5)]
  330. self.assertEqual(res, [0, 1, 2, 3, 4])
  331. self.assertEqual(j, 4)
  332. def test_named_expression_scope_17(self):
  333. b = 0
  334. res = [b := i + b for i in range(5)]
  335. self.assertEqual(res, [0, 1, 3, 6, 10])
  336. self.assertEqual(b, 10)
  337. def test_named_expression_scope_18(self):
  338. def spam(a):
  339. return a
  340. res = spam(b := 2)
  341. self.assertEqual(res, 2)
  342. self.assertEqual(b, 2)
  343. def test_named_expression_scope_19(self):
  344. def spam(a):
  345. return a
  346. res = spam((b := 2))
  347. self.assertEqual(res, 2)
  348. self.assertEqual(b, 2)
  349. def test_named_expression_scope_20(self):
  350. def spam(a):
  351. return a
  352. res = spam(a=(b := 2))
  353. self.assertEqual(res, 2)
  354. self.assertEqual(b, 2)
  355. def test_named_expression_scope_21(self):
  356. def spam(a, b):
  357. return a + b
  358. res = spam(c := 2, b=1)
  359. self.assertEqual(res, 3)
  360. self.assertEqual(c, 2)
  361. def test_named_expression_scope_22(self):
  362. def spam(a, b):
  363. return a + b
  364. res = spam((c := 2), b=1)
  365. self.assertEqual(res, 3)
  366. self.assertEqual(c, 2)
  367. def test_named_expression_scope_23(self):
  368. def spam(a, b):
  369. return a + b
  370. res = spam(b=(c := 2), a=1)
  371. self.assertEqual(res, 3)
  372. self.assertEqual(c, 2)
  373. def test_named_expression_scope_24(self):
  374. a = 10
  375. def spam():
  376. nonlocal a
  377. (a := 20)
  378. spam()
  379. self.assertEqual(a, 20)
  380. def test_named_expression_scope_25(self):
  381. ns = {}
  382. code = """a = 10
  383. def spam():
  384. global a
  385. (a := 20)
  386. spam()"""
  387. exec(code, ns, {})
  388. self.assertEqual(ns["a"], 20)
  389. def test_named_expression_variable_reuse_in_comprehensions(self):
  390. # The compiler is expected to raise syntax error for comprehension
  391. # iteration variables, but should be fine with rebinding of other
  392. # names (e.g. globals, nonlocals, other assignment expressions)
  393. # The cases are all defined to produce the same expected result
  394. # Each comprehension is checked at both function scope and module scope
  395. rebinding = "[x := i for i in range(3) if (x := i) or not x]"
  396. filter_ref = "[x := i for i in range(3) if x or not x]"
  397. body_ref = "[x for i in range(3) if (x := i) or not x]"
  398. nested_ref = "[j for i in range(3) if x or not x for j in range(3) if (x := i)][:-3]"
  399. cases = [
  400. ("Rebind global", f"x = 1; result = {rebinding}"),
  401. ("Rebind nonlocal", f"result, x = (lambda x=1: ({rebinding}, x))()"),
  402. ("Filter global", f"x = 1; result = {filter_ref}"),
  403. ("Filter nonlocal", f"result, x = (lambda x=1: ({filter_ref}, x))()"),
  404. ("Body global", f"x = 1; result = {body_ref}"),
  405. ("Body nonlocal", f"result, x = (lambda x=1: ({body_ref}, x))()"),
  406. ("Nested global", f"x = 1; result = {nested_ref}"),
  407. ("Nested nonlocal", f"result, x = (lambda x=1: ({nested_ref}, x))()"),
  408. ]
  409. for case, code in cases:
  410. with self.subTest(case=case):
  411. ns = {}
  412. exec(code, ns)
  413. self.assertEqual(ns["x"], 2)
  414. self.assertEqual(ns["result"], [0, 1, 2])
  415. def test_named_expression_global_scope(self):
  416. sentinel = object()
  417. global GLOBAL_VAR
  418. def f():
  419. global GLOBAL_VAR
  420. [GLOBAL_VAR := sentinel for _ in range(1)]
  421. self.assertEqual(GLOBAL_VAR, sentinel)
  422. try:
  423. f()
  424. self.assertEqual(GLOBAL_VAR, sentinel)
  425. finally:
  426. GLOBAL_VAR = None
  427. def test_named_expression_global_scope_no_global_keyword(self):
  428. sentinel = object()
  429. def f():
  430. GLOBAL_VAR = None
  431. [GLOBAL_VAR := sentinel for _ in range(1)]
  432. self.assertEqual(GLOBAL_VAR, sentinel)
  433. f()
  434. self.assertEqual(GLOBAL_VAR, None)
  435. def test_named_expression_nonlocal_scope(self):
  436. sentinel = object()
  437. def f():
  438. nonlocal_var = None
  439. def g():
  440. nonlocal nonlocal_var
  441. [nonlocal_var := sentinel for _ in range(1)]
  442. g()
  443. self.assertEqual(nonlocal_var, sentinel)
  444. f()
  445. def test_named_expression_nonlocal_scope_no_nonlocal_keyword(self):
  446. sentinel = object()
  447. def f():
  448. nonlocal_var = None
  449. def g():
  450. [nonlocal_var := sentinel for _ in range(1)]
  451. g()
  452. self.assertEqual(nonlocal_var, None)
  453. f()
  454. def test_named_expression_scope_in_genexp(self):
  455. a = 1
  456. b = [1, 2, 3, 4]
  457. genexp = (c := i + a for i in b)
  458. self.assertNotIn("c", locals())
  459. for idx, elem in enumerate(genexp):
  460. self.assertEqual(elem, b[idx] + a)
  461. if __name__ == "__main__":
  462. unittest.main()