_io.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. from __future__ import annotations
  2. from collections import Counter
  3. from collections.abc import Iterator, Mapping, Sequence
  4. from typing import Any, Literal
  5. from langgraph._internal._constants import (
  6. ERROR,
  7. INTERRUPT,
  8. NULL_TASK_ID,
  9. RESUME,
  10. RETURN,
  11. TASKS,
  12. )
  13. from langgraph._internal._typing import EMPTY_SEQ, MISSING
  14. from langgraph.channels.base import BaseChannel, EmptyChannelError
  15. from langgraph.constants import START, TAG_HIDDEN
  16. from langgraph.errors import InvalidUpdateError
  17. from langgraph.pregel._log import logger
  18. from langgraph.types import Command, PregelExecutableTask, Send
  19. def read_channel(
  20. channels: Mapping[str, BaseChannel],
  21. chan: str,
  22. *,
  23. catch: bool = True,
  24. ) -> Any:
  25. try:
  26. return channels[chan].get()
  27. except EmptyChannelError:
  28. if catch:
  29. return None
  30. else:
  31. raise
  32. def read_channels(
  33. channels: Mapping[str, BaseChannel],
  34. select: Sequence[str] | str,
  35. *,
  36. skip_empty: bool = True,
  37. ) -> dict[str, Any] | Any:
  38. if isinstance(select, str):
  39. return read_channel(channels, select)
  40. else:
  41. values: dict[str, Any] = {}
  42. for k in select:
  43. try:
  44. values[k] = read_channel(channels, k, catch=not skip_empty)
  45. except EmptyChannelError:
  46. pass
  47. return values
  48. def map_command(cmd: Command) -> Iterator[tuple[str, str, Any]]:
  49. """Map input chunk to a sequence of pending writes in the form (channel, value)."""
  50. if cmd.graph == Command.PARENT:
  51. raise InvalidUpdateError("There is no parent graph")
  52. if cmd.goto:
  53. if isinstance(cmd.goto, (tuple, list)):
  54. sends = cmd.goto
  55. else:
  56. sends = [cmd.goto]
  57. for send in sends:
  58. if isinstance(send, Send):
  59. yield (NULL_TASK_ID, TASKS, send)
  60. elif isinstance(send, str):
  61. yield (NULL_TASK_ID, f"branch:to:{send}", START)
  62. else:
  63. raise TypeError(
  64. f"In Command.goto, expected Send/str, got {type(send).__name__}"
  65. )
  66. if cmd.resume is not None:
  67. yield (NULL_TASK_ID, RESUME, cmd.resume)
  68. if cmd.update:
  69. for k, v in cmd._update_as_tuples():
  70. yield (NULL_TASK_ID, k, v)
  71. def map_input(
  72. input_channels: str | Sequence[str],
  73. chunk: dict[str, Any] | Any | None,
  74. ) -> Iterator[tuple[str, Any]]:
  75. """Map input chunk to a sequence of pending writes in the form (channel, value)."""
  76. if chunk is None:
  77. return
  78. elif isinstance(input_channels, str):
  79. yield (input_channels, chunk)
  80. else:
  81. if not isinstance(chunk, dict):
  82. raise TypeError(f"Expected chunk to be a dict, got {type(chunk).__name__}")
  83. for k in chunk:
  84. if k in input_channels:
  85. yield (k, chunk[k])
  86. else:
  87. logger.warning(f"Input channel {k} not found in {input_channels}")
  88. def map_output_values(
  89. output_channels: str | Sequence[str],
  90. pending_writes: Literal[True] | Sequence[tuple[str, Any]],
  91. channels: Mapping[str, BaseChannel],
  92. ) -> Iterator[dict[str, Any] | Any]:
  93. """Map pending writes (a sequence of tuples (channel, value)) to output chunk."""
  94. if isinstance(output_channels, str):
  95. if pending_writes is True or any(
  96. chan == output_channels for chan, _ in pending_writes
  97. ):
  98. yield read_channel(channels, output_channels)
  99. else:
  100. if pending_writes is True or {
  101. c for c, _ in pending_writes if c in output_channels
  102. }:
  103. yield read_channels(channels, output_channels)
  104. def map_output_updates(
  105. output_channels: str | Sequence[str],
  106. tasks: list[tuple[PregelExecutableTask, Sequence[tuple[str, Any]]]],
  107. cached: bool = False,
  108. ) -> Iterator[dict[str, Any | dict[str, Any]]]:
  109. """Map pending writes (a sequence of tuples (channel, value)) to output chunk."""
  110. output_tasks = [
  111. (t, ww)
  112. for t, ww in tasks
  113. if (not t.config or TAG_HIDDEN not in t.config.get("tags", EMPTY_SEQ))
  114. and ww[0][0] != ERROR
  115. and ww[0][0] != INTERRUPT
  116. ]
  117. if not output_tasks:
  118. return
  119. updated: list[tuple[str, Any]] = []
  120. for task, writes in output_tasks:
  121. rtn = next((value for chan, value in writes if chan == RETURN), MISSING)
  122. if rtn is not MISSING:
  123. updated.append((task.name, rtn))
  124. elif isinstance(output_channels, str):
  125. updated.extend(
  126. (task.name, value) for chan, value in writes if chan == output_channels
  127. )
  128. elif any(chan in output_channels for chan, _ in writes):
  129. counts = Counter(chan for chan, _ in writes)
  130. if any(counts[chan] > 1 for chan in output_channels):
  131. updated.extend(
  132. (
  133. task.name,
  134. {chan: value},
  135. )
  136. for chan, value in writes
  137. if chan in output_channels
  138. )
  139. else:
  140. updated.append(
  141. (
  142. task.name,
  143. {
  144. chan: value
  145. for chan, value in writes
  146. if chan in output_channels
  147. },
  148. )
  149. )
  150. grouped: dict[str, Any] = {t.name: [] for t, _ in output_tasks}
  151. for node, value in updated:
  152. grouped[node].append(value)
  153. for node, value in grouped.items():
  154. if len(value) == 0:
  155. grouped[node] = None
  156. if len(value) == 1:
  157. grouped[node] = value[0]
  158. if cached:
  159. grouped["__metadata__"] = {"cached": cached}
  160. yield grouped