debug.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. from __future__ import annotations
  2. from collections.abc import Iterable, Iterator, Mapping, Sequence
  3. from dataclasses import asdict
  4. from typing import Any
  5. from uuid import UUID
  6. from langchain_core.runnables import RunnableConfig
  7. from langgraph.checkpoint.base import CheckpointMetadata, PendingWrite
  8. from typing_extensions import TypedDict
  9. from langgraph._internal._config import patch_checkpoint_map
  10. from langgraph._internal._constants import (
  11. CONF,
  12. CONFIG_KEY_CHECKPOINT_NS,
  13. ERROR,
  14. INTERRUPT,
  15. NS_END,
  16. NS_SEP,
  17. RETURN,
  18. )
  19. from langgraph._internal._typing import MISSING
  20. from langgraph.channels.base import BaseChannel
  21. from langgraph.constants import TAG_HIDDEN
  22. from langgraph.pregel._io import read_channels
  23. from langgraph.types import PregelExecutableTask, PregelTask, StateSnapshot
  24. __all__ = ("TaskPayload", "TaskResultPayload", "CheckpointTask", "CheckpointPayload")
  25. class TaskPayload(TypedDict):
  26. id: str
  27. name: str
  28. input: Any
  29. triggers: list[str]
  30. class TaskResultPayload(TypedDict):
  31. id: str
  32. name: str
  33. error: str | None
  34. interrupts: list[dict]
  35. result: dict[str, Any]
  36. class CheckpointTask(TypedDict):
  37. id: str
  38. name: str
  39. error: str | None
  40. interrupts: list[dict]
  41. state: StateSnapshot | RunnableConfig | None
  42. class CheckpointPayload(TypedDict):
  43. config: RunnableConfig | None
  44. metadata: CheckpointMetadata
  45. values: dict[str, Any]
  46. next: list[str]
  47. parent_config: RunnableConfig | None
  48. tasks: list[CheckpointTask]
  49. TASK_NAMESPACE = UUID("6ba7b831-9dad-11d1-80b4-00c04fd430c8")
  50. def map_debug_tasks(tasks: Iterable[PregelExecutableTask]) -> Iterator[TaskPayload]:
  51. """Produce "task" events for stream_mode=debug."""
  52. for task in tasks:
  53. if task.config is not None and TAG_HIDDEN in task.config.get("tags", []):
  54. continue
  55. yield {
  56. "id": task.id,
  57. "name": task.name,
  58. "input": task.input,
  59. "triggers": task.triggers,
  60. }
  61. def is_multiple_channel_write(value: Any) -> bool:
  62. """Return True if the payload already wraps multiple writes from the same channel."""
  63. return (
  64. isinstance(value, dict)
  65. and "$writes" in value
  66. and isinstance(value["$writes"], list)
  67. )
  68. def map_task_result_writes(writes: Sequence[tuple[str, Any]]) -> dict[str, Any]:
  69. """Folds task writes into a result dict and aggregates multiple writes to the same channel.
  70. If the channel contains a single write, we record the write in the result dict as `{channel: write}`
  71. If the channel contains multiple writes, we record the writes in the result dict as `{channel: {'$writes': [write1, write2, ...]}}`"""
  72. result: dict[str, Any] = {}
  73. for channel, value in writes:
  74. existing = result.get(channel)
  75. if existing is not None:
  76. channel_writes = (
  77. existing["$writes"]
  78. if is_multiple_channel_write(existing)
  79. else [existing]
  80. )
  81. channel_writes.append(value)
  82. result[channel] = {"$writes": channel_writes}
  83. else:
  84. result[channel] = value
  85. return result
  86. def map_debug_task_results(
  87. task_tup: tuple[PregelExecutableTask, Sequence[tuple[str, Any]]],
  88. stream_keys: str | Sequence[str],
  89. ) -> Iterator[TaskResultPayload]:
  90. """Produce "task_result" events for stream_mode=debug."""
  91. stream_channels_list = (
  92. [stream_keys] if isinstance(stream_keys, str) else stream_keys
  93. )
  94. task, writes = task_tup
  95. yield {
  96. "id": task.id,
  97. "name": task.name,
  98. "error": next((w[1] for w in writes if w[0] == ERROR), None),
  99. "result": map_task_result_writes(
  100. [w for w in writes if w[0] in stream_channels_list or w[0] == RETURN]
  101. ),
  102. "interrupts": [
  103. asdict(v)
  104. for w in writes
  105. if w[0] == INTERRUPT
  106. for v in (w[1] if isinstance(w[1], Sequence) else [w[1]])
  107. ],
  108. }
  109. def rm_pregel_keys(config: RunnableConfig | None) -> RunnableConfig | None:
  110. """Remove pregel-specific keys from the config."""
  111. if config is None:
  112. return config
  113. return {
  114. "configurable": {
  115. k: v
  116. for k, v in config.get("configurable", {}).items()
  117. if not k.startswith("__pregel_")
  118. }
  119. }
  120. def map_debug_checkpoint(
  121. config: RunnableConfig,
  122. channels: Mapping[str, BaseChannel],
  123. stream_channels: str | Sequence[str],
  124. metadata: CheckpointMetadata,
  125. tasks: Iterable[PregelExecutableTask],
  126. pending_writes: list[PendingWrite],
  127. parent_config: RunnableConfig | None,
  128. output_keys: str | Sequence[str],
  129. ) -> Iterator[CheckpointPayload]:
  130. """Produce "checkpoint" events for stream_mode=debug."""
  131. parent_ns = config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
  132. task_states: dict[str, RunnableConfig | StateSnapshot] = {}
  133. for task in tasks:
  134. if not task.subgraphs:
  135. continue
  136. # assemble checkpoint_ns for this task
  137. task_ns = f"{task.name}{NS_END}{task.id}"
  138. if parent_ns:
  139. task_ns = f"{parent_ns}{NS_SEP}{task_ns}"
  140. # set config as signal that subgraph checkpoints exist
  141. task_states[task.id] = {
  142. CONF: {
  143. "thread_id": config[CONF]["thread_id"],
  144. CONFIG_KEY_CHECKPOINT_NS: task_ns,
  145. }
  146. }
  147. yield {
  148. "config": rm_pregel_keys(patch_checkpoint_map(config, metadata)),
  149. "parent_config": rm_pregel_keys(patch_checkpoint_map(parent_config, metadata)),
  150. "values": read_channels(channels, stream_channels),
  151. "metadata": metadata,
  152. "next": [t.name for t in tasks],
  153. "tasks": [
  154. {
  155. "id": t.id,
  156. "name": t.name,
  157. "error": t.error,
  158. "state": t.state,
  159. }
  160. if t.error
  161. else {
  162. "id": t.id,
  163. "name": t.name,
  164. "result": t.result,
  165. "interrupts": tuple(asdict(i) for i in t.interrupts),
  166. "state": t.state,
  167. }
  168. if t.result
  169. else {
  170. "id": t.id,
  171. "name": t.name,
  172. "interrupts": tuple(asdict(i) for i in t.interrupts),
  173. "state": t.state,
  174. }
  175. for t in tasks_w_writes(tasks, pending_writes, task_states, output_keys)
  176. ],
  177. }
  178. def tasks_w_writes(
  179. tasks: Iterable[PregelTask | PregelExecutableTask],
  180. pending_writes: list[PendingWrite] | None,
  181. states: dict[str, RunnableConfig | StateSnapshot] | None,
  182. output_keys: str | Sequence[str],
  183. ) -> tuple[PregelTask, ...]:
  184. """Apply writes / subgraph states to tasks to be returned in a StateSnapshot."""
  185. pending_writes = pending_writes or []
  186. out: list[PregelTask] = []
  187. for task in tasks:
  188. rtn = next(
  189. (
  190. val
  191. for tid, chan, val in pending_writes
  192. if tid == task.id and chan == RETURN
  193. ),
  194. MISSING,
  195. )
  196. task_error = next(
  197. (exc for tid, n, exc in pending_writes if tid == task.id and n == ERROR),
  198. None,
  199. )
  200. task_interrupts = tuple(
  201. v
  202. for tid, n, vv in pending_writes
  203. if tid == task.id and n == INTERRUPT
  204. for v in (vv if isinstance(vv, Sequence) else [vv])
  205. )
  206. task_writes = [
  207. (chan, val)
  208. for tid, chan, val in pending_writes
  209. if tid == task.id and chan not in (ERROR, INTERRUPT, RETURN)
  210. ]
  211. if rtn is not MISSING:
  212. task_result = rtn
  213. elif isinstance(output_keys, str):
  214. # unwrap single channel writes to just the write value
  215. filtered_writes = [
  216. (chan, val) for chan, val in task_writes if chan == output_keys
  217. ]
  218. mapped_writes = map_task_result_writes(filtered_writes)
  219. task_result = mapped_writes.get(str(output_keys)) if mapped_writes else None
  220. else:
  221. if isinstance(output_keys, str):
  222. output_keys = [output_keys]
  223. # map task result writes to the desired output channels
  224. # repeateed writes to the same channel are aggregated into: {'$writes': [write1, write2, ...]}
  225. filtered_writes = [
  226. (chan, val) for chan, val in task_writes if chan in output_keys
  227. ]
  228. mapped_writes = map_task_result_writes(filtered_writes)
  229. task_result = mapped_writes if filtered_writes else {}
  230. has_writes = rtn is not MISSING or any(
  231. w[0] == task.id and w[1] not in (ERROR, INTERRUPT) for w in pending_writes
  232. )
  233. out.append(
  234. PregelTask(
  235. task.id,
  236. task.name,
  237. task.path,
  238. task_error,
  239. task_interrupts,
  240. states.get(task.id) if states else None,
  241. task_result if has_writes else None,
  242. )
  243. )
  244. return tuple(out)
  245. COLOR_MAPPING = {
  246. "black": "0;30",
  247. "red": "0;31",
  248. "green": "0;32",
  249. "yellow": "0;33",
  250. "blue": "0;34",
  251. "magenta": "0;35",
  252. "cyan": "0;36",
  253. "white": "0;37",
  254. "gray": "1;30",
  255. }
  256. def get_colored_text(text: str, color: str) -> str:
  257. """Get colored text."""
  258. return f"\033[1;3{COLOR_MAPPING[color]}m{text}\033[0m"
  259. def get_bolded_text(text: str) -> str:
  260. """Get bolded text."""
  261. return f"\033[1m{text}\033[0m"