_loop.py 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328
  1. from __future__ import annotations
  2. import asyncio
  3. import binascii
  4. import concurrent.futures
  5. from collections import defaultdict, deque
  6. from collections.abc import Callable, Iterator, Mapping, Sequence
  7. from contextlib import (
  8. AbstractAsyncContextManager,
  9. AbstractContextManager,
  10. AsyncExitStack,
  11. ExitStack,
  12. )
  13. from datetime import datetime, timezone
  14. from inspect import signature
  15. from types import TracebackType
  16. from typing import (
  17. Any,
  18. Literal,
  19. TypeVar,
  20. cast,
  21. )
  22. from langchain_core.callbacks import AsyncParentRunManager, ParentRunManager
  23. from langchain_core.runnables import RunnableConfig
  24. from langgraph.cache.base import BaseCache
  25. from langgraph.checkpoint.base import (
  26. WRITES_IDX_MAP,
  27. BaseCheckpointSaver,
  28. ChannelVersions,
  29. Checkpoint,
  30. CheckpointMetadata,
  31. CheckpointTuple,
  32. PendingWrite,
  33. )
  34. from langgraph.store.base import BaseStore
  35. from typing_extensions import ParamSpec, Self
  36. from langgraph._internal._config import patch_configurable
  37. from langgraph._internal._constants import (
  38. CONF,
  39. CONFIG_KEY_CHECKPOINT_ID,
  40. CONFIG_KEY_CHECKPOINT_MAP,
  41. CONFIG_KEY_CHECKPOINT_NS,
  42. CONFIG_KEY_RESUME_MAP,
  43. CONFIG_KEY_RESUMING,
  44. CONFIG_KEY_SCRATCHPAD,
  45. CONFIG_KEY_STREAM,
  46. CONFIG_KEY_TASK_ID,
  47. CONFIG_KEY_THREAD_ID,
  48. ERROR,
  49. INPUT,
  50. INTERRUPT,
  51. NS_END,
  52. NS_SEP,
  53. NULL_TASK_ID,
  54. PUSH,
  55. RESUME,
  56. TASKS,
  57. )
  58. from langgraph._internal._scratchpad import PregelScratchpad
  59. from langgraph._internal._typing import EMPTY_SEQ, MISSING
  60. from langgraph.channels.base import BaseChannel
  61. from langgraph.channels.untracked_value import UntrackedValue
  62. from langgraph.constants import TAG_HIDDEN
  63. from langgraph.errors import (
  64. EmptyInputError,
  65. GraphInterrupt,
  66. )
  67. from langgraph.managed.base import (
  68. ManagedValueMapping,
  69. ManagedValueSpec,
  70. )
  71. from langgraph.pregel._algo import (
  72. Call,
  73. GetNextVersion,
  74. PregelTaskWrites,
  75. apply_writes,
  76. checkpoint_null_version,
  77. increment,
  78. prepare_next_tasks,
  79. prepare_single_task,
  80. sanitize_untracked_values_in_send,
  81. should_interrupt,
  82. task_path_str,
  83. )
  84. from langgraph.pregel._checkpoint import (
  85. channels_from_checkpoint,
  86. copy_checkpoint,
  87. create_checkpoint,
  88. empty_checkpoint,
  89. )
  90. from langgraph.pregel._executor import (
  91. AsyncBackgroundExecutor,
  92. BackgroundExecutor,
  93. Submit,
  94. )
  95. from langgraph.pregel._io import (
  96. map_command,
  97. map_input,
  98. map_output_updates,
  99. map_output_values,
  100. read_channels,
  101. )
  102. from langgraph.pregel._read import PregelNode
  103. from langgraph.pregel._utils import get_new_channel_versions, is_xxh3_128_hexdigest
  104. from langgraph.pregel.debug import (
  105. map_debug_checkpoint,
  106. map_debug_task_results,
  107. map_debug_tasks,
  108. )
  109. from langgraph.pregel.protocol import StreamChunk, StreamProtocol
  110. from langgraph.types import (
  111. All,
  112. CachePolicy,
  113. Command,
  114. Durability,
  115. PregelExecutableTask,
  116. RetryPolicy,
  117. Send,
  118. StreamMode,
  119. )
  120. V = TypeVar("V")
  121. P = ParamSpec("P")
  122. WritesT = Sequence[tuple[str, Any]]
  123. def DuplexStream(*streams: StreamProtocol) -> StreamProtocol:
  124. def __call__(value: StreamChunk) -> None:
  125. for stream in streams:
  126. if value[1] in stream.modes:
  127. stream(value)
  128. return StreamProtocol(__call__, {mode for s in streams for mode in s.modes})
  129. class PregelLoop:
  130. config: RunnableConfig
  131. store: BaseStore | None
  132. stream: StreamProtocol | None
  133. step: int
  134. stop: int
  135. input: Any | None
  136. cache: BaseCache[WritesT] | None
  137. checkpointer: BaseCheckpointSaver | None
  138. nodes: Mapping[str, PregelNode]
  139. specs: Mapping[str, BaseChannel | ManagedValueSpec]
  140. input_keys: str | Sequence[str]
  141. output_keys: str | Sequence[str]
  142. stream_keys: str | Sequence[str]
  143. skip_done_tasks: bool
  144. is_nested: bool
  145. manager: None | AsyncParentRunManager | ParentRunManager
  146. interrupt_after: All | Sequence[str]
  147. interrupt_before: All | Sequence[str]
  148. durability: Durability
  149. retry_policy: Sequence[RetryPolicy]
  150. cache_policy: CachePolicy | None
  151. checkpointer_get_next_version: GetNextVersion
  152. checkpointer_put_writes: Callable[[RunnableConfig, WritesT, str], Any] | None
  153. checkpointer_put_writes_accepts_task_path: bool
  154. _checkpointer_put_after_previous: (
  155. Callable[
  156. [
  157. concurrent.futures.Future | None,
  158. RunnableConfig,
  159. Checkpoint,
  160. str,
  161. ChannelVersions,
  162. ],
  163. Any,
  164. ]
  165. | None
  166. )
  167. _migrate_checkpoint: Callable[[Checkpoint], None] | None
  168. submit: Submit
  169. channels: Mapping[str, BaseChannel]
  170. managed: ManagedValueMapping
  171. checkpoint: Checkpoint
  172. checkpoint_id_saved: str
  173. checkpoint_ns: tuple[str, ...]
  174. checkpoint_config: RunnableConfig
  175. checkpoint_metadata: CheckpointMetadata
  176. checkpoint_pending_writes: list[PendingWrite]
  177. checkpoint_previous_versions: dict[str, str | float | int]
  178. prev_checkpoint_config: RunnableConfig | None
  179. status: Literal[
  180. "input",
  181. "pending",
  182. "done",
  183. "interrupt_before",
  184. "interrupt_after",
  185. "out_of_steps",
  186. ]
  187. tasks: dict[str, PregelExecutableTask]
  188. output: None | dict[str, Any] | Any = None
  189. updated_channels: set[str] | None = None
  190. # public
  191. def __init__(
  192. self,
  193. input: Any | None,
  194. *,
  195. stream: StreamProtocol | None,
  196. config: RunnableConfig,
  197. store: BaseStore | None,
  198. cache: BaseCache | None,
  199. checkpointer: BaseCheckpointSaver | None,
  200. nodes: Mapping[str, PregelNode],
  201. specs: Mapping[str, BaseChannel | ManagedValueSpec],
  202. input_keys: str | Sequence[str],
  203. output_keys: str | Sequence[str],
  204. stream_keys: str | Sequence[str],
  205. trigger_to_nodes: Mapping[str, Sequence[str]],
  206. durability: Durability,
  207. interrupt_after: All | Sequence[str] = EMPTY_SEQ,
  208. interrupt_before: All | Sequence[str] = EMPTY_SEQ,
  209. manager: None | AsyncParentRunManager | ParentRunManager = None,
  210. migrate_checkpoint: Callable[[Checkpoint], None] | None = None,
  211. retry_policy: Sequence[RetryPolicy] = (),
  212. cache_policy: CachePolicy | None = None,
  213. ) -> None:
  214. self.stream = stream
  215. self.config = config
  216. self.store = store
  217. self.step = 0
  218. self.stop = 0
  219. self.input = input
  220. self.checkpointer = checkpointer
  221. self.cache = cache
  222. self.nodes = nodes
  223. self.specs = specs
  224. self.input_keys = input_keys
  225. self.output_keys = output_keys
  226. self.stream_keys = stream_keys
  227. self.interrupt_after = interrupt_after
  228. self.interrupt_before = interrupt_before
  229. self.manager = manager
  230. self.is_nested = CONFIG_KEY_TASK_ID in self.config.get(CONF, {})
  231. self.skip_done_tasks = CONFIG_KEY_CHECKPOINT_ID not in config[CONF]
  232. self._migrate_checkpoint = migrate_checkpoint
  233. self.trigger_to_nodes = trigger_to_nodes
  234. self.retry_policy = retry_policy
  235. self.cache_policy = cache_policy
  236. self.durability = durability
  237. if self.stream is not None and CONFIG_KEY_STREAM in config[CONF]:
  238. self.stream = DuplexStream(self.stream, config[CONF][CONFIG_KEY_STREAM])
  239. scratchpad: PregelScratchpad | None = config[CONF].get(CONFIG_KEY_SCRATCHPAD)
  240. if isinstance(scratchpad, PregelScratchpad):
  241. # if count is > 0, append to checkpoint_ns
  242. # if count is 0, leave as is
  243. if cnt := scratchpad.subgraph_counter():
  244. self.config = patch_configurable(
  245. self.config,
  246. {
  247. CONFIG_KEY_CHECKPOINT_NS: NS_SEP.join(
  248. (
  249. config[CONF][CONFIG_KEY_CHECKPOINT_NS],
  250. str(cnt),
  251. )
  252. )
  253. },
  254. )
  255. if not self.is_nested and config[CONF].get(CONFIG_KEY_CHECKPOINT_NS):
  256. self.config = patch_configurable(
  257. self.config,
  258. {CONFIG_KEY_CHECKPOINT_NS: "", CONFIG_KEY_CHECKPOINT_ID: None},
  259. )
  260. if (
  261. CONFIG_KEY_CHECKPOINT_MAP in self.config[CONF]
  262. and self.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS)
  263. in self.config[CONF][CONFIG_KEY_CHECKPOINT_MAP]
  264. ):
  265. self.checkpoint_config = patch_configurable(
  266. self.config,
  267. {
  268. CONFIG_KEY_CHECKPOINT_ID: self.config[CONF][
  269. CONFIG_KEY_CHECKPOINT_MAP
  270. ][self.config[CONF][CONFIG_KEY_CHECKPOINT_NS]]
  271. },
  272. )
  273. else:
  274. self.checkpoint_config = self.config
  275. if thread_id := self.checkpoint_config[CONF].get(CONFIG_KEY_THREAD_ID):
  276. if not isinstance(thread_id, str):
  277. self.checkpoint_config = patch_configurable(
  278. self.checkpoint_config,
  279. {CONFIG_KEY_THREAD_ID: str(thread_id)},
  280. )
  281. self.checkpoint_ns = (
  282. tuple(cast(str, self.config[CONF][CONFIG_KEY_CHECKPOINT_NS]).split(NS_SEP))
  283. if self.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS)
  284. else ()
  285. )
  286. self.prev_checkpoint_config = None
  287. def put_writes(self, task_id: str, writes: WritesT) -> None:
  288. """Put writes for a task, to be read by the next tick."""
  289. if not writes:
  290. return
  291. # deduplicate writes to special channels, last write wins
  292. if all(w[0] in WRITES_IDX_MAP for w in writes):
  293. writes = list({w[0]: w for w in writes}.values())
  294. if task_id == NULL_TASK_ID:
  295. # writes for the null task are accumulated
  296. self.checkpoint_pending_writes = [
  297. w
  298. for w in self.checkpoint_pending_writes
  299. if w[0] != task_id or w[1] not in WRITES_IDX_MAP
  300. ]
  301. writes_to_save: WritesT = [
  302. w[1:] for w in self.checkpoint_pending_writes if w[0] == task_id
  303. ] + list(writes)
  304. else:
  305. # remove existing writes for this task
  306. self.checkpoint_pending_writes = [
  307. w for w in self.checkpoint_pending_writes if w[0] != task_id
  308. ]
  309. writes_to_save = writes
  310. # check if any writes are to an UntrackedValue channel
  311. if any(
  312. isinstance(channel, UntrackedValue) for channel in self.channels.values()
  313. ):
  314. # we do not persist untracked values in checkpoints
  315. writes_to_save = [
  316. # sanitize UntrackedValues that are nested within Send packets
  317. (
  318. (c, sanitize_untracked_values_in_send(v, self.channels))
  319. if c == TASKS and isinstance(v, Send)
  320. else (c, v)
  321. )
  322. for c, v in writes_to_save
  323. # dont persist UntrackedValue channel writes
  324. if not isinstance(self.specs.get(c), UntrackedValue)
  325. ]
  326. # save writes
  327. self.checkpoint_pending_writes.extend((task_id, c, v) for c, v in writes)
  328. if self.durability != "exit" and self.checkpointer_put_writes is not None:
  329. config = patch_configurable(
  330. self.checkpoint_config,
  331. {
  332. CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get(
  333. CONFIG_KEY_CHECKPOINT_NS, ""
  334. ),
  335. CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"],
  336. },
  337. )
  338. if self.checkpointer_put_writes_accepts_task_path:
  339. if hasattr(self, "tasks"):
  340. task = self.tasks.get(task_id)
  341. else:
  342. task = None
  343. self.submit(
  344. self.checkpointer_put_writes,
  345. config,
  346. writes_to_save,
  347. task_id,
  348. task_path_str(task.path) if task else "",
  349. )
  350. else:
  351. self.submit(
  352. self.checkpointer_put_writes,
  353. config,
  354. writes_to_save,
  355. task_id,
  356. )
  357. # output writes
  358. if hasattr(self, "tasks"):
  359. self.output_writes(task_id, writes)
  360. def _put_pending_writes(self) -> None:
  361. if self.checkpointer_put_writes is None:
  362. return
  363. if not self.checkpoint_pending_writes:
  364. return
  365. # patch config
  366. config = patch_configurable(
  367. self.checkpoint_config,
  368. {
  369. CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get(
  370. CONFIG_KEY_CHECKPOINT_NS, ""
  371. ),
  372. CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"],
  373. },
  374. )
  375. # group by task id
  376. by_task = defaultdict(list)
  377. for task_id, channel, value in self.checkpoint_pending_writes:
  378. by_task[task_id].append((channel, value))
  379. # submit writes to checkpointer
  380. for task_id, writes in by_task.items():
  381. if self.checkpointer_put_writes_accepts_task_path and hasattr(
  382. self, "tasks"
  383. ):
  384. task = self.tasks.get(task_id)
  385. self.submit(
  386. self.checkpointer_put_writes,
  387. config,
  388. writes,
  389. task_id,
  390. task_path_str(task.path) if task else "",
  391. )
  392. else:
  393. self.submit(
  394. self.checkpointer_put_writes,
  395. config,
  396. writes,
  397. task_id,
  398. )
  399. def accept_push(
  400. self, task: PregelExecutableTask, write_idx: int, call: Call | None = None
  401. ) -> PregelExecutableTask | None:
  402. """Accept a PUSH from a task, potentially returning a new task to start."""
  403. checkpoint_id_bytes = binascii.unhexlify(self.checkpoint["id"].replace("-", ""))
  404. null_version = checkpoint_null_version(self.checkpoint)
  405. if pushed := cast(
  406. PregelExecutableTask | None,
  407. prepare_single_task(
  408. (PUSH, task.path, write_idx, task.id, call),
  409. None,
  410. checkpoint=self.checkpoint,
  411. checkpoint_id_bytes=checkpoint_id_bytes,
  412. checkpoint_null_version=null_version,
  413. pending_writes=self.checkpoint_pending_writes,
  414. processes=self.nodes,
  415. channels=self.channels,
  416. managed=self.managed,
  417. config=task.config,
  418. step=self.step,
  419. stop=self.stop,
  420. for_execution=True,
  421. store=self.store,
  422. checkpointer=self.checkpointer,
  423. manager=self.manager,
  424. retry_policy=self.retry_policy,
  425. cache_policy=self.cache_policy,
  426. ),
  427. ):
  428. # produce debug output
  429. self._emit("tasks", map_debug_tasks, [pushed])
  430. # save the new task
  431. self.tasks[pushed.id] = pushed
  432. # match any pending writes to the new task
  433. if self.skip_done_tasks:
  434. self._match_writes({pushed.id: pushed})
  435. # return the new task, to be started if not run before
  436. return pushed
  437. def tick(self) -> bool:
  438. """Execute a single iteration of the Pregel loop.
  439. Returns:
  440. True if more iterations are needed.
  441. """
  442. # check if iteration limit is reached
  443. if self.step > self.stop:
  444. self.status = "out_of_steps"
  445. return False
  446. # prepare next tasks
  447. self.tasks = prepare_next_tasks(
  448. self.checkpoint,
  449. self.checkpoint_pending_writes,
  450. self.nodes,
  451. self.channels,
  452. self.managed,
  453. self.config,
  454. self.step,
  455. self.stop,
  456. for_execution=True,
  457. manager=self.manager,
  458. store=self.store,
  459. checkpointer=self.checkpointer,
  460. trigger_to_nodes=self.trigger_to_nodes,
  461. updated_channels=self.updated_channels,
  462. retry_policy=self.retry_policy,
  463. cache_policy=self.cache_policy,
  464. )
  465. # produce debug output
  466. if self._checkpointer_put_after_previous is not None:
  467. self._emit(
  468. "checkpoints",
  469. map_debug_checkpoint,
  470. {
  471. **self.checkpoint_config,
  472. CONF: {
  473. **self.checkpoint_config[CONF],
  474. CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"],
  475. },
  476. },
  477. self.channels,
  478. self.stream_keys,
  479. self.checkpoint_metadata,
  480. self.tasks.values(),
  481. self.checkpoint_pending_writes,
  482. self.prev_checkpoint_config,
  483. self.output_keys,
  484. )
  485. # if no more tasks, we're done
  486. if not self.tasks:
  487. self.status = "done"
  488. return False
  489. # if there are pending writes from a previous loop, apply them
  490. if self.skip_done_tasks and self.checkpoint_pending_writes:
  491. self._match_writes(self.tasks)
  492. # before execution, check if we should interrupt
  493. if self.interrupt_before and should_interrupt(
  494. self.checkpoint, self.interrupt_before, self.tasks.values()
  495. ):
  496. self.status = "interrupt_before"
  497. raise GraphInterrupt()
  498. # produce debug output
  499. self._emit("tasks", map_debug_tasks, self.tasks.values())
  500. # print output for any tasks we applied previous writes to
  501. for task in self.tasks.values():
  502. if task.writes:
  503. self.output_writes(task.id, task.writes, cached=True)
  504. return True
  505. def after_tick(self) -> None:
  506. # finish superstep
  507. writes = [w for t in self.tasks.values() for w in t.writes]
  508. # all tasks have finished
  509. self.updated_channels = apply_writes(
  510. self.checkpoint,
  511. self.channels,
  512. self.tasks.values(),
  513. self.checkpointer_get_next_version,
  514. self.trigger_to_nodes,
  515. )
  516. # produce values output
  517. if not self.updated_channels.isdisjoint(
  518. (self.output_keys,)
  519. if isinstance(self.output_keys, str)
  520. else self.output_keys
  521. ):
  522. self._emit(
  523. "values", map_output_values, self.output_keys, writes, self.channels
  524. )
  525. # clear pending writes
  526. self.checkpoint_pending_writes.clear()
  527. # "not skip_done_tasks" only applies to first tick after resuming
  528. self.skip_done_tasks = True
  529. # save checkpoint
  530. self._put_checkpoint({"source": "loop"})
  531. # after execution, check if we should interrupt
  532. if self.interrupt_after and should_interrupt(
  533. self.checkpoint, self.interrupt_after, self.tasks.values()
  534. ):
  535. self.status = "interrupt_after"
  536. raise GraphInterrupt()
  537. # unset resuming flag
  538. self.config[CONF].pop(CONFIG_KEY_RESUMING, None)
  539. def match_cached_writes(self) -> Sequence[PregelExecutableTask]:
  540. raise NotImplementedError
  541. async def amatch_cached_writes(self) -> Sequence[PregelExecutableTask]:
  542. raise NotImplementedError
  543. # private
  544. def _match_writes(self, tasks: Mapping[str, PregelExecutableTask]) -> None:
  545. for tid, k, v in self.checkpoint_pending_writes:
  546. if k in (ERROR, INTERRUPT, RESUME):
  547. continue
  548. if task := tasks.get(tid):
  549. task.writes.append((k, v))
  550. def _pending_interrupts(self) -> set[str]:
  551. """Return the set of interrupt ids that are pending without corresponding resume values."""
  552. # mapping of task ids to interrupt ids
  553. pending_interrupts: dict[str, str] = {}
  554. # set of resume task ids
  555. pending_resumes: set[str] = set()
  556. for task_id, write_type, value in self.checkpoint_pending_writes:
  557. if write_type == INTERRUPT:
  558. # interrupts is always a list, but there should only be one element
  559. pending_interrupts[task_id] = value[0].id
  560. elif write_type == RESUME:
  561. pending_resumes.add(task_id)
  562. resumed_interrupt_ids = {
  563. pending_interrupts[task_id]
  564. for task_id in pending_resumes
  565. if task_id in pending_interrupts
  566. }
  567. # Keep only interrupts whose interrupt_id is not resumed
  568. hanging_interrupts: set[str] = {
  569. interrupt_id
  570. for interrupt_id in pending_interrupts.values()
  571. if interrupt_id not in resumed_interrupt_ids
  572. }
  573. return hanging_interrupts
  574. def _first(
  575. self, *, input_keys: str | Sequence[str], updated_channels: set[str] | None
  576. ) -> set[str] | None:
  577. # resuming from previous checkpoint requires
  578. # - finding a previous checkpoint
  579. # - receiving None input (outer graph) or RESUMING flag (subgraph)
  580. configurable = self.config.get(CONF, {})
  581. is_resuming = bool(self.checkpoint["channel_versions"]) and bool(
  582. configurable.get(
  583. CONFIG_KEY_RESUMING,
  584. self.input is None
  585. or isinstance(self.input, Command)
  586. or (
  587. not self.is_nested
  588. and self.config.get("metadata", {}).get("run_id")
  589. == self.checkpoint_metadata.get("run_id", MISSING)
  590. ),
  591. )
  592. )
  593. # map command to writes
  594. if isinstance(self.input, Command):
  595. if (resume := self.input.resume) is not None:
  596. if not self.checkpointer:
  597. raise RuntimeError(
  598. "Cannot use Command(resume=...) without checkpointer"
  599. )
  600. if resume_is_map := (
  601. isinstance(resume, dict)
  602. and all(is_xxh3_128_hexdigest(k) for k in resume)
  603. ):
  604. self.config[CONF][CONFIG_KEY_RESUME_MAP] = resume
  605. else:
  606. if len(self._pending_interrupts()) > 1:
  607. raise RuntimeError(
  608. "When there are multiple pending interrupts, you must specify the interrupt id when resuming. "
  609. "Docs: https://docs.langchain.com/oss/python/langgraph/add-human-in-the-loop#resume-multiple-interrupts-with-one-invocation."
  610. )
  611. writes: defaultdict[str, list[tuple[str, Any]]] = defaultdict(list)
  612. # group writes by task ID
  613. for tid, c, v in map_command(cmd=self.input):
  614. if not (c == RESUME and resume_is_map):
  615. writes[tid].append((c, v))
  616. if not writes and not resume_is_map:
  617. raise EmptyInputError("Received empty Command input")
  618. # save writes
  619. for tid, ws in writes.items():
  620. self.put_writes(tid, ws)
  621. # apply NULL writes
  622. if null_writes := [
  623. w[1:] for w in self.checkpoint_pending_writes if w[0] == NULL_TASK_ID
  624. ]:
  625. null_updated_channels = apply_writes(
  626. self.checkpoint,
  627. self.channels,
  628. [PregelTaskWrites((), INPUT, null_writes, [])],
  629. self.checkpointer_get_next_version,
  630. self.trigger_to_nodes,
  631. )
  632. if updated_channels is not None:
  633. updated_channels.update(null_updated_channels)
  634. # proceed past previous checkpoint
  635. if is_resuming:
  636. self.checkpoint["versions_seen"].setdefault(INTERRUPT, {})
  637. for k in self.channels:
  638. if k in self.checkpoint["channel_versions"]:
  639. version = self.checkpoint["channel_versions"][k]
  640. self.checkpoint["versions_seen"][INTERRUPT][k] = version
  641. # produce values output
  642. self._emit(
  643. "values", map_output_values, self.output_keys, True, self.channels
  644. )
  645. # map inputs to channel updates
  646. elif input_writes := deque(map_input(input_keys, self.input)):
  647. # discard any unfinished tasks from previous checkpoint
  648. discard_tasks = prepare_next_tasks(
  649. self.checkpoint,
  650. self.checkpoint_pending_writes,
  651. self.nodes,
  652. self.channels,
  653. self.managed,
  654. self.config,
  655. self.step,
  656. self.stop,
  657. for_execution=True,
  658. store=None,
  659. checkpointer=None,
  660. manager=None,
  661. updated_channels=updated_channels,
  662. )
  663. # apply input writes
  664. updated_channels = apply_writes(
  665. self.checkpoint,
  666. self.channels,
  667. [
  668. *discard_tasks.values(),
  669. PregelTaskWrites((), INPUT, input_writes, []),
  670. ],
  671. self.checkpointer_get_next_version,
  672. self.trigger_to_nodes,
  673. )
  674. # save input checkpoint
  675. self.updated_channels = updated_channels
  676. self._put_checkpoint({"source": "input"})
  677. elif CONFIG_KEY_RESUMING not in configurable:
  678. raise EmptyInputError(f"Received no input for {input_keys}")
  679. # update config
  680. if not self.is_nested:
  681. self.config = patch_configurable(
  682. self.config, {CONFIG_KEY_RESUMING: is_resuming}
  683. )
  684. # set flag
  685. self.status = "pending"
  686. return updated_channels
  687. def _put_checkpoint(self, metadata: CheckpointMetadata) -> None:
  688. # assign step and parents
  689. exiting = metadata is self.checkpoint_metadata
  690. if exiting and self.checkpoint["id"] == self.checkpoint_id_saved:
  691. # checkpoint already saved
  692. return
  693. if not exiting:
  694. metadata["step"] = self.step
  695. metadata["parents"] = self.config[CONF].get(CONFIG_KEY_CHECKPOINT_MAP, {})
  696. self.checkpoint_metadata = metadata
  697. # do checkpoint?
  698. do_checkpoint = self._checkpointer_put_after_previous is not None and (
  699. exiting or self.durability != "exit"
  700. )
  701. # create new checkpoint
  702. self.checkpoint = create_checkpoint(
  703. self.checkpoint,
  704. self.channels if do_checkpoint else None,
  705. self.step,
  706. id=self.checkpoint["id"] if exiting else None,
  707. updated_channels=self.updated_channels,
  708. )
  709. # sanitize TASK channel in the checkpoint before saving (durability=="exit")
  710. if TASKS in self.checkpoint["channel_values"] and any(
  711. isinstance(channel, UntrackedValue) for channel in self.channels.values()
  712. ):
  713. sanitized_tasks = [
  714. sanitize_untracked_values_in_send(value, self.channels)
  715. if isinstance(value, Send)
  716. else value
  717. for value in self.checkpoint["channel_values"][TASKS]
  718. ]
  719. self.checkpoint["channel_values"][TASKS] = sanitized_tasks
  720. # bail if no checkpointer
  721. if do_checkpoint and self._checkpointer_put_after_previous is not None:
  722. self.prev_checkpoint_config = (
  723. self.checkpoint_config
  724. if CONFIG_KEY_CHECKPOINT_ID in self.checkpoint_config[CONF]
  725. and self.checkpoint_config[CONF][CONFIG_KEY_CHECKPOINT_ID]
  726. else None
  727. )
  728. self.checkpoint_config = {
  729. **self.checkpoint_config,
  730. CONF: {
  731. **self.checkpoint_config[CONF],
  732. CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get(
  733. CONFIG_KEY_CHECKPOINT_NS, ""
  734. ),
  735. },
  736. }
  737. channel_versions = self.checkpoint["channel_versions"].copy()
  738. new_versions = get_new_channel_versions(
  739. self.checkpoint_previous_versions, channel_versions
  740. )
  741. self.checkpoint_previous_versions = channel_versions
  742. # save it, without blocking
  743. # if there's a previous checkpoint save in progress, wait for it
  744. # ensuring checkpointers receive checkpoints in order
  745. self._put_checkpoint_fut = self.submit(
  746. self._checkpointer_put_after_previous,
  747. getattr(self, "_put_checkpoint_fut", None),
  748. self.checkpoint_config,
  749. copy_checkpoint(self.checkpoint),
  750. self.checkpoint_metadata,
  751. new_versions,
  752. )
  753. self.checkpoint_config = {
  754. **self.checkpoint_config,
  755. CONF: {
  756. **self.checkpoint_config[CONF],
  757. CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"],
  758. },
  759. }
  760. if not exiting:
  761. # increment step
  762. self.step += 1
  763. def _suppress_interrupt(
  764. self,
  765. exc_type: type[BaseException] | None,
  766. exc_value: BaseException | None,
  767. traceback: TracebackType | None,
  768. ) -> bool | None:
  769. # persist current checkpoint and writes
  770. if self.durability == "exit" and (
  771. # if it's a top graph
  772. not self.is_nested
  773. # or a nested graph with error or interrupt
  774. or exc_value is not None
  775. # or a nested graph with checkpointer=True
  776. or all(NS_END not in part for part in self.checkpoint_ns)
  777. ):
  778. self._put_checkpoint(self.checkpoint_metadata)
  779. self._put_pending_writes()
  780. # suppress interrupt
  781. suppress = isinstance(exc_value, GraphInterrupt) and not self.is_nested
  782. if suppress:
  783. # emit one last "values" event, with pending writes applied
  784. if (
  785. hasattr(self, "tasks")
  786. and self.checkpoint_pending_writes
  787. and any(task.writes for task in self.tasks.values())
  788. ):
  789. updated_channels = apply_writes(
  790. self.checkpoint,
  791. self.channels,
  792. self.tasks.values(),
  793. self.checkpointer_get_next_version,
  794. self.trigger_to_nodes,
  795. )
  796. if not updated_channels.isdisjoint(
  797. (self.output_keys,)
  798. if isinstance(self.output_keys, str)
  799. else self.output_keys
  800. ):
  801. self._emit(
  802. "values",
  803. map_output_values,
  804. self.output_keys,
  805. [w for t in self.tasks.values() for w in t.writes],
  806. self.channels,
  807. )
  808. # emit INTERRUPT if exception is empty (otherwise emitted by put_writes)
  809. if exc_value is not None and (not exc_value.args or not exc_value.args[0]):
  810. self._emit(
  811. "updates",
  812. lambda: iter(
  813. [{INTERRUPT: cast(GraphInterrupt, exc_value).args[0]}]
  814. ),
  815. )
  816. # save final output
  817. self.output = read_channels(self.channels, self.output_keys)
  818. # suppress interrupt
  819. return True
  820. elif exc_type is None:
  821. # save final output
  822. self.output = read_channels(self.channels, self.output_keys)
  823. def _emit(
  824. self,
  825. mode: StreamMode,
  826. values: Callable[P, Iterator[Any]],
  827. *args: P.args,
  828. **kwargs: P.kwargs,
  829. ) -> None:
  830. if self.stream is None:
  831. return
  832. debug_remap = mode in ("checkpoints", "tasks") and "debug" in self.stream.modes
  833. if mode not in self.stream.modes and not debug_remap:
  834. return
  835. for v in values(*args, **kwargs):
  836. if mode in self.stream.modes:
  837. self.stream((self.checkpoint_ns, mode, v))
  838. # "debug" mode is "checkpoints" or "tasks" with a wrapper dict
  839. if debug_remap:
  840. self.stream(
  841. (
  842. self.checkpoint_ns,
  843. "debug",
  844. {
  845. "step": self.step - 1
  846. if mode == "checkpoints"
  847. else self.step,
  848. "timestamp": datetime.now(timezone.utc).isoformat(),
  849. "type": "checkpoint"
  850. if mode == "checkpoints"
  851. else "task_result"
  852. if "result" in v
  853. else "task",
  854. "payload": v,
  855. },
  856. )
  857. )
  858. def output_writes(
  859. self, task_id: str, writes: WritesT, *, cached: bool = False
  860. ) -> None:
  861. if task := self.tasks.get(task_id):
  862. if task.config is not None and TAG_HIDDEN in task.config.get(
  863. "tags", EMPTY_SEQ
  864. ):
  865. return
  866. if writes[0][0] == INTERRUPT:
  867. # in loop.py we append a bool to the PUSH task paths to indicate
  868. # whether or not a call was present. If so,
  869. # we don't emit the interrupt as it'll be emitted by the parent
  870. if task.path[0] == PUSH and task.path[-1] is True:
  871. return
  872. interrupts = [
  873. {
  874. INTERRUPT: tuple(
  875. v
  876. for w in writes
  877. if w[0] == INTERRUPT
  878. for v in (w[1] if isinstance(w[1], Sequence) else (w[1],))
  879. )
  880. }
  881. ]
  882. stream_modes = self.stream.modes if self.stream else []
  883. if "updates" in stream_modes:
  884. self._emit("updates", lambda: iter(interrupts))
  885. if "values" in stream_modes:
  886. current_values = read_channels(self.channels, self.output_keys)
  887. # self.output_keys is a sequence, stream chunk contains entire state and interrupts
  888. if isinstance(current_values, dict):
  889. current_values[INTERRUPT] = interrupts[0][INTERRUPT]
  890. self._emit("values", lambda: iter([current_values]))
  891. # self.output_keys is a string, stream chunk contains only interrupts
  892. else:
  893. self._emit("values", lambda: iter(interrupts))
  894. elif writes[0][0] != ERROR:
  895. self._emit(
  896. "updates",
  897. map_output_updates,
  898. self.output_keys,
  899. [(task, writes)],
  900. cached,
  901. )
  902. if not cached:
  903. self._emit(
  904. "tasks",
  905. map_debug_task_results,
  906. (task, writes),
  907. self.stream_keys,
  908. )
  909. class SyncPregelLoop(PregelLoop, AbstractContextManager):
  910. def __init__(
  911. self,
  912. input: Any | None,
  913. *,
  914. stream: StreamProtocol | None,
  915. config: RunnableConfig,
  916. store: BaseStore | None,
  917. cache: BaseCache | None,
  918. checkpointer: BaseCheckpointSaver | None,
  919. nodes: Mapping[str, PregelNode],
  920. specs: Mapping[str, BaseChannel | ManagedValueSpec],
  921. trigger_to_nodes: Mapping[str, Sequence[str]],
  922. durability: Durability,
  923. manager: None | AsyncParentRunManager | ParentRunManager = None,
  924. interrupt_after: All | Sequence[str] = EMPTY_SEQ,
  925. interrupt_before: All | Sequence[str] = EMPTY_SEQ,
  926. input_keys: str | Sequence[str] = EMPTY_SEQ,
  927. output_keys: str | Sequence[str] = EMPTY_SEQ,
  928. stream_keys: str | Sequence[str] = EMPTY_SEQ,
  929. migrate_checkpoint: Callable[[Checkpoint], None] | None = None,
  930. retry_policy: Sequence[RetryPolicy] = (),
  931. cache_policy: CachePolicy | None = None,
  932. ) -> None:
  933. super().__init__(
  934. input,
  935. stream=stream,
  936. config=config,
  937. checkpointer=checkpointer,
  938. cache=cache,
  939. store=store,
  940. nodes=nodes,
  941. specs=specs,
  942. input_keys=input_keys,
  943. output_keys=output_keys,
  944. stream_keys=stream_keys,
  945. interrupt_after=interrupt_after,
  946. interrupt_before=interrupt_before,
  947. manager=manager,
  948. migrate_checkpoint=migrate_checkpoint,
  949. trigger_to_nodes=trigger_to_nodes,
  950. retry_policy=retry_policy,
  951. cache_policy=cache_policy,
  952. durability=durability,
  953. )
  954. self.stack = ExitStack()
  955. if checkpointer:
  956. self.checkpointer_get_next_version = checkpointer.get_next_version
  957. self.checkpointer_put_writes = checkpointer.put_writes
  958. self.checkpointer_put_writes_accepts_task_path = (
  959. signature(checkpointer.put_writes).parameters.get("task_path")
  960. is not None
  961. )
  962. else:
  963. self.checkpointer_get_next_version = increment
  964. self._checkpointer_put_after_previous = None # type: ignore[assignment]
  965. self.checkpointer_put_writes = None
  966. self.checkpointer_put_writes_accepts_task_path = False
  967. def _checkpointer_put_after_previous(
  968. self,
  969. prev: concurrent.futures.Future | None,
  970. config: RunnableConfig,
  971. checkpoint: Checkpoint,
  972. metadata: CheckpointMetadata,
  973. new_versions: ChannelVersions,
  974. ) -> RunnableConfig:
  975. try:
  976. if prev is not None:
  977. prev.result()
  978. finally:
  979. cast(BaseCheckpointSaver, self.checkpointer).put(
  980. config, checkpoint, metadata, new_versions
  981. )
  982. def match_cached_writes(self) -> Sequence[PregelExecutableTask]:
  983. if self.cache is None:
  984. return ()
  985. matched: list[PregelExecutableTask] = []
  986. if cached := {
  987. (t.cache_key.ns, t.cache_key.key): t
  988. for t in self.tasks.values()
  989. if t.cache_key and not t.writes
  990. }:
  991. for key, values in self.cache.get(tuple(cached)).items():
  992. task = cached[key]
  993. task.writes.extend(values)
  994. matched.append(task)
  995. return matched
  996. def accept_push(
  997. self, task: PregelExecutableTask, write_idx: int, call: Call | None = None
  998. ) -> PregelExecutableTask | None:
  999. if pushed := super().accept_push(task, write_idx, call):
  1000. for task in self.match_cached_writes():
  1001. self.output_writes(task.id, task.writes, cached=True)
  1002. return pushed
  1003. def put_writes(self, task_id: str, writes: WritesT) -> None:
  1004. """Put writes for a task, to be read by the next tick."""
  1005. super().put_writes(task_id, writes)
  1006. if not writes or self.cache is None or not hasattr(self, "tasks"):
  1007. return
  1008. task = self.tasks.get(task_id)
  1009. if task is None or task.cache_key is None:
  1010. return
  1011. self.submit(
  1012. self.cache.set,
  1013. {
  1014. (task.cache_key.ns, task.cache_key.key): (
  1015. task.writes,
  1016. task.cache_key.ttl,
  1017. )
  1018. },
  1019. )
  1020. # context manager
  1021. def __enter__(self) -> Self:
  1022. if self.checkpointer:
  1023. saved = self.checkpointer.get_tuple(self.checkpoint_config)
  1024. else:
  1025. saved = None
  1026. if saved is None:
  1027. saved = CheckpointTuple(
  1028. self.checkpoint_config, empty_checkpoint(), {"step": -2}, None, []
  1029. )
  1030. elif self._migrate_checkpoint is not None:
  1031. self._migrate_checkpoint(saved.checkpoint)
  1032. self.checkpoint_config = {
  1033. **self.checkpoint_config,
  1034. **saved.config,
  1035. CONF: {
  1036. CONFIG_KEY_CHECKPOINT_NS: "",
  1037. **self.checkpoint_config.get(CONF, {}),
  1038. **saved.config.get(CONF, {}),
  1039. },
  1040. }
  1041. self.prev_checkpoint_config = saved.parent_config
  1042. self.checkpoint_id_saved = saved.checkpoint["id"]
  1043. self.checkpoint = saved.checkpoint
  1044. self.checkpoint_metadata = saved.metadata
  1045. self.checkpoint_pending_writes = (
  1046. [(str(tid), k, v) for tid, k, v in saved.pending_writes]
  1047. if saved.pending_writes is not None
  1048. else []
  1049. )
  1050. self.submit = self.stack.enter_context(BackgroundExecutor(self.config))
  1051. self.channels, self.managed = channels_from_checkpoint(
  1052. self.specs, self.checkpoint
  1053. )
  1054. self.stack.push(self._suppress_interrupt)
  1055. self.status = "input"
  1056. self.step = self.checkpoint_metadata["step"] + 1
  1057. self.stop = self.step + self.config["recursion_limit"] + 1
  1058. self.checkpoint_previous_versions = self.checkpoint["channel_versions"].copy()
  1059. self.updated_channels = self._first(
  1060. input_keys=self.input_keys,
  1061. updated_channels=set(self.checkpoint.get("updated_channels")) # type: ignore[arg-type]
  1062. if self.checkpoint.get("updated_channels")
  1063. else None,
  1064. )
  1065. return self
  1066. def __exit__(
  1067. self,
  1068. exc_type: type[BaseException] | None,
  1069. exc_value: BaseException | None,
  1070. traceback: TracebackType | None,
  1071. ) -> bool | None:
  1072. # unwind stack
  1073. return self.stack.__exit__(exc_type, exc_value, traceback)
  1074. class AsyncPregelLoop(PregelLoop, AbstractAsyncContextManager):
  1075. def __init__(
  1076. self,
  1077. input: Any | None,
  1078. *,
  1079. stream: StreamProtocol | None,
  1080. config: RunnableConfig,
  1081. store: BaseStore | None,
  1082. cache: BaseCache | None,
  1083. checkpointer: BaseCheckpointSaver | None,
  1084. nodes: Mapping[str, PregelNode],
  1085. specs: Mapping[str, BaseChannel | ManagedValueSpec],
  1086. trigger_to_nodes: Mapping[str, Sequence[str]],
  1087. durability: Durability,
  1088. interrupt_after: All | Sequence[str] = EMPTY_SEQ,
  1089. interrupt_before: All | Sequence[str] = EMPTY_SEQ,
  1090. manager: None | AsyncParentRunManager | ParentRunManager = None,
  1091. input_keys: str | Sequence[str] = EMPTY_SEQ,
  1092. output_keys: str | Sequence[str] = EMPTY_SEQ,
  1093. stream_keys: str | Sequence[str] = EMPTY_SEQ,
  1094. migrate_checkpoint: Callable[[Checkpoint], None] | None = None,
  1095. retry_policy: Sequence[RetryPolicy] = (),
  1096. cache_policy: CachePolicy | None = None,
  1097. ) -> None:
  1098. super().__init__(
  1099. input,
  1100. stream=stream,
  1101. config=config,
  1102. checkpointer=checkpointer,
  1103. cache=cache,
  1104. store=store,
  1105. nodes=nodes,
  1106. specs=specs,
  1107. input_keys=input_keys,
  1108. output_keys=output_keys,
  1109. stream_keys=stream_keys,
  1110. interrupt_after=interrupt_after,
  1111. interrupt_before=interrupt_before,
  1112. manager=manager,
  1113. migrate_checkpoint=migrate_checkpoint,
  1114. trigger_to_nodes=trigger_to_nodes,
  1115. retry_policy=retry_policy,
  1116. cache_policy=cache_policy,
  1117. durability=durability,
  1118. )
  1119. self.stack = AsyncExitStack()
  1120. if checkpointer:
  1121. self.checkpointer_get_next_version = checkpointer.get_next_version
  1122. self.checkpointer_put_writes = checkpointer.aput_writes
  1123. self.checkpointer_put_writes_accepts_task_path = (
  1124. signature(checkpointer.aput_writes).parameters.get("task_path")
  1125. is not None
  1126. )
  1127. else:
  1128. self.checkpointer_get_next_version = increment
  1129. self._checkpointer_put_after_previous = None # type: ignore[assignment]
  1130. self.checkpointer_put_writes = None
  1131. self.checkpointer_put_writes_accepts_task_path = False
  1132. async def _checkpointer_put_after_previous(
  1133. self,
  1134. prev: asyncio.Task | None,
  1135. config: RunnableConfig,
  1136. checkpoint: Checkpoint,
  1137. metadata: CheckpointMetadata,
  1138. new_versions: ChannelVersions,
  1139. ) -> RunnableConfig:
  1140. try:
  1141. if prev is not None:
  1142. await prev
  1143. finally:
  1144. await cast(BaseCheckpointSaver, self.checkpointer).aput(
  1145. config, checkpoint, metadata, new_versions
  1146. )
  1147. async def amatch_cached_writes(self) -> Sequence[PregelExecutableTask]:
  1148. if self.cache is None:
  1149. return []
  1150. matched: list[PregelExecutableTask] = []
  1151. if cached := {
  1152. (t.cache_key.ns, t.cache_key.key): t
  1153. for t in self.tasks.values()
  1154. if t.cache_key and not t.writes
  1155. }:
  1156. for key, values in (await self.cache.aget(tuple(cached))).items():
  1157. task = cached[key]
  1158. task.writes.extend(values)
  1159. matched.append(task)
  1160. return matched
  1161. async def aaccept_push(
  1162. self, task: PregelExecutableTask, write_idx: int, call: Call | None = None
  1163. ) -> PregelExecutableTask | None:
  1164. if pushed := super().accept_push(task, write_idx, call):
  1165. for task in await self.amatch_cached_writes():
  1166. self.output_writes(task.id, task.writes, cached=True)
  1167. return pushed
  1168. def put_writes(self, task_id: str, writes: WritesT) -> None:
  1169. """Put writes for a task, to be read by the next tick."""
  1170. super().put_writes(task_id, writes)
  1171. if not writes or self.cache is None or not hasattr(self, "tasks"):
  1172. return
  1173. task = self.tasks.get(task_id)
  1174. if task is None or task.cache_key is None:
  1175. return
  1176. if writes[0][0] in (INTERRUPT, ERROR):
  1177. # only cache successful tasks
  1178. return
  1179. self.submit(
  1180. self.cache.aset,
  1181. {
  1182. (task.cache_key.ns, task.cache_key.key): (
  1183. task.writes,
  1184. task.cache_key.ttl,
  1185. )
  1186. },
  1187. )
  1188. # context manager
  1189. async def __aenter__(self) -> Self:
  1190. if self.checkpointer:
  1191. saved = await self.checkpointer.aget_tuple(self.checkpoint_config)
  1192. else:
  1193. saved = None
  1194. if saved is None:
  1195. saved = CheckpointTuple(
  1196. self.checkpoint_config, empty_checkpoint(), {"step": -2}, None, []
  1197. )
  1198. elif self._migrate_checkpoint is not None:
  1199. self._migrate_checkpoint(saved.checkpoint)
  1200. self.checkpoint_config = {
  1201. **self.checkpoint_config,
  1202. **saved.config,
  1203. CONF: {
  1204. CONFIG_KEY_CHECKPOINT_NS: "",
  1205. **self.checkpoint_config.get(CONF, {}),
  1206. **saved.config.get(CONF, {}),
  1207. },
  1208. }
  1209. self.prev_checkpoint_config = saved.parent_config
  1210. self.checkpoint_id_saved = saved.checkpoint["id"]
  1211. self.checkpoint = saved.checkpoint
  1212. self.checkpoint_metadata = saved.metadata
  1213. self.checkpoint_pending_writes = (
  1214. [(str(tid), k, v) for tid, k, v in saved.pending_writes]
  1215. if saved.pending_writes is not None
  1216. else []
  1217. )
  1218. self.submit = await self.stack.enter_async_context(
  1219. AsyncBackgroundExecutor(self.config)
  1220. )
  1221. self.channels, self.managed = channels_from_checkpoint(
  1222. self.specs, self.checkpoint
  1223. )
  1224. self.stack.push(self._suppress_interrupt)
  1225. self.status = "input"
  1226. self.step = self.checkpoint_metadata["step"] + 1
  1227. self.stop = self.step + self.config["recursion_limit"] + 1
  1228. self.checkpoint_previous_versions = self.checkpoint["channel_versions"].copy()
  1229. self.updated_channels = self._first(
  1230. input_keys=self.input_keys,
  1231. updated_channels=set(self.checkpoint.get("updated_channels")) # type: ignore[arg-type]
  1232. if self.checkpoint.get("updated_channels")
  1233. else None,
  1234. )
  1235. return self
  1236. async def __aexit__(
  1237. self,
  1238. exc_type: type[BaseException] | None,
  1239. exc_value: BaseException | None,
  1240. traceback: TracebackType | None,
  1241. ) -> bool | None:
  1242. # unwind stack
  1243. exit_task = asyncio.create_task(
  1244. self.stack.__aexit__(exc_type, exc_value, traceback)
  1245. )
  1246. try:
  1247. return await exit_task
  1248. except asyncio.CancelledError as e:
  1249. # Bubble up the exit task upon cancellation to permit the API
  1250. # consumer to await it before e.g., reusing the DB connection.
  1251. e.args = (*e.args, exit_task)
  1252. raise