| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269 |
- """Utility to convert a user provided function into a Runnable with a ChannelWrite."""
- from __future__ import annotations
- import concurrent.futures
- import functools
- import inspect
- import sys
- import types
- from collections.abc import Awaitable, Callable, Generator, Sequence
- from typing import Any, Generic, TypeVar, cast
- from langchain_core.runnables import Runnable
- from typing_extensions import ParamSpec
- from langgraph._internal._constants import CONF, CONFIG_KEY_CALL, RETURN
- from langgraph._internal._runnable import (
- RunnableCallable,
- RunnableSeq,
- is_async_callable,
- run_in_executor,
- )
- from langgraph.config import get_config
- from langgraph.pregel._write import ChannelWrite, ChannelWriteEntry
- from langgraph.types import CachePolicy, RetryPolicy
- ##
- # Utilities borrowed from cloudpickle.
- # https://github.com/cloudpipe/cloudpickle/blob/6220b0ce83ffee5e47e06770a1ee38ca9e47c850/cloudpickle/cloudpickle.py#L265
- def _getattribute(obj: Any, name: str) -> Any:
- parent = None
- for subpath in name.split("."):
- if subpath == "<locals>":
- raise AttributeError(f"Can't get local attribute {name!r} on {obj!r}")
- try:
- parent = obj
- obj = getattr(obj, subpath)
- except AttributeError:
- raise AttributeError(f"Can't get attribute {name!r} on {obj!r}") from None
- return obj, parent
- def _whichmodule(obj: Any, name: str) -> str | None:
- """Find the module an object belongs to.
- This function differs from ``pickle.whichmodule`` in two ways:
- - it does not mangle the cases where obj's module is __main__ and obj was
- not found in any module.
- - Errors arising during module introspection are ignored, as those errors
- are considered unwanted side effects.
- """
- module_name = getattr(obj, "__module__", None)
- if module_name is not None:
- return module_name
- # Protect the iteration by using a copy of sys.modules against dynamic
- # modules that trigger imports of other modules upon calls to getattr or
- # other threads importing at the same time.
- for module_name, module in sys.modules.copy().items():
- # Some modules such as coverage can inject non-module objects inside
- # sys.modules
- if (
- module_name == "__main__"
- or module_name == "__mp_main__"
- or module is None
- or not isinstance(module, types.ModuleType)
- ):
- continue
- try:
- if _getattribute(module, name)[0] is obj:
- return module_name
- except Exception:
- pass
- return None
- def identifier(obj: Any, name: str | None = None) -> str | None:
- """Return the module and name of an object."""
- from langgraph._internal._runnable import RunnableCallable, RunnableSeq
- from langgraph.pregel._read import PregelNode
- if isinstance(obj, PregelNode):
- obj = obj.bound
- if isinstance(obj, RunnableSeq):
- obj = obj.steps[0]
- if isinstance(obj, RunnableCallable):
- obj = obj.func
- if name is None:
- name = getattr(obj, "__qualname__", None)
- if name is None: # pragma: no cover
- # This used to be needed for Python 2.7 support but is probably not
- # needed anymore. However we keep the __name__ introspection in case
- # users of cloudpickle rely on this old behavior for unknown reasons.
- name = getattr(obj, "__name__", None)
- if name is None:
- return None
- module_name = getattr(obj, "__module__", None)
- if module_name is None:
- # In this case, obj.__module__ is None. obj is thus treated as dynamic.
- return None
- return f"{module_name}.{name}"
- def _lookup_module_and_qualname(
- obj: Any, name: str | None = None
- ) -> tuple[types.ModuleType, str] | None:
- if name is None:
- name = getattr(obj, "__qualname__", None)
- if name is None: # pragma: no cover
- # This used to be needed for Python 2.7 support but is probably not
- # needed anymore. However we keep the __name__ introspection in case
- # users of cloudpickle rely on this old behavior for unknown reasons.
- name = getattr(obj, "__name__", None)
- if name is None:
- return None
- module_name = _whichmodule(obj, name)
- if module_name is None:
- # In this case, obj.__module__ is None AND obj was not found in any
- # imported module. obj is thus treated as dynamic.
- return None
- if module_name == "__main__":
- return None
- # Note: if module_name is in sys.modules, the corresponding module is
- # assumed importable at unpickling time. See #357
- module = sys.modules.get(module_name, None)
- if module is None:
- # The main reason why obj's module would not be imported is that this
- # module has been dynamically created, using for example
- # types.ModuleType. The other possibility is that module was removed
- # from sys.modules after obj was created/imported. But this case is not
- # supported, as the standard pickle does not support it either.
- return None
- try:
- obj2, parent = _getattribute(module, name)
- except AttributeError:
- # obj was not found inside the module it points to
- return None
- if obj2 is not obj:
- return None
- return module, name
- def _explode_args_trace_inputs(
- sig: inspect.Signature, input: tuple[tuple[Any, ...], dict[str, Any]]
- ) -> dict[str, Any]:
- args, kwargs = input
- bound = sig.bind_partial(*args, **kwargs)
- bound.apply_defaults()
- arguments = dict(bound.arguments)
- arguments.pop("self", None)
- arguments.pop("cls", None)
- for param_name, param in sig.parameters.items():
- if param.kind == inspect.Parameter.VAR_KEYWORD:
- # Update with the **kwargs, and remove the original entry
- # This is to help flatten out keyword arguments
- if param_name in arguments:
- arguments.update(arguments.pop(param_name))
- return arguments
- def get_runnable_for_entrypoint(func: Callable[..., Any]) -> Runnable:
- key = (func, False)
- if key in CACHE:
- return CACHE[key]
- else:
- if is_async_callable(func):
- run = RunnableCallable(
- None, func, name=func.__name__, trace=False, recurse=False
- )
- else:
- afunc = functools.update_wrapper(
- functools.partial(run_in_executor, None, func), func
- )
- run = RunnableCallable(
- func,
- afunc,
- name=func.__name__,
- trace=False,
- recurse=False,
- )
- if not _lookup_module_and_qualname(func):
- return run
- return CACHE.setdefault(key, run)
- def get_runnable_for_task(func: Callable[..., Any]) -> Runnable:
- key = (func, True)
- if key in CACHE:
- return CACHE[key]
- else:
- if hasattr(func, "__name__"):
- name = func.__name__
- elif hasattr(func, "func"):
- name = func.func.__name__
- elif hasattr(func, "__class__"):
- name = func.__class__.__name__
- else:
- name = str(func)
- if is_async_callable(func):
- run = RunnableCallable(
- None,
- func,
- explode_args=True,
- name=name,
- trace=False,
- recurse=False,
- )
- else:
- run = RunnableCallable(
- func,
- functools.wraps(func)(functools.partial(run_in_executor, None, func)),
- explode_args=True,
- name=name,
- trace=False,
- recurse=False,
- )
- seq = RunnableSeq(
- run,
- ChannelWrite([ChannelWriteEntry(RETURN)]),
- name=name,
- trace_inputs=functools.partial(
- _explode_args_trace_inputs, inspect.signature(func)
- ),
- )
- if not _lookup_module_and_qualname(func):
- return seq
- return CACHE.setdefault(key, seq)
- CACHE: dict[tuple[Callable[..., Any], bool], Runnable] = {}
- P = ParamSpec("P")
- P1 = TypeVar("P1")
- T = TypeVar("T")
- class SyncAsyncFuture(Generic[T], concurrent.futures.Future[T]):
- def __await__(self) -> Generator[T, None, T]:
- yield cast(T, ...)
- def call(
- func: Callable[P, Awaitable[T]] | Callable[P, T],
- *args: Any,
- retry_policy: Sequence[RetryPolicy] | None = None,
- cache_policy: CachePolicy | None = None,
- **kwargs: Any,
- ) -> SyncAsyncFuture[T]:
- config = get_config()
- impl = config[CONF][CONFIG_KEY_CALL]
- fut = impl(
- func,
- (args, kwargs),
- retry_policy=retry_policy,
- cache_policy=cache_policy,
- callbacks=config["callbacks"],
- )
- return fut
|