_branch.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. from __future__ import annotations
  2. from collections.abc import Awaitable, Callable, Hashable, Sequence
  3. from inspect import (
  4. isfunction,
  5. ismethod,
  6. signature,
  7. )
  8. from itertools import zip_longest
  9. from types import FunctionType
  10. from typing import (
  11. Any,
  12. Literal,
  13. NamedTuple,
  14. cast,
  15. get_args,
  16. get_origin,
  17. get_type_hints,
  18. )
  19. from langchain_core.runnables import (
  20. Runnable,
  21. RunnableConfig,
  22. RunnableLambda,
  23. )
  24. from langgraph._internal._runnable import (
  25. RunnableCallable,
  26. )
  27. from langgraph.constants import END, START
  28. from langgraph.errors import InvalidUpdateError
  29. from langgraph.pregel._write import PASSTHROUGH, ChannelWrite, ChannelWriteEntry
  30. from langgraph.types import Send
  31. _Writer = Callable[
  32. [Sequence[str | Send], bool],
  33. Sequence[ChannelWriteEntry | Send],
  34. ]
  35. def _get_branch_path_input_schema(
  36. path: Callable[..., Hashable | Sequence[Hashable]]
  37. | Callable[..., Awaitable[Hashable | Sequence[Hashable]]]
  38. | Runnable[Any, Hashable | Sequence[Hashable]],
  39. ) -> type[Any] | None:
  40. input = None
  41. # detect input schema annotation in the branch callable
  42. try:
  43. callable_: (
  44. Callable[..., Hashable | Sequence[Hashable]]
  45. | Callable[..., Awaitable[Hashable | Sequence[Hashable]]]
  46. | None
  47. ) = None
  48. if isinstance(path, (RunnableCallable, RunnableLambda)):
  49. if isfunction(path.func) or ismethod(path.func):
  50. callable_ = path.func
  51. elif (callable_method := getattr(path.func, "__call__", None)) and ismethod(
  52. callable_method
  53. ):
  54. callable_ = callable_method
  55. elif isfunction(path.afunc) or ismethod(path.afunc):
  56. callable_ = path.afunc
  57. elif (
  58. callable_method := getattr(path.afunc, "__call__", None)
  59. ) and ismethod(callable_method):
  60. callable_ = callable_method
  61. elif callable(path):
  62. callable_ = path
  63. if callable_ is not None and (hints := get_type_hints(callable_)):
  64. first_parameter_name = next(
  65. iter(signature(cast(FunctionType, callable_)).parameters.keys())
  66. )
  67. if input_hint := hints.get(first_parameter_name):
  68. if isinstance(input_hint, type) and get_type_hints(input_hint):
  69. input = input_hint
  70. except (TypeError, StopIteration):
  71. pass
  72. return input
  73. class BranchSpec(NamedTuple):
  74. path: Runnable[Any, Hashable | list[Hashable]]
  75. ends: dict[Hashable, str] | None
  76. input_schema: type[Any] | None = None
  77. @classmethod
  78. def from_path(
  79. cls,
  80. path: Runnable[Any, Hashable | list[Hashable]],
  81. path_map: dict[Hashable, str] | list[str] | None,
  82. infer_schema: bool = False,
  83. ) -> BranchSpec:
  84. # coerce path_map to a dictionary
  85. path_map_: dict[Hashable, str] | None = None
  86. try:
  87. if isinstance(path_map, dict):
  88. path_map_ = path_map.copy()
  89. elif isinstance(path_map, list):
  90. path_map_ = {name: name for name in path_map}
  91. else:
  92. # find func
  93. func: Callable | None = None
  94. if isinstance(path, (RunnableCallable, RunnableLambda)):
  95. func = path.func or path.afunc
  96. if func is not None:
  97. # find callable method
  98. if (cal := getattr(path, "__call__", None)) and ismethod(cal):
  99. func = cal
  100. # get the return type
  101. if rtn_type := get_type_hints(func).get("return"):
  102. if get_origin(rtn_type) is Literal:
  103. path_map_ = {name: name for name in get_args(rtn_type)}
  104. except Exception:
  105. pass
  106. # infer input schema
  107. input_schema = _get_branch_path_input_schema(path) if infer_schema else None
  108. # create branch
  109. return cls(path=path, ends=path_map_, input_schema=input_schema)
  110. def run(
  111. self,
  112. writer: _Writer,
  113. reader: Callable[[RunnableConfig], Any] | None = None,
  114. ) -> RunnableCallable:
  115. return ChannelWrite.register_writer(
  116. RunnableCallable(
  117. func=self._route,
  118. afunc=self._aroute,
  119. writer=writer,
  120. reader=reader,
  121. name=None,
  122. trace=False,
  123. ),
  124. list(
  125. zip_longest(
  126. writer([e for e in self.ends.values()], True),
  127. [str(la) for la, e in self.ends.items()],
  128. )
  129. )
  130. if self.ends
  131. else None,
  132. )
  133. def _route(
  134. self,
  135. input: Any,
  136. config: RunnableConfig,
  137. *,
  138. reader: Callable[[RunnableConfig], Any] | None,
  139. writer: _Writer,
  140. ) -> Runnable:
  141. if reader:
  142. value = reader(config)
  143. # passthrough additional keys from node to branch
  144. # only doable when using dict states
  145. if (
  146. isinstance(value, dict)
  147. and isinstance(input, dict)
  148. and self.input_schema is None
  149. ):
  150. value = {**input, **value}
  151. else:
  152. value = input
  153. result = self.path.invoke(value, config)
  154. return self._finish(writer, input, result, config)
  155. async def _aroute(
  156. self,
  157. input: Any,
  158. config: RunnableConfig,
  159. *,
  160. reader: Callable[[RunnableConfig], Any] | None,
  161. writer: _Writer,
  162. ) -> Runnable:
  163. if reader:
  164. value = reader(config)
  165. # passthrough additional keys from node to branch
  166. # only doable when using dict states
  167. if (
  168. isinstance(value, dict)
  169. and isinstance(input, dict)
  170. and self.input_schema is None
  171. ):
  172. value = {**input, **value}
  173. else:
  174. value = input
  175. result = await self.path.ainvoke(value, config)
  176. return self._finish(writer, input, result, config)
  177. def _finish(
  178. self,
  179. writer: _Writer,
  180. input: Any,
  181. result: Any,
  182. config: RunnableConfig,
  183. ) -> Runnable | Any:
  184. if not isinstance(result, (list, tuple)):
  185. result = [result]
  186. if self.ends:
  187. destinations: Sequence[Send | str] = [
  188. r if isinstance(r, Send) else self.ends[r] for r in result
  189. ]
  190. else:
  191. destinations = cast(Sequence[Send | str], result)
  192. if any(dest is None or dest == START for dest in destinations):
  193. raise ValueError("Branch did not return a valid destination")
  194. if any(p.node == END for p in destinations if isinstance(p, Send)):
  195. raise InvalidUpdateError("Cannot send a packet to the END node")
  196. entries = writer(destinations, False)
  197. if not entries:
  198. return input
  199. else:
  200. need_passthrough = False
  201. for e in entries:
  202. if isinstance(e, ChannelWriteEntry):
  203. if e.value is PASSTHROUGH:
  204. need_passthrough = True
  205. break
  206. if need_passthrough:
  207. return ChannelWrite(entries)
  208. else:
  209. ChannelWrite.do_write(config, entries)
  210. return input