_algo.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233
  1. from __future__ import annotations
  2. import binascii
  3. import itertools
  4. import sys
  5. import threading
  6. from collections import defaultdict, deque
  7. from collections.abc import Callable, Iterable, Mapping, Sequence
  8. from copy import copy
  9. from functools import partial
  10. from hashlib import sha1
  11. from typing import (
  12. Any,
  13. Literal,
  14. NamedTuple,
  15. Protocol,
  16. cast,
  17. overload,
  18. )
  19. from langchain_core.callbacks import Callbacks
  20. from langchain_core.callbacks.manager import AsyncParentRunManager, ParentRunManager
  21. from langchain_core.runnables.config import RunnableConfig
  22. from langgraph.checkpoint.base import (
  23. BaseCheckpointSaver,
  24. ChannelVersions,
  25. Checkpoint,
  26. PendingWrite,
  27. V,
  28. )
  29. from langgraph.store.base import BaseStore
  30. from xxhash import xxh3_128_hexdigest
  31. from langgraph._internal._config import merge_configs, patch_config
  32. from langgraph._internal._constants import (
  33. CACHE_NS_WRITES,
  34. CONF,
  35. CONFIG_KEY_CHECKPOINT_ID,
  36. CONFIG_KEY_CHECKPOINT_MAP,
  37. CONFIG_KEY_CHECKPOINT_NS,
  38. CONFIG_KEY_CHECKPOINTER,
  39. CONFIG_KEY_READ,
  40. CONFIG_KEY_RESUME_MAP,
  41. CONFIG_KEY_RUNTIME,
  42. CONFIG_KEY_SCRATCHPAD,
  43. CONFIG_KEY_SEND,
  44. CONFIG_KEY_TASK_ID,
  45. ERROR,
  46. INTERRUPT,
  47. NO_WRITES,
  48. NS_END,
  49. NS_SEP,
  50. NULL_TASK_ID,
  51. PREVIOUS,
  52. PULL,
  53. PUSH,
  54. RESERVED,
  55. RESUME,
  56. RETURN,
  57. TASKS,
  58. )
  59. from langgraph._internal._scratchpad import PregelScratchpad
  60. from langgraph._internal._typing import EMPTY_SEQ, MISSING
  61. from langgraph.channels.base import BaseChannel
  62. from langgraph.channels.topic import Topic
  63. from langgraph.channels.untracked_value import UntrackedValue
  64. from langgraph.constants import TAG_HIDDEN
  65. from langgraph.managed.base import ManagedValueMapping
  66. from langgraph.pregel._call import get_runnable_for_task, identifier
  67. from langgraph.pregel._io import read_channels
  68. from langgraph.pregel._log import logger
  69. from langgraph.pregel._read import INPUT_CACHE_KEY_TYPE, PregelNode
  70. from langgraph.runtime import DEFAULT_RUNTIME, Runtime
  71. from langgraph.types import (
  72. All,
  73. CacheKey,
  74. CachePolicy,
  75. PregelExecutableTask,
  76. PregelTask,
  77. RetryPolicy,
  78. Send,
  79. )
  80. GetNextVersion = Callable[[V | None, None], V]
  81. SUPPORTS_EXC_NOTES = sys.version_info >= (3, 11)
  82. class WritesProtocol(Protocol):
  83. """Protocol for objects containing writes to be applied to checkpoint.
  84. Implemented by PregelTaskWrites and PregelExecutableTask."""
  85. @property
  86. def path(self) -> tuple[str | int | tuple, ...]: ...
  87. @property
  88. def name(self) -> str: ...
  89. @property
  90. def writes(self) -> Sequence[tuple[str, Any]]: ...
  91. @property
  92. def triggers(self) -> Sequence[str]: ...
  93. class PregelTaskWrites(NamedTuple):
  94. """Simplest implementation of WritesProtocol, for usage with writes that
  95. don't originate from a runnable task, eg. graph input, update_state, etc."""
  96. path: tuple[str | int | tuple, ...]
  97. name: str
  98. writes: Sequence[tuple[str, Any]]
  99. triggers: Sequence[str]
  100. class Call:
  101. __slots__ = ("func", "input", "retry_policy", "cache_policy", "callbacks")
  102. func: Callable
  103. input: tuple[tuple[Any, ...], dict[str, Any]]
  104. retry_policy: Sequence[RetryPolicy] | None
  105. cache_policy: CachePolicy | None
  106. callbacks: Callbacks
  107. def __init__(
  108. self,
  109. func: Callable,
  110. input: tuple[tuple[Any, ...], dict[str, Any]],
  111. *,
  112. retry_policy: Sequence[RetryPolicy] | None,
  113. cache_policy: CachePolicy | None,
  114. callbacks: Callbacks,
  115. ) -> None:
  116. self.func = func
  117. self.input = input
  118. self.retry_policy = retry_policy
  119. self.cache_policy = cache_policy
  120. self.callbacks = callbacks
  121. def should_interrupt(
  122. checkpoint: Checkpoint,
  123. interrupt_nodes: All | Sequence[str],
  124. tasks: Iterable[PregelExecutableTask],
  125. ) -> list[PregelExecutableTask]:
  126. """Check if the graph should be interrupted based on current state."""
  127. version_type = type(next(iter(checkpoint["channel_versions"].values()), None))
  128. null_version = version_type() # type: ignore[misc]
  129. seen = checkpoint["versions_seen"].get(INTERRUPT, {})
  130. # interrupt if any channel has been updated since last interrupt
  131. any_updates_since_prev_interrupt = any(
  132. version > seen.get(chan, null_version) # type: ignore[operator]
  133. for chan, version in checkpoint["channel_versions"].items()
  134. )
  135. # and any triggered node is in interrupt_nodes list
  136. return (
  137. [
  138. task
  139. for task in tasks
  140. if (
  141. (
  142. not task.config
  143. or TAG_HIDDEN not in task.config.get("tags", EMPTY_SEQ)
  144. )
  145. if interrupt_nodes == "*"
  146. else task.name in interrupt_nodes
  147. )
  148. ]
  149. if any_updates_since_prev_interrupt
  150. else []
  151. )
  152. def local_read(
  153. scratchpad: PregelScratchpad,
  154. channels: Mapping[str, BaseChannel],
  155. managed: ManagedValueMapping,
  156. task: WritesProtocol,
  157. select: list[str] | str,
  158. fresh: bool = False,
  159. ) -> dict[str, Any] | Any:
  160. """Function injected under CONFIG_KEY_READ in task config, to read current state.
  161. Used by conditional edges to read a copy of the state with reflecting the writes
  162. from that node only."""
  163. updated: dict[str, list[Any]] = defaultdict(list)
  164. if isinstance(select, str):
  165. managed_keys = []
  166. for c, v in task.writes:
  167. if c == select:
  168. updated[c].append(v)
  169. else:
  170. managed_keys = [k for k in select if k in managed]
  171. select = [k for k in select if k not in managed]
  172. for c, v in task.writes:
  173. if c in select:
  174. updated[c].append(v)
  175. if fresh:
  176. # apply writes
  177. local_channels: dict[str, BaseChannel] = {}
  178. for k in channels:
  179. cc = channels[k].copy()
  180. cc.update(updated[k])
  181. local_channels[k] = cc
  182. # read fresh values
  183. values = read_channels(local_channels, select)
  184. else:
  185. values = read_channels(channels, select)
  186. if managed_keys:
  187. values.update({k: managed[k].get(scratchpad) for k in managed_keys})
  188. return values
  189. def increment(current: int | None, channel: None) -> int:
  190. """Default channel versioning function, increments the current int version."""
  191. return current + 1 if current is not None else 1
  192. def apply_writes(
  193. checkpoint: Checkpoint,
  194. channels: Mapping[str, BaseChannel],
  195. tasks: Iterable[WritesProtocol],
  196. get_next_version: GetNextVersion | None,
  197. trigger_to_nodes: Mapping[str, Sequence[str]],
  198. ) -> set[str]:
  199. """Apply writes from a set of tasks (usually the tasks from a Pregel step)
  200. to the checkpoint and channels, and return managed values writes to be applied
  201. externally.
  202. Args:
  203. checkpoint: The checkpoint to update.
  204. channels: The channels to update.
  205. tasks: The tasks to apply writes from.
  206. get_next_version: Optional function to determine the next version of a channel.
  207. trigger_to_nodes: Mapping of channel names to the set of nodes that can be triggered by updates to that channel.
  208. Returns:
  209. Set of channels that were updated in this step.
  210. """
  211. # sort tasks on path, to ensure deterministic order for update application
  212. # any path parts after the 3rd are ignored for sorting
  213. # (we use them for eg. task ids which aren't good for sorting)
  214. tasks = sorted(tasks, key=lambda t: task_path_str(t.path[:3]))
  215. # if no task has triggers this is applying writes from the null task only
  216. # so we don't do anything other than update the channels written to
  217. bump_step = any(t.triggers for t in tasks)
  218. # update seen versions
  219. for task in tasks:
  220. checkpoint["versions_seen"].setdefault(task.name, {}).update(
  221. {
  222. chan: checkpoint["channel_versions"][chan]
  223. for chan in task.triggers
  224. if chan in checkpoint["channel_versions"]
  225. }
  226. )
  227. # Find the highest version of all channels
  228. if get_next_version is None:
  229. next_version = None
  230. else:
  231. next_version = get_next_version(
  232. (
  233. max(checkpoint["channel_versions"].values())
  234. if checkpoint["channel_versions"]
  235. else None
  236. ),
  237. None,
  238. )
  239. # Consume all channels that were read
  240. for chan in {
  241. chan
  242. for task in tasks
  243. for chan in task.triggers
  244. if chan not in RESERVED and chan in channels
  245. }:
  246. if channels[chan].consume() and next_version is not None:
  247. checkpoint["channel_versions"][chan] = next_version
  248. # Group writes by channel
  249. pending_writes_by_channel: dict[str, list[Any]] = defaultdict(list)
  250. for task in tasks:
  251. for chan, val in task.writes:
  252. if chan in (NO_WRITES, PUSH, RESUME, INTERRUPT, RETURN, ERROR):
  253. pass
  254. elif chan in channels:
  255. pending_writes_by_channel[chan].append(val)
  256. else:
  257. logger.warning(
  258. f"Task {task.name} with path {task.path} wrote to unknown channel {chan}, ignoring it."
  259. )
  260. # Apply writes to channels
  261. updated_channels: set[str] = set()
  262. for chan, vals in pending_writes_by_channel.items():
  263. if chan in channels:
  264. if channels[chan].update(vals) and next_version is not None:
  265. checkpoint["channel_versions"][chan] = next_version
  266. # unavailable channels can't trigger tasks, so don't add them
  267. if channels[chan].is_available():
  268. updated_channels.add(chan)
  269. # Channels that weren't updated in this step are notified of a new step
  270. if bump_step:
  271. for chan in channels:
  272. if channels[chan].is_available() and chan not in updated_channels:
  273. if channels[chan].update(EMPTY_SEQ) and next_version is not None:
  274. checkpoint["channel_versions"][chan] = next_version
  275. # unavailable channels can't trigger tasks, so don't add them
  276. if channels[chan].is_available():
  277. updated_channels.add(chan)
  278. # If this is (tentatively) the last superstep, notify all channels of finish
  279. if bump_step and updated_channels.isdisjoint(trigger_to_nodes):
  280. for chan in channels:
  281. if channels[chan].finish() and next_version is not None:
  282. checkpoint["channel_versions"][chan] = next_version
  283. # unavailable channels can't trigger tasks, so don't add them
  284. if channels[chan].is_available():
  285. updated_channels.add(chan)
  286. # Return managed values writes to be applied externally
  287. return updated_channels
  288. @overload
  289. def prepare_next_tasks(
  290. checkpoint: Checkpoint,
  291. pending_writes: list[PendingWrite],
  292. processes: Mapping[str, PregelNode],
  293. channels: Mapping[str, BaseChannel],
  294. managed: ManagedValueMapping,
  295. config: RunnableConfig,
  296. step: int,
  297. stop: int,
  298. *,
  299. for_execution: Literal[False],
  300. store: Literal[None] = None,
  301. checkpointer: Literal[None] = None,
  302. manager: Literal[None] = None,
  303. trigger_to_nodes: Mapping[str, Sequence[str]] | None = None,
  304. updated_channels: set[str] | None = None,
  305. retry_policy: Sequence[RetryPolicy] = (),
  306. cache_policy: Literal[None] = None,
  307. ) -> dict[str, PregelTask]: ...
  308. @overload
  309. def prepare_next_tasks(
  310. checkpoint: Checkpoint,
  311. pending_writes: list[PendingWrite],
  312. processes: Mapping[str, PregelNode],
  313. channels: Mapping[str, BaseChannel],
  314. managed: ManagedValueMapping,
  315. config: RunnableConfig,
  316. step: int,
  317. stop: int,
  318. *,
  319. for_execution: Literal[True],
  320. store: BaseStore | None,
  321. checkpointer: BaseCheckpointSaver | None,
  322. manager: None | ParentRunManager | AsyncParentRunManager,
  323. trigger_to_nodes: Mapping[str, Sequence[str]] | None = None,
  324. updated_channels: set[str] | None = None,
  325. retry_policy: Sequence[RetryPolicy] = (),
  326. cache_policy: CachePolicy | None = None,
  327. ) -> dict[str, PregelExecutableTask]: ...
  328. def prepare_next_tasks(
  329. checkpoint: Checkpoint,
  330. pending_writes: list[PendingWrite],
  331. processes: Mapping[str, PregelNode],
  332. channels: Mapping[str, BaseChannel],
  333. managed: ManagedValueMapping,
  334. config: RunnableConfig,
  335. step: int,
  336. stop: int,
  337. *,
  338. for_execution: bool,
  339. store: BaseStore | None = None,
  340. checkpointer: BaseCheckpointSaver | None = None,
  341. manager: None | ParentRunManager | AsyncParentRunManager = None,
  342. trigger_to_nodes: Mapping[str, Sequence[str]] | None = None,
  343. updated_channels: set[str] | None = None,
  344. retry_policy: Sequence[RetryPolicy] = (),
  345. cache_policy: CachePolicy | None = None,
  346. ) -> dict[str, PregelTask] | dict[str, PregelExecutableTask]:
  347. """Prepare the set of tasks that will make up the next Pregel step.
  348. Args:
  349. checkpoint: The current checkpoint.
  350. pending_writes: The list of pending writes.
  351. processes: The mapping of process names to PregelNode instances.
  352. channels: The mapping of channel names to BaseChannel instances.
  353. managed: The mapping of managed value names to functions.
  354. config: The `Runnable` configuration.
  355. step: The current step.
  356. for_execution: Whether the tasks are being prepared for execution.
  357. store: An instance of BaseStore to make it available for usage within tasks.
  358. checkpointer: `Checkpointer` instance used for saving checkpoints.
  359. manager: The parent run manager to use for the tasks.
  360. trigger_to_nodes: Optional: Mapping of channel names to the set of nodes
  361. that are can be triggered by that channel.
  362. updated_channels: Optional. Set of channel names that have been updated during
  363. the previous step. Using in conjunction with trigger_to_nodes to speed
  364. up the process of determining which nodes should be triggered in the next
  365. step.
  366. Returns:
  367. A dictionary of tasks to be executed. The keys are the task ids and the values
  368. are the tasks themselves. This is the union of all PUSH tasks (Sends)
  369. and PULL tasks (nodes triggered by edges).
  370. """
  371. input_cache: dict[INPUT_CACHE_KEY_TYPE, Any] = {}
  372. checkpoint_id_bytes = binascii.unhexlify(checkpoint["id"].replace("-", ""))
  373. null_version = checkpoint_null_version(checkpoint)
  374. tasks: list[PregelTask | PregelExecutableTask] = []
  375. # Consume pending tasks
  376. tasks_channel = cast(Topic[Send] | None, channels.get(TASKS))
  377. if tasks_channel and tasks_channel.is_available():
  378. for idx, _ in enumerate(tasks_channel.get()):
  379. if task := prepare_single_task(
  380. (PUSH, idx),
  381. None,
  382. checkpoint=checkpoint,
  383. checkpoint_id_bytes=checkpoint_id_bytes,
  384. checkpoint_null_version=null_version,
  385. pending_writes=pending_writes,
  386. processes=processes,
  387. channels=channels,
  388. managed=managed,
  389. config=config,
  390. step=step,
  391. stop=stop,
  392. for_execution=for_execution,
  393. store=store,
  394. checkpointer=checkpointer,
  395. manager=manager,
  396. input_cache=input_cache,
  397. cache_policy=cache_policy,
  398. retry_policy=retry_policy,
  399. ):
  400. tasks.append(task)
  401. # This section is an optimization that allows which nodes will be active
  402. # during the next step.
  403. # When there's information about:
  404. # 1. Which channels were updated in the previous step
  405. # 2. Which nodes are triggered by which channels
  406. # Then we can determine which nodes should be triggered in the next step
  407. # without having to cycle through all nodes.
  408. if updated_channels and trigger_to_nodes:
  409. triggered_nodes: set[str] = set()
  410. # Get all nodes that have triggers associated with an updated channel
  411. for channel in updated_channels:
  412. if node_ids := trigger_to_nodes.get(channel):
  413. triggered_nodes.update(node_ids)
  414. # Sort the nodes to ensure deterministic order
  415. candidate_nodes: Iterable[str] = sorted(triggered_nodes)
  416. elif not checkpoint["channel_versions"]:
  417. candidate_nodes = ()
  418. else:
  419. candidate_nodes = processes.keys()
  420. # Check if any processes should be run in next step
  421. # If so, prepare the values to be passed to them
  422. for name in candidate_nodes:
  423. if task := prepare_single_task(
  424. (PULL, name),
  425. None,
  426. checkpoint=checkpoint,
  427. checkpoint_id_bytes=checkpoint_id_bytes,
  428. checkpoint_null_version=null_version,
  429. pending_writes=pending_writes,
  430. processes=processes,
  431. channels=channels,
  432. managed=managed,
  433. config=config,
  434. step=step,
  435. stop=stop,
  436. for_execution=for_execution,
  437. store=store,
  438. checkpointer=checkpointer,
  439. manager=manager,
  440. input_cache=input_cache,
  441. cache_policy=cache_policy,
  442. retry_policy=retry_policy,
  443. ):
  444. tasks.append(task)
  445. return {t.id: t for t in tasks}
  446. PUSH_TRIGGER = (PUSH,)
  447. class _TaskIDFn(Protocol):
  448. def __call__(self, namespace: bytes, *parts: str | bytes) -> str:
  449. pass
  450. def prepare_single_task(
  451. task_path: tuple[Any, ...],
  452. task_id_checksum: str | None,
  453. *,
  454. checkpoint: Checkpoint,
  455. checkpoint_id_bytes: bytes,
  456. checkpoint_null_version: V | None,
  457. pending_writes: list[PendingWrite],
  458. processes: Mapping[str, PregelNode],
  459. channels: Mapping[str, BaseChannel],
  460. managed: ManagedValueMapping,
  461. config: RunnableConfig,
  462. step: int,
  463. stop: int,
  464. for_execution: bool,
  465. store: BaseStore | None = None,
  466. checkpointer: BaseCheckpointSaver | None = None,
  467. manager: None | ParentRunManager | AsyncParentRunManager = None,
  468. input_cache: dict[INPUT_CACHE_KEY_TYPE, Any] | None = None,
  469. cache_policy: CachePolicy | None = None,
  470. retry_policy: Sequence[RetryPolicy] = (),
  471. ) -> None | PregelTask | PregelExecutableTask:
  472. """Prepares a single task for the next Pregel step, given a task path, which
  473. uniquely identifies a PUSH or PULL task within the graph."""
  474. configurable = config.get(CONF, {})
  475. parent_ns = configurable.get(CONFIG_KEY_CHECKPOINT_NS, "")
  476. task_id_func = _xxhash_str if checkpoint["v"] > 1 else _uuid5_str
  477. if task_path[0] == PUSH and isinstance(task_path[-1], Call):
  478. return prepare_push_task_functional(
  479. cast(tuple[str, tuple, int, str, Call], task_path),
  480. task_id_checksum,
  481. checkpoint=checkpoint,
  482. checkpoint_id_bytes=checkpoint_id_bytes,
  483. pending_writes=pending_writes,
  484. channels=channels,
  485. managed=managed,
  486. config=config,
  487. step=step,
  488. stop=stop,
  489. for_execution=for_execution,
  490. store=store,
  491. checkpointer=checkpointer,
  492. manager=manager,
  493. cache_policy=cache_policy,
  494. retry_policy=retry_policy,
  495. parent_ns=parent_ns,
  496. task_id_func=task_id_func,
  497. )
  498. elif task_path[0] == PUSH:
  499. return prepare_push_task_send(
  500. cast(tuple[str, tuple], task_path),
  501. task_id_checksum,
  502. checkpoint=checkpoint,
  503. checkpoint_id_bytes=checkpoint_id_bytes,
  504. pending_writes=pending_writes,
  505. channels=channels,
  506. managed=managed,
  507. config=config,
  508. step=step,
  509. processes=processes,
  510. stop=stop,
  511. for_execution=for_execution,
  512. store=store,
  513. checkpointer=checkpointer,
  514. manager=manager,
  515. cache_policy=cache_policy,
  516. retry_policy=retry_policy,
  517. parent_ns=parent_ns,
  518. task_id_func=task_id_func,
  519. )
  520. elif task_path[0] == PULL:
  521. # (PULL, node name)
  522. name = cast(str, task_path[1])
  523. if name not in processes:
  524. return
  525. proc = processes[name]
  526. if checkpoint_null_version is None:
  527. return
  528. # If any of the channels read by this process were updated
  529. if _triggers(
  530. channels,
  531. checkpoint["channel_versions"],
  532. checkpoint["versions_seen"].get(name),
  533. checkpoint_null_version,
  534. proc,
  535. ):
  536. triggers = tuple(sorted(proc.triggers))
  537. # create task id
  538. checkpoint_ns = f"{parent_ns}{NS_SEP}{name}" if parent_ns else name
  539. task_id = task_id_func(
  540. checkpoint_id_bytes,
  541. checkpoint_ns,
  542. str(step),
  543. name,
  544. PULL,
  545. *triggers,
  546. )
  547. task_checkpoint_ns = f"{checkpoint_ns}{NS_END}{task_id}"
  548. # create scratchpad
  549. scratchpad = _scratchpad(
  550. config[CONF].get(CONFIG_KEY_SCRATCHPAD),
  551. pending_writes,
  552. task_id,
  553. xxh3_128_hexdigest(task_checkpoint_ns.encode()),
  554. config[CONF].get(CONFIG_KEY_RESUME_MAP),
  555. step,
  556. stop,
  557. )
  558. # create task input
  559. try:
  560. val = _proc_input(
  561. proc,
  562. managed,
  563. channels,
  564. for_execution=for_execution,
  565. input_cache=input_cache,
  566. scratchpad=scratchpad,
  567. )
  568. if val is MISSING:
  569. return
  570. except Exception as exc:
  571. if SUPPORTS_EXC_NOTES:
  572. exc.add_note(
  573. f"Before task with name '{name}' and path '{task_path[:3]}'"
  574. )
  575. raise
  576. metadata = {
  577. "langgraph_step": step,
  578. "langgraph_node": name,
  579. "langgraph_triggers": triggers,
  580. "langgraph_path": task_path[:3],
  581. "langgraph_checkpoint_ns": task_checkpoint_ns,
  582. }
  583. if task_id_checksum is not None:
  584. assert task_id == task_id_checksum, f"{task_id} != {task_id_checksum}"
  585. if for_execution:
  586. if node := proc.node:
  587. if proc.metadata:
  588. metadata.update(proc.metadata)
  589. writes: deque[tuple[str, Any]] = deque()
  590. cache_policy = proc.cache_policy or cache_policy
  591. if cache_policy:
  592. args_key = cache_policy.key_func(val)
  593. cache_key = CacheKey(
  594. (
  595. CACHE_NS_WRITES,
  596. (identifier(proc) or "__dynamic__"),
  597. name,
  598. ),
  599. xxh3_128_hexdigest(
  600. (
  601. args_key.encode()
  602. if isinstance(args_key, str)
  603. else args_key
  604. ),
  605. ),
  606. cache_policy.ttl,
  607. )
  608. else:
  609. cache_key = None
  610. runtime = cast(
  611. Runtime, configurable.get(CONFIG_KEY_RUNTIME, DEFAULT_RUNTIME)
  612. )
  613. runtime = runtime.override(
  614. previous=checkpoint["channel_values"].get(PREVIOUS, None),
  615. store=store,
  616. )
  617. additional_config = {
  618. "metadata": metadata,
  619. "tags": proc.tags,
  620. }
  621. return PregelExecutableTask(
  622. name,
  623. val,
  624. node,
  625. writes,
  626. patch_config(
  627. merge_configs(
  628. config, cast(RunnableConfig, additional_config)
  629. ),
  630. run_name=name,
  631. callbacks=(
  632. manager.get_child(f"graph:step:{step}")
  633. if manager
  634. else None
  635. ),
  636. configurable={
  637. CONFIG_KEY_TASK_ID: task_id,
  638. # deque.extend is thread-safe
  639. CONFIG_KEY_SEND: writes.extend,
  640. CONFIG_KEY_READ: partial(
  641. local_read,
  642. scratchpad,
  643. channels,
  644. managed,
  645. PregelTaskWrites(
  646. task_path[:3],
  647. name,
  648. writes,
  649. triggers,
  650. ),
  651. ),
  652. CONFIG_KEY_CHECKPOINTER: (
  653. checkpointer
  654. or configurable.get(CONFIG_KEY_CHECKPOINTER)
  655. ),
  656. CONFIG_KEY_CHECKPOINT_MAP: {
  657. **configurable.get(CONFIG_KEY_CHECKPOINT_MAP, {}),
  658. parent_ns: checkpoint["id"],
  659. },
  660. CONFIG_KEY_CHECKPOINT_ID: None,
  661. CONFIG_KEY_CHECKPOINT_NS: task_checkpoint_ns,
  662. CONFIG_KEY_SCRATCHPAD: scratchpad,
  663. CONFIG_KEY_RUNTIME: runtime,
  664. },
  665. ),
  666. triggers,
  667. proc.retry_policy or retry_policy,
  668. cache_key,
  669. task_id,
  670. task_path[:3],
  671. writers=proc.flat_writers,
  672. subgraphs=proc.subgraphs,
  673. )
  674. else:
  675. return PregelTask(task_id, name, task_path[:3])
  676. def prepare_push_task_functional(
  677. task_path: tuple[str, tuple, int, str, Call],
  678. # (PUSH, parent task path, idx of PUSH write, id of parent task, Call)
  679. task_id_checksum: str | None,
  680. *,
  681. checkpoint: Checkpoint,
  682. checkpoint_id_bytes: bytes,
  683. pending_writes: list[PendingWrite],
  684. channels: Mapping[str, BaseChannel],
  685. managed: ManagedValueMapping,
  686. config: RunnableConfig,
  687. step: int,
  688. stop: int,
  689. for_execution: bool,
  690. store: BaseStore | None = None,
  691. checkpointer: BaseCheckpointSaver | None = None,
  692. manager: None | ParentRunManager | AsyncParentRunManager = None,
  693. cache_policy: CachePolicy | None = None,
  694. retry_policy: Sequence[RetryPolicy] = (),
  695. parent_ns: str,
  696. # namespace: bytes, *parts: str | bytes
  697. task_id_func: _TaskIDFn,
  698. ) -> PregelTask | PregelExecutableTask:
  699. """Prepare a push task with an attached caller. Used for the functional API."""
  700. configurable = config.get(CONF, {})
  701. call = task_path[-1]
  702. proc_ = get_runnable_for_task(call.func)
  703. name = proc_.name
  704. if name is None:
  705. raise ValueError("`call` functions must have a `__name__` attribute")
  706. # create task id
  707. triggers: Sequence[str] = PUSH_TRIGGER
  708. checkpoint_ns = f"{parent_ns}{NS_SEP}{name}" if parent_ns else name
  709. task_id = task_id_func(
  710. checkpoint_id_bytes,
  711. checkpoint_ns,
  712. str(step),
  713. name,
  714. PUSH,
  715. task_path_str(task_path[1]),
  716. str(task_path[2]),
  717. )
  718. task_checkpoint_ns = f"{checkpoint_ns}:{task_id}"
  719. # we append True to the task path to indicate that a call is being
  720. # made, so we should not return interrupts from this task (responsibility lies with the parent)
  721. in_progress_task_path = (*task_path[:3], True)
  722. metadata = {
  723. "langgraph_step": step,
  724. "langgraph_node": name,
  725. "langgraph_triggers": triggers,
  726. "langgraph_path": in_progress_task_path,
  727. "langgraph_checkpoint_ns": task_checkpoint_ns,
  728. }
  729. if task_id_checksum is not None:
  730. assert task_id == task_id_checksum, f"{task_id} != {task_id_checksum}"
  731. if for_execution:
  732. writes: deque[tuple[str, Any]] = deque()
  733. cache_policy = call.cache_policy or cache_policy
  734. if cache_policy:
  735. args_key = cache_policy.key_func(*call.input[0], **call.input[1])
  736. cache_key: CacheKey | None = CacheKey(
  737. (
  738. CACHE_NS_WRITES,
  739. (identifier(call.func) or "__dynamic__"),
  740. ),
  741. xxh3_128_hexdigest(
  742. args_key.encode() if isinstance(args_key, str) else args_key,
  743. ),
  744. cache_policy.ttl,
  745. )
  746. else:
  747. cache_key = None
  748. scratchpad = _scratchpad(
  749. configurable.get(CONFIG_KEY_SCRATCHPAD),
  750. pending_writes,
  751. task_id,
  752. xxh3_128_hexdigest(task_checkpoint_ns.encode()),
  753. configurable.get(CONFIG_KEY_RESUME_MAP),
  754. step,
  755. stop,
  756. )
  757. runtime = cast(Runtime, configurable.get(CONFIG_KEY_RUNTIME, DEFAULT_RUNTIME))
  758. runtime = runtime.override(store=store)
  759. return PregelExecutableTask(
  760. name,
  761. call.input,
  762. proc_,
  763. writes,
  764. patch_config(
  765. merge_configs(config, {"metadata": metadata}),
  766. run_name=name,
  767. callbacks=call.callbacks
  768. or (manager.get_child(f"graph:step:{step}") if manager else None),
  769. configurable={
  770. CONFIG_KEY_TASK_ID: task_id,
  771. # deque.extend is thread-safe
  772. CONFIG_KEY_SEND: writes.extend,
  773. CONFIG_KEY_READ: partial(
  774. local_read,
  775. scratchpad,
  776. channels,
  777. managed,
  778. PregelTaskWrites(in_progress_task_path, name, writes, triggers),
  779. ),
  780. CONFIG_KEY_CHECKPOINTER: (
  781. checkpointer or configurable.get(CONFIG_KEY_CHECKPOINTER)
  782. ),
  783. CONFIG_KEY_CHECKPOINT_MAP: {
  784. **configurable.get(CONFIG_KEY_CHECKPOINT_MAP, {}),
  785. parent_ns: checkpoint["id"],
  786. },
  787. CONFIG_KEY_CHECKPOINT_ID: None,
  788. CONFIG_KEY_CHECKPOINT_NS: task_checkpoint_ns,
  789. CONFIG_KEY_SCRATCHPAD: scratchpad,
  790. CONFIG_KEY_RUNTIME: runtime,
  791. },
  792. ),
  793. triggers,
  794. call.retry_policy or retry_policy,
  795. cache_key,
  796. task_id,
  797. in_progress_task_path,
  798. )
  799. else:
  800. return PregelTask(task_id, name, in_progress_task_path)
  801. def prepare_push_task_send(
  802. task_path: tuple[str, tuple],
  803. # (PUSH, parent task path)
  804. task_id_checksum: str | None,
  805. *,
  806. checkpoint: Checkpoint,
  807. checkpoint_id_bytes: bytes,
  808. pending_writes: list[PendingWrite],
  809. channels: Mapping[str, BaseChannel],
  810. managed: ManagedValueMapping,
  811. config: RunnableConfig,
  812. step: int,
  813. stop: int,
  814. for_execution: bool,
  815. store: BaseStore | None = None,
  816. checkpointer: BaseCheckpointSaver | None = None,
  817. manager: None | ParentRunManager | AsyncParentRunManager = None,
  818. cache_policy: CachePolicy | None = None,
  819. retry_policy: Sequence[RetryPolicy] = (),
  820. parent_ns: str,
  821. task_id_func: _TaskIDFn,
  822. processes: Mapping[str, PregelNode],
  823. ) -> PregelTask | PregelExecutableTask | None:
  824. if len(task_path) == 2:
  825. # SEND tasks, executed in superstep n+1
  826. # (PUSH, idx of pending send)
  827. idx = cast(int, task_path[1])
  828. if not channels[TASKS].is_available():
  829. return
  830. sends: Sequence[Send] = channels[TASKS].get()
  831. if idx < 0 or idx >= len(sends):
  832. return
  833. packet = sends[idx]
  834. if not isinstance(packet, Send):
  835. logger.warning(
  836. f"Ignoring invalid packet type {type(packet)} in pending sends"
  837. )
  838. return
  839. if packet.node not in processes:
  840. logger.warning(f"Ignoring unknown node name {packet.node} in pending sends")
  841. return
  842. # find process
  843. proc = processes[packet.node]
  844. proc_node = proc.node
  845. if proc_node is None:
  846. return
  847. # create task id
  848. triggers = PUSH_TRIGGER
  849. checkpoint_ns = (
  850. f"{parent_ns}{NS_SEP}{packet.node}" if parent_ns else packet.node
  851. )
  852. task_id = task_id_func(
  853. checkpoint_id_bytes,
  854. checkpoint_ns,
  855. str(step),
  856. packet.node,
  857. PUSH,
  858. str(idx),
  859. )
  860. else:
  861. logger.warning(f"Ignoring invalid PUSH task path {task_path}")
  862. return
  863. configurable = config.get(CONF, {})
  864. task_checkpoint_ns = f"{checkpoint_ns}:{task_id}"
  865. # we append False to the task path to indicate that a call is not being made
  866. # so we should return interrupts from this task
  867. translated_task_path = (*task_path[:3], False)
  868. metadata = {
  869. "langgraph_step": step,
  870. "langgraph_node": packet.node,
  871. "langgraph_triggers": triggers,
  872. "langgraph_path": translated_task_path,
  873. "langgraph_checkpoint_ns": task_checkpoint_ns,
  874. }
  875. if task_id_checksum is not None:
  876. assert task_id == task_id_checksum, f"{task_id} != {task_id_checksum}"
  877. if for_execution:
  878. if proc.metadata:
  879. metadata.update(proc.metadata)
  880. writes: deque[tuple[str, Any]] = deque()
  881. cache_policy = proc.cache_policy or cache_policy
  882. if cache_policy:
  883. args_key = cache_policy.key_func(packet.arg)
  884. cache_key = CacheKey(
  885. (
  886. CACHE_NS_WRITES,
  887. (identifier(proc) or "__dynamic__"),
  888. packet.node,
  889. ),
  890. xxh3_128_hexdigest(
  891. args_key.encode() if isinstance(args_key, str) else args_key,
  892. ),
  893. cache_policy.ttl,
  894. )
  895. else:
  896. cache_key = None
  897. scratchpad = _scratchpad(
  898. config[CONF].get(CONFIG_KEY_SCRATCHPAD),
  899. pending_writes,
  900. task_id,
  901. xxh3_128_hexdigest(task_checkpoint_ns.encode()),
  902. config[CONF].get(CONFIG_KEY_RESUME_MAP),
  903. step,
  904. stop,
  905. )
  906. runtime = cast(Runtime, configurable.get(CONFIG_KEY_RUNTIME, DEFAULT_RUNTIME))
  907. runtime = runtime.override(
  908. store=store, previous=checkpoint["channel_values"].get(PREVIOUS, None)
  909. )
  910. additional_config: RunnableConfig = {
  911. "metadata": metadata,
  912. "tags": proc.tags,
  913. }
  914. return PregelExecutableTask(
  915. packet.node,
  916. packet.arg,
  917. proc_node,
  918. writes,
  919. patch_config(
  920. merge_configs(config, additional_config),
  921. run_name=packet.node,
  922. callbacks=(
  923. manager.get_child(f"graph:step:{step}") if manager else None
  924. ),
  925. configurable={
  926. CONFIG_KEY_TASK_ID: task_id,
  927. # deque.extend is thread-safe
  928. CONFIG_KEY_SEND: writes.extend,
  929. CONFIG_KEY_READ: partial(
  930. local_read,
  931. scratchpad,
  932. channels,
  933. managed,
  934. PregelTaskWrites(
  935. translated_task_path, packet.node, writes, triggers
  936. ),
  937. ),
  938. CONFIG_KEY_CHECKPOINTER: (
  939. checkpointer or configurable.get(CONFIG_KEY_CHECKPOINTER)
  940. ),
  941. CONFIG_KEY_CHECKPOINT_MAP: {
  942. **configurable.get(CONFIG_KEY_CHECKPOINT_MAP, {}),
  943. parent_ns: checkpoint["id"],
  944. },
  945. CONFIG_KEY_CHECKPOINT_ID: None,
  946. CONFIG_KEY_CHECKPOINT_NS: task_checkpoint_ns,
  947. CONFIG_KEY_SCRATCHPAD: scratchpad,
  948. CONFIG_KEY_RUNTIME: runtime,
  949. },
  950. ),
  951. triggers,
  952. proc.retry_policy or retry_policy,
  953. cache_key,
  954. task_id,
  955. translated_task_path,
  956. writers=proc.flat_writers,
  957. subgraphs=proc.subgraphs,
  958. )
  959. else:
  960. return PregelTask(task_id, packet.node, translated_task_path)
  961. def checkpoint_null_version(
  962. checkpoint: Checkpoint,
  963. ) -> V | None:
  964. """Get the null version for the checkpoint, if available."""
  965. for version in checkpoint["channel_versions"].values():
  966. return type(version)()
  967. return None
  968. def _triggers(
  969. channels: Mapping[str, BaseChannel],
  970. versions: ChannelVersions,
  971. seen: ChannelVersions | None,
  972. null_version: V,
  973. proc: PregelNode,
  974. ) -> bool:
  975. if seen is None:
  976. for chan in proc.triggers:
  977. if channels[chan].is_available():
  978. return True
  979. else:
  980. for chan in proc.triggers:
  981. if channels[chan].is_available() and versions.get( # type: ignore[operator]
  982. chan, null_version
  983. ) > seen.get(chan, null_version):
  984. return True
  985. return False
  986. def _scratchpad(
  987. parent_scratchpad: PregelScratchpad | None,
  988. pending_writes: list[PendingWrite],
  989. task_id: str,
  990. namespace_hash: str,
  991. resume_map: dict[str, Any] | None,
  992. step: int,
  993. stop: int,
  994. ) -> PregelScratchpad:
  995. if len(pending_writes) > 0:
  996. # find global resume value
  997. for w in pending_writes:
  998. if w[0] == NULL_TASK_ID and w[1] == RESUME:
  999. null_resume_write = w
  1000. break
  1001. else:
  1002. # None cannot be used as a resume value, because it would be difficult to
  1003. # distinguish from missing when used over http
  1004. null_resume_write = None
  1005. # find task-specific resume value
  1006. for w in pending_writes:
  1007. if w[0] == task_id and w[1] == RESUME:
  1008. task_resume_write = w[2]
  1009. if not isinstance(task_resume_write, list):
  1010. task_resume_write = [task_resume_write]
  1011. break
  1012. else:
  1013. task_resume_write = []
  1014. del w
  1015. # find namespace and task-specific resume value
  1016. if resume_map and namespace_hash in resume_map:
  1017. mapped_resume_write = resume_map[namespace_hash]
  1018. task_resume_write.append(mapped_resume_write)
  1019. else:
  1020. null_resume_write = None
  1021. task_resume_write = []
  1022. def get_null_resume(consume: bool = False) -> Any:
  1023. if null_resume_write is None:
  1024. if parent_scratchpad is not None:
  1025. return parent_scratchpad.get_null_resume(consume)
  1026. return None
  1027. if consume:
  1028. try:
  1029. pending_writes.remove(null_resume_write)
  1030. return null_resume_write[2]
  1031. except ValueError:
  1032. return None
  1033. return null_resume_write[2]
  1034. # using itertools.count as an atomic counter (+= 1 is not thread-safe)
  1035. return PregelScratchpad(
  1036. step=step,
  1037. stop=stop,
  1038. # call
  1039. call_counter=LazyAtomicCounter(),
  1040. # interrupt
  1041. interrupt_counter=LazyAtomicCounter(),
  1042. resume=task_resume_write,
  1043. get_null_resume=get_null_resume,
  1044. # subgraph
  1045. subgraph_counter=LazyAtomicCounter(),
  1046. )
  1047. def _proc_input(
  1048. proc: PregelNode,
  1049. managed: ManagedValueMapping,
  1050. channels: Mapping[str, BaseChannel],
  1051. *,
  1052. for_execution: bool,
  1053. scratchpad: PregelScratchpad,
  1054. input_cache: dict[INPUT_CACHE_KEY_TYPE, Any] | None,
  1055. ) -> Any:
  1056. """Prepare input for a PULL task, based on the process's channels and triggers."""
  1057. # if in cache return shallow copy
  1058. if input_cache is not None and proc.input_cache_key in input_cache:
  1059. return copy(input_cache[proc.input_cache_key])
  1060. # If all trigger channels subscribed by this process are not empty
  1061. # then invoke the process with the values of all non-empty channels
  1062. if isinstance(proc.channels, list):
  1063. val: dict[str, Any] = {}
  1064. for chan in proc.channels:
  1065. if chan in channels:
  1066. if channels[chan].is_available():
  1067. val[chan] = channels[chan].get()
  1068. else:
  1069. val[chan] = managed[chan].get(scratchpad)
  1070. elif isinstance(proc.channels, str):
  1071. if proc.channels in channels:
  1072. if channels[proc.channels].is_available():
  1073. val = channels[proc.channels].get()
  1074. else:
  1075. return MISSING
  1076. else:
  1077. return MISSING
  1078. else:
  1079. raise RuntimeError(
  1080. f"Invalid channels type, expected list or dict, got {proc.channels}"
  1081. )
  1082. # If the process has a mapper, apply it to the value
  1083. if for_execution and proc.mapper is not None:
  1084. val = proc.mapper(val)
  1085. # Cache the input value
  1086. if input_cache is not None:
  1087. input_cache[proc.input_cache_key] = val
  1088. return val
  1089. def _uuid5_str(namespace: bytes, *parts: str | bytes) -> str:
  1090. """Generate a UUID from the SHA-1 hash of a namespace and str parts."""
  1091. sha = sha1(namespace, usedforsecurity=False)
  1092. sha.update(b"".join(p.encode() if isinstance(p, str) else p for p in parts))
  1093. hex = sha.hexdigest()
  1094. return f"{hex[:8]}-{hex[8:12]}-{hex[12:16]}-{hex[16:20]}-{hex[20:32]}"
  1095. def _xxhash_str(namespace: bytes, *parts: str | bytes) -> str:
  1096. """Generate a UUID from the XXH3 hash of a namespace and str parts."""
  1097. hex = xxh3_128_hexdigest(
  1098. namespace + b"".join(p.encode() if isinstance(p, str) else p for p in parts)
  1099. )
  1100. return f"{hex[:8]}-{hex[8:12]}-{hex[12:16]}-{hex[16:20]}-{hex[20:32]}"
  1101. def task_path_str(tup: str | int | tuple) -> str:
  1102. """Generate a string representation of the task path."""
  1103. return (
  1104. f"~{', '.join(task_path_str(x) for x in tup)}"
  1105. if isinstance(tup, (tuple, list))
  1106. else f"{tup:010d}"
  1107. if isinstance(tup, int)
  1108. else str(tup)
  1109. )
  1110. LAZY_ATOMIC_COUNTER_LOCK = threading.Lock()
  1111. class LazyAtomicCounter:
  1112. __slots__ = ("_counter",)
  1113. _counter: Callable[[], int] | None
  1114. def __init__(self) -> None:
  1115. self._counter = None
  1116. def __call__(self) -> int:
  1117. if self._counter is None:
  1118. with LAZY_ATOMIC_COUNTER_LOCK:
  1119. if self._counter is None:
  1120. self._counter = itertools.count(0).__next__
  1121. return self._counter()
  1122. def sanitize_untracked_values_in_send(
  1123. packet: Send, channels: Mapping[str, BaseChannel]
  1124. ) -> Send:
  1125. """Pop any values belonging to UntrackedValue channels in Send.arg for safe checkpointing.
  1126. Send is often called with state to be passed to the dest node, which may contain
  1127. UntrackedValues at the top level. Send is not typed and arg may be a nested dict."""
  1128. if not isinstance(packet.arg, dict):
  1129. # Command
  1130. return packet
  1131. # top level keys should be the channel names
  1132. sanitized_arg = {
  1133. k: v
  1134. for k, v in packet.arg.items()
  1135. if not isinstance(channels.get(k), UntrackedValue)
  1136. }
  1137. return Send(node=packet.node, arg=sanitized_arg)