| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- from __future__ import annotations
- from collections.abc import Awaitable, Callable, Hashable, Sequence
- from inspect import (
- isfunction,
- ismethod,
- signature,
- )
- from itertools import zip_longest
- from types import FunctionType
- from typing import (
- Any,
- Literal,
- NamedTuple,
- cast,
- get_args,
- get_origin,
- get_type_hints,
- )
- from langchain_core.runnables import (
- Runnable,
- RunnableConfig,
- RunnableLambda,
- )
- from langgraph._internal._runnable import (
- RunnableCallable,
- )
- from langgraph.constants import END, START
- from langgraph.errors import InvalidUpdateError
- from langgraph.pregel._write import PASSTHROUGH, ChannelWrite, ChannelWriteEntry
- from langgraph.types import Send
- _Writer = Callable[
- [Sequence[str | Send], bool],
- Sequence[ChannelWriteEntry | Send],
- ]
- def _get_branch_path_input_schema(
- path: Callable[..., Hashable | Sequence[Hashable]]
- | Callable[..., Awaitable[Hashable | Sequence[Hashable]]]
- | Runnable[Any, Hashable | Sequence[Hashable]],
- ) -> type[Any] | None:
- input = None
- # detect input schema annotation in the branch callable
- try:
- callable_: (
- Callable[..., Hashable | Sequence[Hashable]]
- | Callable[..., Awaitable[Hashable | Sequence[Hashable]]]
- | None
- ) = None
- if isinstance(path, (RunnableCallable, RunnableLambda)):
- if isfunction(path.func) or ismethod(path.func):
- callable_ = path.func
- elif (callable_method := getattr(path.func, "__call__", None)) and ismethod(
- callable_method
- ):
- callable_ = callable_method
- elif isfunction(path.afunc) or ismethod(path.afunc):
- callable_ = path.afunc
- elif (
- callable_method := getattr(path.afunc, "__call__", None)
- ) and ismethod(callable_method):
- callable_ = callable_method
- elif callable(path):
- callable_ = path
- if callable_ is not None and (hints := get_type_hints(callable_)):
- first_parameter_name = next(
- iter(signature(cast(FunctionType, callable_)).parameters.keys())
- )
- if input_hint := hints.get(first_parameter_name):
- if isinstance(input_hint, type) and get_type_hints(input_hint):
- input = input_hint
- except (TypeError, StopIteration):
- pass
- return input
- class BranchSpec(NamedTuple):
- path: Runnable[Any, Hashable | list[Hashable]]
- ends: dict[Hashable, str] | None
- input_schema: type[Any] | None = None
- @classmethod
- def from_path(
- cls,
- path: Runnable[Any, Hashable | list[Hashable]],
- path_map: dict[Hashable, str] | list[str] | None,
- infer_schema: bool = False,
- ) -> BranchSpec:
- # coerce path_map to a dictionary
- path_map_: dict[Hashable, str] | None = None
- try:
- if isinstance(path_map, dict):
- path_map_ = path_map.copy()
- elif isinstance(path_map, list):
- path_map_ = {name: name for name in path_map}
- else:
- # find func
- func: Callable | None = None
- if isinstance(path, (RunnableCallable, RunnableLambda)):
- func = path.func or path.afunc
- if func is not None:
- # find callable method
- if (cal := getattr(path, "__call__", None)) and ismethod(cal):
- func = cal
- # get the return type
- if rtn_type := get_type_hints(func).get("return"):
- if get_origin(rtn_type) is Literal:
- path_map_ = {name: name for name in get_args(rtn_type)}
- except Exception:
- pass
- # infer input schema
- input_schema = _get_branch_path_input_schema(path) if infer_schema else None
- # create branch
- return cls(path=path, ends=path_map_, input_schema=input_schema)
- def run(
- self,
- writer: _Writer,
- reader: Callable[[RunnableConfig], Any] | None = None,
- ) -> RunnableCallable:
- return ChannelWrite.register_writer(
- RunnableCallable(
- func=self._route,
- afunc=self._aroute,
- writer=writer,
- reader=reader,
- name=None,
- trace=False,
- ),
- list(
- zip_longest(
- writer([e for e in self.ends.values()], True),
- [str(la) for la, e in self.ends.items()],
- )
- )
- if self.ends
- else None,
- )
- def _route(
- self,
- input: Any,
- config: RunnableConfig,
- *,
- reader: Callable[[RunnableConfig], Any] | None,
- writer: _Writer,
- ) -> Runnable:
- if reader:
- value = reader(config)
- # passthrough additional keys from node to branch
- # only doable when using dict states
- if (
- isinstance(value, dict)
- and isinstance(input, dict)
- and self.input_schema is None
- ):
- value = {**input, **value}
- else:
- value = input
- result = self.path.invoke(value, config)
- return self._finish(writer, input, result, config)
- async def _aroute(
- self,
- input: Any,
- config: RunnableConfig,
- *,
- reader: Callable[[RunnableConfig], Any] | None,
- writer: _Writer,
- ) -> Runnable:
- if reader:
- value = reader(config)
- # passthrough additional keys from node to branch
- # only doable when using dict states
- if (
- isinstance(value, dict)
- and isinstance(input, dict)
- and self.input_schema is None
- ):
- value = {**input, **value}
- else:
- value = input
- result = await self.path.ainvoke(value, config)
- return self._finish(writer, input, result, config)
- def _finish(
- self,
- writer: _Writer,
- input: Any,
- result: Any,
- config: RunnableConfig,
- ) -> Runnable | Any:
- if not isinstance(result, (list, tuple)):
- result = [result]
- if self.ends:
- destinations: Sequence[Send | str] = [
- r if isinstance(r, Send) else self.ends[r] for r in result
- ]
- else:
- destinations = cast(Sequence[Send | str], result)
- if any(dest is None or dest == START for dest in destinations):
- raise ValueError("Branch did not return a valid destination")
- if any(p.node == END for p in destinations if isinstance(p, Send)):
- raise InvalidUpdateError("Cannot send a packet to the END node")
- entries = writer(destinations, False)
- if not entries:
- return input
- else:
- need_passthrough = False
- for e in entries:
- if isinstance(e, ChannelWriteEntry):
- if e.value is PASSTHROUGH:
- need_passthrough = True
- break
- if need_passthrough:
- return ChannelWrite(entries)
- else:
- ChannelWrite.do_write(config, entries)
- return input
|