_call.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. """Utility to convert a user provided function into a Runnable with a ChannelWrite."""
  2. from __future__ import annotations
  3. import concurrent.futures
  4. import functools
  5. import inspect
  6. import sys
  7. import types
  8. from collections.abc import Awaitable, Callable, Generator, Sequence
  9. from typing import Any, Generic, TypeVar, cast
  10. from langchain_core.runnables import Runnable
  11. from typing_extensions import ParamSpec
  12. from langgraph._internal._constants import CONF, CONFIG_KEY_CALL, RETURN
  13. from langgraph._internal._runnable import (
  14. RunnableCallable,
  15. RunnableSeq,
  16. is_async_callable,
  17. run_in_executor,
  18. )
  19. from langgraph.config import get_config
  20. from langgraph.pregel._write import ChannelWrite, ChannelWriteEntry
  21. from langgraph.types import CachePolicy, RetryPolicy
  22. ##
  23. # Utilities borrowed from cloudpickle.
  24. # https://github.com/cloudpipe/cloudpickle/blob/6220b0ce83ffee5e47e06770a1ee38ca9e47c850/cloudpickle/cloudpickle.py#L265
  25. def _getattribute(obj: Any, name: str) -> Any:
  26. parent = None
  27. for subpath in name.split("."):
  28. if subpath == "<locals>":
  29. raise AttributeError(f"Can't get local attribute {name!r} on {obj!r}")
  30. try:
  31. parent = obj
  32. obj = getattr(obj, subpath)
  33. except AttributeError:
  34. raise AttributeError(f"Can't get attribute {name!r} on {obj!r}") from None
  35. return obj, parent
  36. def _whichmodule(obj: Any, name: str) -> str | None:
  37. """Find the module an object belongs to.
  38. This function differs from ``pickle.whichmodule`` in two ways:
  39. - it does not mangle the cases where obj's module is __main__ and obj was
  40. not found in any module.
  41. - Errors arising during module introspection are ignored, as those errors
  42. are considered unwanted side effects.
  43. """
  44. module_name = getattr(obj, "__module__", None)
  45. if module_name is not None:
  46. return module_name
  47. # Protect the iteration by using a copy of sys.modules against dynamic
  48. # modules that trigger imports of other modules upon calls to getattr or
  49. # other threads importing at the same time.
  50. for module_name, module in sys.modules.copy().items():
  51. # Some modules such as coverage can inject non-module objects inside
  52. # sys.modules
  53. if (
  54. module_name == "__main__"
  55. or module_name == "__mp_main__"
  56. or module is None
  57. or not isinstance(module, types.ModuleType)
  58. ):
  59. continue
  60. try:
  61. if _getattribute(module, name)[0] is obj:
  62. return module_name
  63. except Exception:
  64. pass
  65. return None
  66. def identifier(obj: Any, name: str | None = None) -> str | None:
  67. """Return the module and name of an object."""
  68. from langgraph._internal._runnable import RunnableCallable, RunnableSeq
  69. from langgraph.pregel._read import PregelNode
  70. if isinstance(obj, PregelNode):
  71. obj = obj.bound
  72. if isinstance(obj, RunnableSeq):
  73. obj = obj.steps[0]
  74. if isinstance(obj, RunnableCallable):
  75. obj = obj.func
  76. if name is None:
  77. name = getattr(obj, "__qualname__", None)
  78. if name is None: # pragma: no cover
  79. # This used to be needed for Python 2.7 support but is probably not
  80. # needed anymore. However we keep the __name__ introspection in case
  81. # users of cloudpickle rely on this old behavior for unknown reasons.
  82. name = getattr(obj, "__name__", None)
  83. if name is None:
  84. return None
  85. module_name = getattr(obj, "__module__", None)
  86. if module_name is None:
  87. # In this case, obj.__module__ is None. obj is thus treated as dynamic.
  88. return None
  89. return f"{module_name}.{name}"
  90. def _lookup_module_and_qualname(
  91. obj: Any, name: str | None = None
  92. ) -> tuple[types.ModuleType, str] | None:
  93. if name is None:
  94. name = getattr(obj, "__qualname__", None)
  95. if name is None: # pragma: no cover
  96. # This used to be needed for Python 2.7 support but is probably not
  97. # needed anymore. However we keep the __name__ introspection in case
  98. # users of cloudpickle rely on this old behavior for unknown reasons.
  99. name = getattr(obj, "__name__", None)
  100. if name is None:
  101. return None
  102. module_name = _whichmodule(obj, name)
  103. if module_name is None:
  104. # In this case, obj.__module__ is None AND obj was not found in any
  105. # imported module. obj is thus treated as dynamic.
  106. return None
  107. if module_name == "__main__":
  108. return None
  109. # Note: if module_name is in sys.modules, the corresponding module is
  110. # assumed importable at unpickling time. See #357
  111. module = sys.modules.get(module_name, None)
  112. if module is None:
  113. # The main reason why obj's module would not be imported is that this
  114. # module has been dynamically created, using for example
  115. # types.ModuleType. The other possibility is that module was removed
  116. # from sys.modules after obj was created/imported. But this case is not
  117. # supported, as the standard pickle does not support it either.
  118. return None
  119. try:
  120. obj2, parent = _getattribute(module, name)
  121. except AttributeError:
  122. # obj was not found inside the module it points to
  123. return None
  124. if obj2 is not obj:
  125. return None
  126. return module, name
  127. def _explode_args_trace_inputs(
  128. sig: inspect.Signature, input: tuple[tuple[Any, ...], dict[str, Any]]
  129. ) -> dict[str, Any]:
  130. args, kwargs = input
  131. bound = sig.bind_partial(*args, **kwargs)
  132. bound.apply_defaults()
  133. arguments = dict(bound.arguments)
  134. arguments.pop("self", None)
  135. arguments.pop("cls", None)
  136. for param_name, param in sig.parameters.items():
  137. if param.kind == inspect.Parameter.VAR_KEYWORD:
  138. # Update with the **kwargs, and remove the original entry
  139. # This is to help flatten out keyword arguments
  140. if param_name in arguments:
  141. arguments.update(arguments.pop(param_name))
  142. return arguments
  143. def get_runnable_for_entrypoint(func: Callable[..., Any]) -> Runnable:
  144. key = (func, False)
  145. if key in CACHE:
  146. return CACHE[key]
  147. else:
  148. if is_async_callable(func):
  149. run = RunnableCallable(
  150. None, func, name=func.__name__, trace=False, recurse=False
  151. )
  152. else:
  153. afunc = functools.update_wrapper(
  154. functools.partial(run_in_executor, None, func), func
  155. )
  156. run = RunnableCallable(
  157. func,
  158. afunc,
  159. name=func.__name__,
  160. trace=False,
  161. recurse=False,
  162. )
  163. if not _lookup_module_and_qualname(func):
  164. return run
  165. return CACHE.setdefault(key, run)
  166. def get_runnable_for_task(func: Callable[..., Any]) -> Runnable:
  167. key = (func, True)
  168. if key in CACHE:
  169. return CACHE[key]
  170. else:
  171. if hasattr(func, "__name__"):
  172. name = func.__name__
  173. elif hasattr(func, "func"):
  174. name = func.func.__name__
  175. elif hasattr(func, "__class__"):
  176. name = func.__class__.__name__
  177. else:
  178. name = str(func)
  179. if is_async_callable(func):
  180. run = RunnableCallable(
  181. None,
  182. func,
  183. explode_args=True,
  184. name=name,
  185. trace=False,
  186. recurse=False,
  187. )
  188. else:
  189. run = RunnableCallable(
  190. func,
  191. functools.wraps(func)(functools.partial(run_in_executor, None, func)),
  192. explode_args=True,
  193. name=name,
  194. trace=False,
  195. recurse=False,
  196. )
  197. seq = RunnableSeq(
  198. run,
  199. ChannelWrite([ChannelWriteEntry(RETURN)]),
  200. name=name,
  201. trace_inputs=functools.partial(
  202. _explode_args_trace_inputs, inspect.signature(func)
  203. ),
  204. )
  205. if not _lookup_module_and_qualname(func):
  206. return seq
  207. return CACHE.setdefault(key, seq)
  208. CACHE: dict[tuple[Callable[..., Any], bool], Runnable] = {}
  209. P = ParamSpec("P")
  210. P1 = TypeVar("P1")
  211. T = TypeVar("T")
  212. class SyncAsyncFuture(Generic[T], concurrent.futures.Future[T]):
  213. def __await__(self) -> Generator[T, None, T]:
  214. yield cast(T, ...)
  215. def call(
  216. func: Callable[P, Awaitable[T]] | Callable[P, T],
  217. *args: Any,
  218. retry_policy: Sequence[RetryPolicy] | None = None,
  219. cache_policy: CachePolicy | None = None,
  220. **kwargs: Any,
  221. ) -> SyncAsyncFuture[T]:
  222. config = get_config()
  223. impl = config[CONF][CONFIG_KEY_CALL]
  224. fut = impl(
  225. func,
  226. (args, kwargs),
  227. retry_policy=retry_policy,
  228. cache_policy=cache_policy,
  229. callbacks=config["callbacks"],
  230. )
  231. return fut