_runner.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769
  1. from __future__ import annotations
  2. import asyncio
  3. import concurrent.futures
  4. import inspect
  5. import threading
  6. import time
  7. import weakref
  8. from collections.abc import (
  9. AsyncIterator,
  10. Awaitable,
  11. Callable,
  12. Iterable,
  13. Iterator,
  14. Sequence,
  15. )
  16. from functools import partial
  17. from typing import (
  18. Any,
  19. Generic,
  20. TypeVar,
  21. cast,
  22. )
  23. from langchain_core.callbacks import Callbacks
  24. from langgraph._internal._constants import (
  25. CONF,
  26. CONFIG_KEY_CALL,
  27. CONFIG_KEY_SCRATCHPAD,
  28. ERROR,
  29. INTERRUPT,
  30. NO_WRITES,
  31. RESUME,
  32. RETURN,
  33. )
  34. from langgraph._internal._future import chain_future, run_coroutine_threadsafe
  35. from langgraph._internal._scratchpad import PregelScratchpad
  36. from langgraph._internal._typing import MISSING
  37. from langgraph.constants import TAG_HIDDEN
  38. from langgraph.errors import GraphBubbleUp, GraphInterrupt
  39. from langgraph.pregel._algo import Call
  40. from langgraph.pregel._executor import Submit
  41. from langgraph.pregel._retry import arun_with_retry, run_with_retry
  42. from langgraph.types import (
  43. CachePolicy,
  44. PregelExecutableTask,
  45. RetryPolicy,
  46. )
  47. F = TypeVar("F", concurrent.futures.Future, asyncio.Future)
  48. E = TypeVar("E", threading.Event, asyncio.Event)
  49. # List of filenames to exclude from exception traceback
  50. # Note: Frames will be removed if they are the last frame in traceback, recursively
  51. EXCLUDED_FRAME_FNAMES = (
  52. "langgraph/pregel/retry.py",
  53. "langgraph/pregel/runner.py",
  54. "langgraph/pregel/executor.py",
  55. "langgraph/utils/runnable.py",
  56. "langchain_core/runnables/config.py",
  57. "concurrent/futures/thread.py",
  58. "concurrent/futures/_base.py",
  59. )
  60. SKIP_RERAISE_SET: weakref.WeakSet[concurrent.futures.Future | asyncio.Future] = (
  61. weakref.WeakSet()
  62. )
  63. class FuturesDict(Generic[F, E], dict[F, PregelExecutableTask | None]):
  64. event: E
  65. callback: weakref.ref[Callable[[PregelExecutableTask, BaseException | None], None]]
  66. counter: int
  67. done: set[F]
  68. lock: threading.Lock
  69. def __init__(
  70. self,
  71. event: E,
  72. callback: weakref.ref[
  73. Callable[[PregelExecutableTask, BaseException | None], None]
  74. ],
  75. future_type: type[F],
  76. # used for generic typing, newer py supports FutureDict[...](...)
  77. ) -> None:
  78. super().__init__()
  79. self.lock = threading.Lock()
  80. self.event = event
  81. self.callback = callback
  82. self.counter = 0
  83. self.done: set[F] = set()
  84. def __setitem__(
  85. self,
  86. key: F,
  87. value: PregelExecutableTask | None,
  88. ) -> None:
  89. super().__setitem__(key, value) # type: ignore[index]
  90. if value is not None:
  91. with self.lock:
  92. self.event.clear()
  93. self.counter += 1
  94. key.add_done_callback(partial(self.on_done, value))
  95. def on_done(
  96. self,
  97. task: PregelExecutableTask,
  98. fut: F,
  99. ) -> None:
  100. try:
  101. if cb := self.callback():
  102. cb(task, _exception(fut))
  103. finally:
  104. with self.lock:
  105. self.done.add(fut)
  106. self.counter -= 1
  107. if self.counter == 0 or _should_stop_others(self.done):
  108. self.event.set()
  109. class PregelRunner:
  110. """Responsible for executing a set of Pregel tasks concurrently, committing
  111. their writes, yielding control to caller when there is output to emit, and
  112. interrupting other tasks if appropriate."""
  113. def __init__(
  114. self,
  115. *,
  116. submit: weakref.ref[Submit],
  117. put_writes: weakref.ref[Callable[[str, Sequence[tuple[str, Any]]], None]],
  118. use_astream: bool = False,
  119. node_finished: Callable[[str], None] | None = None,
  120. ) -> None:
  121. self.submit = submit
  122. self.put_writes = put_writes
  123. self.use_astream = use_astream
  124. self.node_finished = node_finished
  125. def tick(
  126. self,
  127. tasks: Iterable[PregelExecutableTask],
  128. *,
  129. reraise: bool = True,
  130. timeout: float | None = None,
  131. retry_policy: Sequence[RetryPolicy] | None = None,
  132. get_waiter: Callable[[], concurrent.futures.Future[None]] | None = None,
  133. schedule_task: Callable[
  134. [PregelExecutableTask, int, Call | None],
  135. PregelExecutableTask | None,
  136. ],
  137. ) -> Iterator[None]:
  138. tasks = tuple(tasks)
  139. futures = FuturesDict(
  140. callback=weakref.WeakMethod(self.commit),
  141. event=threading.Event(),
  142. future_type=concurrent.futures.Future,
  143. )
  144. # give control back to the caller
  145. yield
  146. # fast path if single task with no timeout and no waiter
  147. if len(tasks) == 0:
  148. return
  149. elif len(tasks) == 1 and timeout is None and get_waiter is None:
  150. t = tasks[0]
  151. try:
  152. run_with_retry(
  153. t,
  154. retry_policy,
  155. configurable={
  156. CONFIG_KEY_CALL: partial(
  157. _call,
  158. weakref.ref(t),
  159. retry_policy=retry_policy,
  160. futures=weakref.ref(futures),
  161. schedule_task=schedule_task,
  162. submit=self.submit,
  163. ),
  164. },
  165. )
  166. self.commit(t, None)
  167. except Exception as exc:
  168. self.commit(t, exc)
  169. if reraise and futures:
  170. # will be re-raised after futures are done
  171. fut: concurrent.futures.Future = concurrent.futures.Future()
  172. fut.set_exception(exc)
  173. futures.done.add(fut)
  174. elif reraise:
  175. if tb := exc.__traceback__:
  176. while tb.tb_next is not None and any(
  177. tb.tb_frame.f_code.co_filename.endswith(name)
  178. for name in EXCLUDED_FRAME_FNAMES
  179. ):
  180. tb = tb.tb_next
  181. exc.__traceback__ = tb
  182. raise
  183. if not futures: # maybe `t` scheduled another task
  184. return
  185. else:
  186. tasks = () # don't reschedule this task
  187. # add waiter task if requested
  188. if get_waiter is not None:
  189. futures[get_waiter()] = None
  190. # schedule tasks
  191. for t in tasks:
  192. fut = self.submit()( # type: ignore[misc]
  193. run_with_retry,
  194. t,
  195. retry_policy,
  196. configurable={
  197. CONFIG_KEY_CALL: partial(
  198. _call,
  199. weakref.ref(t),
  200. retry_policy=retry_policy,
  201. futures=weakref.ref(futures),
  202. schedule_task=schedule_task,
  203. submit=self.submit,
  204. ),
  205. },
  206. __reraise_on_exit__=reraise,
  207. )
  208. futures[fut] = t
  209. # execute tasks, and wait for one to fail or all to finish.
  210. # each task is independent from all other concurrent tasks
  211. # yield updates/debug output as each task finishes
  212. end_time = timeout + time.monotonic() if timeout else None
  213. while len(futures) > (1 if get_waiter is not None else 0):
  214. done, inflight = concurrent.futures.wait(
  215. futures,
  216. return_when=concurrent.futures.FIRST_COMPLETED,
  217. timeout=(max(0, end_time - time.monotonic()) if end_time else None),
  218. )
  219. if not done:
  220. break # timed out
  221. for fut in done:
  222. task = futures.pop(fut)
  223. if task is None:
  224. # waiter task finished, schedule another
  225. if inflight and get_waiter is not None:
  226. futures[get_waiter()] = None
  227. else:
  228. # remove references to loop vars
  229. del fut, task
  230. # maybe stop other tasks
  231. if _should_stop_others(done):
  232. break
  233. # give control back to the caller
  234. yield
  235. # wait for done callbacks
  236. futures.event.wait(
  237. timeout=(max(0, end_time - time.monotonic()) if end_time else None)
  238. )
  239. # give control back to the caller
  240. yield
  241. # panic on failure or timeout
  242. try:
  243. _panic_or_proceed(
  244. futures.done.union(f for f, t in futures.items() if t is not None),
  245. panic=reraise,
  246. )
  247. except Exception as exc:
  248. if tb := exc.__traceback__:
  249. while tb.tb_next is not None and any(
  250. tb.tb_frame.f_code.co_filename.endswith(name)
  251. for name in EXCLUDED_FRAME_FNAMES
  252. ):
  253. tb = tb.tb_next
  254. exc.__traceback__ = tb
  255. raise
  256. async def atick(
  257. self,
  258. tasks: Iterable[PregelExecutableTask],
  259. *,
  260. reraise: bool = True,
  261. timeout: float | None = None,
  262. retry_policy: Sequence[RetryPolicy] | None = None,
  263. get_waiter: Callable[[], asyncio.Future[None]] | None = None,
  264. schedule_task: Callable[
  265. [PregelExecutableTask, int, Call | None],
  266. Awaitable[PregelExecutableTask | None],
  267. ],
  268. ) -> AsyncIterator[None]:
  269. try:
  270. loop = asyncio.get_event_loop()
  271. except RuntimeError:
  272. loop = asyncio.new_event_loop()
  273. asyncio.set_event_loop(loop)
  274. tasks = tuple(tasks)
  275. futures = FuturesDict(
  276. callback=weakref.WeakMethod(self.commit),
  277. event=asyncio.Event(),
  278. future_type=asyncio.Future,
  279. )
  280. # give control back to the caller
  281. yield
  282. # fast path if single task with no waiter and no timeout
  283. if len(tasks) == 0:
  284. return
  285. elif len(tasks) == 1 and get_waiter is None and timeout is None:
  286. t = tasks[0]
  287. try:
  288. await arun_with_retry(
  289. t,
  290. retry_policy,
  291. stream=self.use_astream,
  292. configurable={
  293. CONFIG_KEY_CALL: partial(
  294. _acall,
  295. weakref.ref(t),
  296. stream=self.use_astream,
  297. retry_policy=retry_policy,
  298. futures=weakref.ref(futures),
  299. schedule_task=schedule_task,
  300. submit=self.submit,
  301. loop=loop,
  302. ),
  303. },
  304. )
  305. self.commit(t, None)
  306. except Exception as exc:
  307. self.commit(t, exc)
  308. if reraise and futures:
  309. # will be re-raised after futures are done
  310. fut: asyncio.Future = loop.create_future()
  311. fut.set_exception(exc)
  312. futures.done.add(fut)
  313. elif reraise:
  314. if tb := exc.__traceback__:
  315. while tb.tb_next is not None and any(
  316. tb.tb_frame.f_code.co_filename.endswith(name)
  317. for name in EXCLUDED_FRAME_FNAMES
  318. ):
  319. tb = tb.tb_next
  320. exc.__traceback__ = tb
  321. raise
  322. if not futures: # maybe `t` scheduled another task
  323. return
  324. else:
  325. tasks = () # don't reschedule this task
  326. # add waiter task if requested
  327. if get_waiter is not None:
  328. futures[get_waiter()] = None
  329. # schedule tasks
  330. for t in tasks:
  331. fut = cast(
  332. asyncio.Future,
  333. self.submit()( # type: ignore[misc]
  334. arun_with_retry,
  335. t,
  336. retry_policy,
  337. stream=self.use_astream,
  338. configurable={
  339. CONFIG_KEY_CALL: partial(
  340. _acall,
  341. weakref.ref(t),
  342. retry_policy=retry_policy,
  343. stream=self.use_astream,
  344. futures=weakref.ref(futures),
  345. schedule_task=schedule_task,
  346. submit=self.submit,
  347. loop=loop,
  348. ),
  349. },
  350. __name__=t.name,
  351. __cancel_on_exit__=True,
  352. __reraise_on_exit__=reraise,
  353. ),
  354. )
  355. futures[fut] = t
  356. # execute tasks, and wait for one to fail or all to finish.
  357. # each task is independent from all other concurrent tasks
  358. # yield updates/debug output as each task finishes
  359. end_time = timeout + loop.time() if timeout else None
  360. while len(futures) > (1 if get_waiter is not None else 0):
  361. done, inflight = await asyncio.wait(
  362. futures,
  363. return_when=asyncio.FIRST_COMPLETED,
  364. timeout=(max(0, end_time - loop.time()) if end_time else None),
  365. )
  366. if not done:
  367. break # timed out
  368. for fut in done:
  369. task = futures.pop(fut)
  370. if task is None:
  371. # waiter task finished, schedule another
  372. if inflight and get_waiter is not None:
  373. futures[get_waiter()] = None
  374. else:
  375. # remove references to loop vars
  376. del fut, task
  377. # maybe stop other tasks
  378. if _should_stop_others(done):
  379. break
  380. # give control back to the caller
  381. yield
  382. # wait for done callbacks
  383. await asyncio.wait_for(
  384. futures.event.wait(),
  385. timeout=(max(0, end_time - loop.time()) if end_time else None),
  386. )
  387. # give control back to the caller
  388. yield
  389. # cancel waiter task
  390. for fut in futures:
  391. fut.cancel()
  392. # panic on failure or timeout
  393. try:
  394. _panic_or_proceed(
  395. futures.done.union(f for f, t in futures.items() if t is not None),
  396. timeout_exc_cls=asyncio.TimeoutError,
  397. panic=reraise,
  398. )
  399. except Exception as exc:
  400. if tb := exc.__traceback__:
  401. while tb.tb_next is not None and any(
  402. tb.tb_frame.f_code.co_filename.endswith(name)
  403. for name in EXCLUDED_FRAME_FNAMES
  404. ):
  405. tb = tb.tb_next
  406. exc.__traceback__ = tb
  407. raise
  408. def commit(
  409. self,
  410. task: PregelExecutableTask,
  411. exception: BaseException | None,
  412. ) -> None:
  413. if isinstance(exception, asyncio.CancelledError):
  414. # for cancelled tasks, also save error in task,
  415. # so loop can finish super-step
  416. task.writes.append((ERROR, exception))
  417. self.put_writes()(task.id, task.writes) # type: ignore[misc]
  418. elif exception:
  419. if isinstance(exception, GraphInterrupt):
  420. # save interrupt to checkpointer
  421. if exception.args[0]:
  422. writes = [(INTERRUPT, exception.args[0])]
  423. if resumes := [w for w in task.writes if w[0] == RESUME]:
  424. writes.extend(resumes)
  425. self.put_writes()(task.id, writes) # type: ignore[misc]
  426. elif isinstance(exception, GraphBubbleUp):
  427. # exception will be raised in _panic_or_proceed
  428. pass
  429. else:
  430. # save error to checkpointer
  431. task.writes.append((ERROR, exception))
  432. self.put_writes()(task.id, task.writes) # type: ignore[misc]
  433. else:
  434. if self.node_finished and (
  435. task.config is None or TAG_HIDDEN not in task.config.get("tags", [])
  436. ):
  437. self.node_finished(task.name)
  438. if not task.writes:
  439. # add no writes marker
  440. task.writes.append((NO_WRITES, None))
  441. # save task writes to checkpointer
  442. self.put_writes()(task.id, task.writes) # type: ignore[misc]
  443. def _should_stop_others(
  444. done: set[F],
  445. ) -> bool:
  446. """Check if any task failed, if so, cancel all other tasks.
  447. GraphInterrupts are not considered failures."""
  448. for fut in done:
  449. if fut.cancelled():
  450. continue
  451. elif exc := fut.exception():
  452. if not isinstance(exc, GraphBubbleUp) and fut not in SKIP_RERAISE_SET:
  453. return True
  454. return False
  455. def _exception(
  456. fut: concurrent.futures.Future[Any] | asyncio.Future[Any],
  457. ) -> BaseException | None:
  458. """Return the exception from a future, without raising CancelledError."""
  459. if fut.cancelled():
  460. if isinstance(fut, asyncio.Future):
  461. return asyncio.CancelledError()
  462. else:
  463. return concurrent.futures.CancelledError()
  464. else:
  465. return fut.exception()
  466. def _panic_or_proceed(
  467. futs: set[concurrent.futures.Future] | set[asyncio.Future],
  468. *,
  469. timeout_exc_cls: type[Exception] = TimeoutError,
  470. panic: bool = True,
  471. ) -> None:
  472. """Cancel remaining tasks if any failed, re-raise exception if panic is True."""
  473. done: set[concurrent.futures.Future[Any] | asyncio.Future[Any]] = set()
  474. inflight: set[concurrent.futures.Future[Any] | asyncio.Future[Any]] = set()
  475. for fut in futs:
  476. if fut.cancelled():
  477. continue
  478. elif fut.done():
  479. done.add(fut)
  480. else:
  481. inflight.add(fut)
  482. interrupts: list[GraphInterrupt] = []
  483. while done:
  484. # if any task failed
  485. fut = done.pop()
  486. if exc := _exception(fut):
  487. # cancel all pending tasks
  488. while inflight:
  489. inflight.pop().cancel()
  490. # raise the exception
  491. if panic:
  492. if isinstance(exc, GraphInterrupt):
  493. # collect interrupts
  494. interrupts.append(exc)
  495. elif fut not in SKIP_RERAISE_SET:
  496. raise exc
  497. # raise combined interrupts
  498. if interrupts:
  499. raise GraphInterrupt(tuple(i for exc in interrupts for i in exc.args[0]))
  500. if inflight:
  501. # if we got here means we timed out
  502. while inflight:
  503. # cancel all pending tasks
  504. inflight.pop().cancel()
  505. # raise timeout error
  506. raise timeout_exc_cls("Timed out")
  507. def _call(
  508. task: weakref.ref[PregelExecutableTask],
  509. func: Callable[[Any], Awaitable[Any] | Any],
  510. input: Any,
  511. *,
  512. retry_policy: Sequence[RetryPolicy] | None = None,
  513. cache_policy: CachePolicy | None = None,
  514. callbacks: Callbacks = None,
  515. futures: weakref.ref[FuturesDict],
  516. schedule_task: Callable[
  517. [PregelExecutableTask, int, Call | None], PregelExecutableTask | None
  518. ],
  519. submit: weakref.ref[Submit],
  520. ) -> concurrent.futures.Future[Any]:
  521. if inspect.iscoroutinefunction(func):
  522. raise RuntimeError("In an sync context async tasks cannot be called")
  523. fut: concurrent.futures.Future | None = None
  524. # schedule PUSH tasks, collect futures
  525. scratchpad: PregelScratchpad = task().config[CONF][CONFIG_KEY_SCRATCHPAD] # type: ignore[union-attr]
  526. # schedule the next task, if the callback returns one
  527. if next_task := schedule_task(
  528. task(), # type: ignore[arg-type]
  529. scratchpad.call_counter(),
  530. Call(
  531. func,
  532. input,
  533. retry_policy=retry_policy,
  534. cache_policy=cache_policy,
  535. callbacks=callbacks,
  536. ),
  537. ):
  538. if fut := next(
  539. (
  540. f
  541. for f, t in futures().items() # type: ignore[union-attr]
  542. if t is not None and t == next_task.id
  543. ),
  544. None,
  545. ):
  546. # if the parent task was retried,
  547. # the next task might already be running
  548. pass
  549. elif next_task.writes:
  550. # if it already ran, return the result
  551. fut = concurrent.futures.Future()
  552. ret = next((v for c, v in next_task.writes if c == RETURN), MISSING)
  553. if ret is not MISSING:
  554. fut.set_result(ret)
  555. elif exc := next((v for c, v in next_task.writes if c == ERROR), None):
  556. fut.set_exception(
  557. exc if isinstance(exc, BaseException) else Exception(exc)
  558. )
  559. else:
  560. fut.set_result(None)
  561. else:
  562. # schedule the next task
  563. fut = submit()( # type: ignore[misc]
  564. run_with_retry,
  565. next_task,
  566. retry_policy,
  567. configurable={
  568. CONFIG_KEY_CALL: partial(
  569. _call,
  570. weakref.ref(next_task),
  571. futures=futures,
  572. retry_policy=retry_policy,
  573. callbacks=callbacks,
  574. schedule_task=schedule_task,
  575. submit=submit,
  576. ),
  577. },
  578. __reraise_on_exit__=False,
  579. # starting a new task in the next tick ensures
  580. # updates from this tick are committed/streamed first
  581. __next_tick__=True,
  582. )
  583. # exceptions for call() tasks are raised into the parent task
  584. # so we should not re-raise at the end of the tick
  585. SKIP_RERAISE_SET.add(fut)
  586. futures()[fut] = next_task # type: ignore[index]
  587. fut = cast(asyncio.Future | concurrent.futures.Future, fut)
  588. # return a chained future to ensure commit() callback is called
  589. # before the returned future is resolved, to ensure stream order etc
  590. return chain_future(fut, concurrent.futures.Future())
  591. def _acall(
  592. task: weakref.ref[PregelExecutableTask],
  593. func: Callable[[Any], Awaitable[Any] | Any],
  594. input: Any,
  595. *,
  596. retry_policy: Sequence[RetryPolicy] | None = None,
  597. cache_policy: CachePolicy | None = None,
  598. callbacks: Callbacks = None,
  599. # injected dependencies
  600. futures: weakref.ref[FuturesDict],
  601. schedule_task: Callable[
  602. [PregelExecutableTask, int, Call | None],
  603. Awaitable[PregelExecutableTask | None],
  604. ],
  605. submit: weakref.ref[Submit],
  606. loop: asyncio.AbstractEventLoop,
  607. stream: bool = False,
  608. ) -> asyncio.Future[Any] | concurrent.futures.Future[Any]:
  609. # return a chained future to ensure commit() callback is called
  610. # before the returned future is resolved, to ensure stream order etc
  611. try:
  612. in_async = asyncio.current_task() is not None
  613. except RuntimeError:
  614. in_async = False
  615. # if in async context return an async future, otherwise return a sync future
  616. if in_async:
  617. fut: asyncio.Future[Any] | concurrent.futures.Future[Any] = asyncio.Future(
  618. loop=loop
  619. )
  620. else:
  621. fut = concurrent.futures.Future()
  622. # schedule the next task
  623. run_coroutine_threadsafe(
  624. _acall_impl(
  625. fut,
  626. task,
  627. func,
  628. input,
  629. retry_policy=retry_policy,
  630. cache_policy=cache_policy,
  631. callbacks=callbacks,
  632. futures=futures,
  633. schedule_task=schedule_task,
  634. submit=submit,
  635. loop=loop,
  636. stream=stream,
  637. ),
  638. loop,
  639. lazy=False,
  640. )
  641. return fut
  642. async def _acall_impl(
  643. destination: asyncio.Future[Any] | concurrent.futures.Future[Any],
  644. task: weakref.ref[PregelExecutableTask],
  645. func: Callable[[Any], Awaitable[Any] | Any],
  646. input: Any,
  647. *,
  648. retry_policy: Sequence[RetryPolicy] | None = None,
  649. cache_policy: CachePolicy | None = None,
  650. callbacks: Callbacks = None,
  651. # injected dependencies
  652. futures: weakref.ref[FuturesDict[asyncio.Future, asyncio.Event]],
  653. schedule_task: Callable[
  654. [PregelExecutableTask, int, Call | None],
  655. Awaitable[PregelExecutableTask | None],
  656. ],
  657. submit: weakref.ref[Submit],
  658. loop: asyncio.AbstractEventLoop,
  659. stream: bool = False,
  660. ) -> None:
  661. try:
  662. fut: asyncio.Future | None = None
  663. # schedule PUSH tasks, collect futures
  664. scratchpad: PregelScratchpad = task().config[CONF][CONFIG_KEY_SCRATCHPAD] # type: ignore[union-attr]
  665. # schedule the next task, if the callback returns one
  666. if next_task := await schedule_task(
  667. task(), # type: ignore[arg-type]
  668. scratchpad.call_counter(),
  669. Call(
  670. func,
  671. input,
  672. retry_policy=retry_policy,
  673. cache_policy=cache_policy,
  674. callbacks=callbacks,
  675. ),
  676. ):
  677. if fut := next(
  678. (
  679. f
  680. for f, t in futures().items() # type: ignore[union-attr]
  681. if t is not None and t == next_task.id
  682. ),
  683. None,
  684. ):
  685. # if the parent task was retried,
  686. # the next task might already be running
  687. pass
  688. elif next_task.writes:
  689. # if it already ran, return the result
  690. fut = asyncio.Future(loop=loop)
  691. ret = next((v for c, v in next_task.writes if c == RETURN), MISSING)
  692. if ret is not MISSING:
  693. fut.set_result(ret)
  694. elif exc := next((v for c, v in next_task.writes if c == ERROR), None):
  695. fut.set_exception(
  696. exc if isinstance(exc, BaseException) else Exception(exc)
  697. )
  698. else:
  699. fut.set_result(None)
  700. futures()[fut] = next_task # type: ignore[index]
  701. else:
  702. # schedule the next task
  703. fut = cast(
  704. asyncio.Future,
  705. submit()( # type: ignore[misc]
  706. arun_with_retry,
  707. next_task,
  708. retry_policy,
  709. stream=stream,
  710. configurable={
  711. CONFIG_KEY_CALL: partial(
  712. _acall,
  713. weakref.ref(next_task),
  714. stream=stream,
  715. futures=futures,
  716. schedule_task=schedule_task,
  717. submit=submit,
  718. loop=loop,
  719. ),
  720. },
  721. __name__=next_task.name,
  722. __cancel_on_exit__=True,
  723. __reraise_on_exit__=False,
  724. # starting a new task in the next tick ensures
  725. # updates from this tick are committed/streamed first
  726. __next_tick__=True,
  727. ),
  728. )
  729. # exceptions for call() tasks are raised into the parent task
  730. # so we should not re-raise at the end of the tick
  731. SKIP_RERAISE_SET.add(fut)
  732. futures()[fut] = next_task # type: ignore[index]
  733. if fut is not None:
  734. chain_future(fut, destination)
  735. else:
  736. destination.set_exception(RuntimeError("Task not scheduled"))
  737. except Exception as exc:
  738. destination.set_exception(exc)