| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769 |
- from __future__ import annotations
- import asyncio
- import concurrent.futures
- import inspect
- import threading
- import time
- import weakref
- from collections.abc import (
- AsyncIterator,
- Awaitable,
- Callable,
- Iterable,
- Iterator,
- Sequence,
- )
- from functools import partial
- from typing import (
- Any,
- Generic,
- TypeVar,
- cast,
- )
- from langchain_core.callbacks import Callbacks
- from langgraph._internal._constants import (
- CONF,
- CONFIG_KEY_CALL,
- CONFIG_KEY_SCRATCHPAD,
- ERROR,
- INTERRUPT,
- NO_WRITES,
- RESUME,
- RETURN,
- )
- from langgraph._internal._future import chain_future, run_coroutine_threadsafe
- from langgraph._internal._scratchpad import PregelScratchpad
- from langgraph._internal._typing import MISSING
- from langgraph.constants import TAG_HIDDEN
- from langgraph.errors import GraphBubbleUp, GraphInterrupt
- from langgraph.pregel._algo import Call
- from langgraph.pregel._executor import Submit
- from langgraph.pregel._retry import arun_with_retry, run_with_retry
- from langgraph.types import (
- CachePolicy,
- PregelExecutableTask,
- RetryPolicy,
- )
- F = TypeVar("F", concurrent.futures.Future, asyncio.Future)
- E = TypeVar("E", threading.Event, asyncio.Event)
- # List of filenames to exclude from exception traceback
- # Note: Frames will be removed if they are the last frame in traceback, recursively
- EXCLUDED_FRAME_FNAMES = (
- "langgraph/pregel/retry.py",
- "langgraph/pregel/runner.py",
- "langgraph/pregel/executor.py",
- "langgraph/utils/runnable.py",
- "langchain_core/runnables/config.py",
- "concurrent/futures/thread.py",
- "concurrent/futures/_base.py",
- )
- SKIP_RERAISE_SET: weakref.WeakSet[concurrent.futures.Future | asyncio.Future] = (
- weakref.WeakSet()
- )
- class FuturesDict(Generic[F, E], dict[F, PregelExecutableTask | None]):
- event: E
- callback: weakref.ref[Callable[[PregelExecutableTask, BaseException | None], None]]
- counter: int
- done: set[F]
- lock: threading.Lock
- def __init__(
- self,
- event: E,
- callback: weakref.ref[
- Callable[[PregelExecutableTask, BaseException | None], None]
- ],
- future_type: type[F],
- # used for generic typing, newer py supports FutureDict[...](...)
- ) -> None:
- super().__init__()
- self.lock = threading.Lock()
- self.event = event
- self.callback = callback
- self.counter = 0
- self.done: set[F] = set()
- def __setitem__(
- self,
- key: F,
- value: PregelExecutableTask | None,
- ) -> None:
- super().__setitem__(key, value) # type: ignore[index]
- if value is not None:
- with self.lock:
- self.event.clear()
- self.counter += 1
- key.add_done_callback(partial(self.on_done, value))
- def on_done(
- self,
- task: PregelExecutableTask,
- fut: F,
- ) -> None:
- try:
- if cb := self.callback():
- cb(task, _exception(fut))
- finally:
- with self.lock:
- self.done.add(fut)
- self.counter -= 1
- if self.counter == 0 or _should_stop_others(self.done):
- self.event.set()
- class PregelRunner:
- """Responsible for executing a set of Pregel tasks concurrently, committing
- their writes, yielding control to caller when there is output to emit, and
- interrupting other tasks if appropriate."""
- def __init__(
- self,
- *,
- submit: weakref.ref[Submit],
- put_writes: weakref.ref[Callable[[str, Sequence[tuple[str, Any]]], None]],
- use_astream: bool = False,
- node_finished: Callable[[str], None] | None = None,
- ) -> None:
- self.submit = submit
- self.put_writes = put_writes
- self.use_astream = use_astream
- self.node_finished = node_finished
- def tick(
- self,
- tasks: Iterable[PregelExecutableTask],
- *,
- reraise: bool = True,
- timeout: float | None = None,
- retry_policy: Sequence[RetryPolicy] | None = None,
- get_waiter: Callable[[], concurrent.futures.Future[None]] | None = None,
- schedule_task: Callable[
- [PregelExecutableTask, int, Call | None],
- PregelExecutableTask | None,
- ],
- ) -> Iterator[None]:
- tasks = tuple(tasks)
- futures = FuturesDict(
- callback=weakref.WeakMethod(self.commit),
- event=threading.Event(),
- future_type=concurrent.futures.Future,
- )
- # give control back to the caller
- yield
- # fast path if single task with no timeout and no waiter
- if len(tasks) == 0:
- return
- elif len(tasks) == 1 and timeout is None and get_waiter is None:
- t = tasks[0]
- try:
- run_with_retry(
- t,
- retry_policy,
- configurable={
- CONFIG_KEY_CALL: partial(
- _call,
- weakref.ref(t),
- retry_policy=retry_policy,
- futures=weakref.ref(futures),
- schedule_task=schedule_task,
- submit=self.submit,
- ),
- },
- )
- self.commit(t, None)
- except Exception as exc:
- self.commit(t, exc)
- if reraise and futures:
- # will be re-raised after futures are done
- fut: concurrent.futures.Future = concurrent.futures.Future()
- fut.set_exception(exc)
- futures.done.add(fut)
- elif reraise:
- if tb := exc.__traceback__:
- while tb.tb_next is not None and any(
- tb.tb_frame.f_code.co_filename.endswith(name)
- for name in EXCLUDED_FRAME_FNAMES
- ):
- tb = tb.tb_next
- exc.__traceback__ = tb
- raise
- if not futures: # maybe `t` scheduled another task
- return
- else:
- tasks = () # don't reschedule this task
- # add waiter task if requested
- if get_waiter is not None:
- futures[get_waiter()] = None
- # schedule tasks
- for t in tasks:
- fut = self.submit()( # type: ignore[misc]
- run_with_retry,
- t,
- retry_policy,
- configurable={
- CONFIG_KEY_CALL: partial(
- _call,
- weakref.ref(t),
- retry_policy=retry_policy,
- futures=weakref.ref(futures),
- schedule_task=schedule_task,
- submit=self.submit,
- ),
- },
- __reraise_on_exit__=reraise,
- )
- futures[fut] = t
- # execute tasks, and wait for one to fail or all to finish.
- # each task is independent from all other concurrent tasks
- # yield updates/debug output as each task finishes
- end_time = timeout + time.monotonic() if timeout else None
- while len(futures) > (1 if get_waiter is not None else 0):
- done, inflight = concurrent.futures.wait(
- futures,
- return_when=concurrent.futures.FIRST_COMPLETED,
- timeout=(max(0, end_time - time.monotonic()) if end_time else None),
- )
- if not done:
- break # timed out
- for fut in done:
- task = futures.pop(fut)
- if task is None:
- # waiter task finished, schedule another
- if inflight and get_waiter is not None:
- futures[get_waiter()] = None
- else:
- # remove references to loop vars
- del fut, task
- # maybe stop other tasks
- if _should_stop_others(done):
- break
- # give control back to the caller
- yield
- # wait for done callbacks
- futures.event.wait(
- timeout=(max(0, end_time - time.monotonic()) if end_time else None)
- )
- # give control back to the caller
- yield
- # panic on failure or timeout
- try:
- _panic_or_proceed(
- futures.done.union(f for f, t in futures.items() if t is not None),
- panic=reraise,
- )
- except Exception as exc:
- if tb := exc.__traceback__:
- while tb.tb_next is not None and any(
- tb.tb_frame.f_code.co_filename.endswith(name)
- for name in EXCLUDED_FRAME_FNAMES
- ):
- tb = tb.tb_next
- exc.__traceback__ = tb
- raise
- async def atick(
- self,
- tasks: Iterable[PregelExecutableTask],
- *,
- reraise: bool = True,
- timeout: float | None = None,
- retry_policy: Sequence[RetryPolicy] | None = None,
- get_waiter: Callable[[], asyncio.Future[None]] | None = None,
- schedule_task: Callable[
- [PregelExecutableTask, int, Call | None],
- Awaitable[PregelExecutableTask | None],
- ],
- ) -> AsyncIterator[None]:
- try:
- loop = asyncio.get_event_loop()
- except RuntimeError:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- tasks = tuple(tasks)
- futures = FuturesDict(
- callback=weakref.WeakMethod(self.commit),
- event=asyncio.Event(),
- future_type=asyncio.Future,
- )
- # give control back to the caller
- yield
- # fast path if single task with no waiter and no timeout
- if len(tasks) == 0:
- return
- elif len(tasks) == 1 and get_waiter is None and timeout is None:
- t = tasks[0]
- try:
- await arun_with_retry(
- t,
- retry_policy,
- stream=self.use_astream,
- configurable={
- CONFIG_KEY_CALL: partial(
- _acall,
- weakref.ref(t),
- stream=self.use_astream,
- retry_policy=retry_policy,
- futures=weakref.ref(futures),
- schedule_task=schedule_task,
- submit=self.submit,
- loop=loop,
- ),
- },
- )
- self.commit(t, None)
- except Exception as exc:
- self.commit(t, exc)
- if reraise and futures:
- # will be re-raised after futures are done
- fut: asyncio.Future = loop.create_future()
- fut.set_exception(exc)
- futures.done.add(fut)
- elif reraise:
- if tb := exc.__traceback__:
- while tb.tb_next is not None and any(
- tb.tb_frame.f_code.co_filename.endswith(name)
- for name in EXCLUDED_FRAME_FNAMES
- ):
- tb = tb.tb_next
- exc.__traceback__ = tb
- raise
- if not futures: # maybe `t` scheduled another task
- return
- else:
- tasks = () # don't reschedule this task
- # add waiter task if requested
- if get_waiter is not None:
- futures[get_waiter()] = None
- # schedule tasks
- for t in tasks:
- fut = cast(
- asyncio.Future,
- self.submit()( # type: ignore[misc]
- arun_with_retry,
- t,
- retry_policy,
- stream=self.use_astream,
- configurable={
- CONFIG_KEY_CALL: partial(
- _acall,
- weakref.ref(t),
- retry_policy=retry_policy,
- stream=self.use_astream,
- futures=weakref.ref(futures),
- schedule_task=schedule_task,
- submit=self.submit,
- loop=loop,
- ),
- },
- __name__=t.name,
- __cancel_on_exit__=True,
- __reraise_on_exit__=reraise,
- ),
- )
- futures[fut] = t
- # execute tasks, and wait for one to fail or all to finish.
- # each task is independent from all other concurrent tasks
- # yield updates/debug output as each task finishes
- end_time = timeout + loop.time() if timeout else None
- while len(futures) > (1 if get_waiter is not None else 0):
- done, inflight = await asyncio.wait(
- futures,
- return_when=asyncio.FIRST_COMPLETED,
- timeout=(max(0, end_time - loop.time()) if end_time else None),
- )
- if not done:
- break # timed out
- for fut in done:
- task = futures.pop(fut)
- if task is None:
- # waiter task finished, schedule another
- if inflight and get_waiter is not None:
- futures[get_waiter()] = None
- else:
- # remove references to loop vars
- del fut, task
- # maybe stop other tasks
- if _should_stop_others(done):
- break
- # give control back to the caller
- yield
- # wait for done callbacks
- await asyncio.wait_for(
- futures.event.wait(),
- timeout=(max(0, end_time - loop.time()) if end_time else None),
- )
- # give control back to the caller
- yield
- # cancel waiter task
- for fut in futures:
- fut.cancel()
- # panic on failure or timeout
- try:
- _panic_or_proceed(
- futures.done.union(f for f, t in futures.items() if t is not None),
- timeout_exc_cls=asyncio.TimeoutError,
- panic=reraise,
- )
- except Exception as exc:
- if tb := exc.__traceback__:
- while tb.tb_next is not None and any(
- tb.tb_frame.f_code.co_filename.endswith(name)
- for name in EXCLUDED_FRAME_FNAMES
- ):
- tb = tb.tb_next
- exc.__traceback__ = tb
- raise
- def commit(
- self,
- task: PregelExecutableTask,
- exception: BaseException | None,
- ) -> None:
- if isinstance(exception, asyncio.CancelledError):
- # for cancelled tasks, also save error in task,
- # so loop can finish super-step
- task.writes.append((ERROR, exception))
- self.put_writes()(task.id, task.writes) # type: ignore[misc]
- elif exception:
- if isinstance(exception, GraphInterrupt):
- # save interrupt to checkpointer
- if exception.args[0]:
- writes = [(INTERRUPT, exception.args[0])]
- if resumes := [w for w in task.writes if w[0] == RESUME]:
- writes.extend(resumes)
- self.put_writes()(task.id, writes) # type: ignore[misc]
- elif isinstance(exception, GraphBubbleUp):
- # exception will be raised in _panic_or_proceed
- pass
- else:
- # save error to checkpointer
- task.writes.append((ERROR, exception))
- self.put_writes()(task.id, task.writes) # type: ignore[misc]
- else:
- if self.node_finished and (
- task.config is None or TAG_HIDDEN not in task.config.get("tags", [])
- ):
- self.node_finished(task.name)
- if not task.writes:
- # add no writes marker
- task.writes.append((NO_WRITES, None))
- # save task writes to checkpointer
- self.put_writes()(task.id, task.writes) # type: ignore[misc]
- def _should_stop_others(
- done: set[F],
- ) -> bool:
- """Check if any task failed, if so, cancel all other tasks.
- GraphInterrupts are not considered failures."""
- for fut in done:
- if fut.cancelled():
- continue
- elif exc := fut.exception():
- if not isinstance(exc, GraphBubbleUp) and fut not in SKIP_RERAISE_SET:
- return True
- return False
- def _exception(
- fut: concurrent.futures.Future[Any] | asyncio.Future[Any],
- ) -> BaseException | None:
- """Return the exception from a future, without raising CancelledError."""
- if fut.cancelled():
- if isinstance(fut, asyncio.Future):
- return asyncio.CancelledError()
- else:
- return concurrent.futures.CancelledError()
- else:
- return fut.exception()
- def _panic_or_proceed(
- futs: set[concurrent.futures.Future] | set[asyncio.Future],
- *,
- timeout_exc_cls: type[Exception] = TimeoutError,
- panic: bool = True,
- ) -> None:
- """Cancel remaining tasks if any failed, re-raise exception if panic is True."""
- done: set[concurrent.futures.Future[Any] | asyncio.Future[Any]] = set()
- inflight: set[concurrent.futures.Future[Any] | asyncio.Future[Any]] = set()
- for fut in futs:
- if fut.cancelled():
- continue
- elif fut.done():
- done.add(fut)
- else:
- inflight.add(fut)
- interrupts: list[GraphInterrupt] = []
- while done:
- # if any task failed
- fut = done.pop()
- if exc := _exception(fut):
- # cancel all pending tasks
- while inflight:
- inflight.pop().cancel()
- # raise the exception
- if panic:
- if isinstance(exc, GraphInterrupt):
- # collect interrupts
- interrupts.append(exc)
- elif fut not in SKIP_RERAISE_SET:
- raise exc
- # raise combined interrupts
- if interrupts:
- raise GraphInterrupt(tuple(i for exc in interrupts for i in exc.args[0]))
- if inflight:
- # if we got here means we timed out
- while inflight:
- # cancel all pending tasks
- inflight.pop().cancel()
- # raise timeout error
- raise timeout_exc_cls("Timed out")
- def _call(
- task: weakref.ref[PregelExecutableTask],
- func: Callable[[Any], Awaitable[Any] | Any],
- input: Any,
- *,
- retry_policy: Sequence[RetryPolicy] | None = None,
- cache_policy: CachePolicy | None = None,
- callbacks: Callbacks = None,
- futures: weakref.ref[FuturesDict],
- schedule_task: Callable[
- [PregelExecutableTask, int, Call | None], PregelExecutableTask | None
- ],
- submit: weakref.ref[Submit],
- ) -> concurrent.futures.Future[Any]:
- if inspect.iscoroutinefunction(func):
- raise RuntimeError("In an sync context async tasks cannot be called")
- fut: concurrent.futures.Future | None = None
- # schedule PUSH tasks, collect futures
- scratchpad: PregelScratchpad = task().config[CONF][CONFIG_KEY_SCRATCHPAD] # type: ignore[union-attr]
- # schedule the next task, if the callback returns one
- if next_task := schedule_task(
- task(), # type: ignore[arg-type]
- scratchpad.call_counter(),
- Call(
- func,
- input,
- retry_policy=retry_policy,
- cache_policy=cache_policy,
- callbacks=callbacks,
- ),
- ):
- if fut := next(
- (
- f
- for f, t in futures().items() # type: ignore[union-attr]
- if t is not None and t == next_task.id
- ),
- None,
- ):
- # if the parent task was retried,
- # the next task might already be running
- pass
- elif next_task.writes:
- # if it already ran, return the result
- fut = concurrent.futures.Future()
- ret = next((v for c, v in next_task.writes if c == RETURN), MISSING)
- if ret is not MISSING:
- fut.set_result(ret)
- elif exc := next((v for c, v in next_task.writes if c == ERROR), None):
- fut.set_exception(
- exc if isinstance(exc, BaseException) else Exception(exc)
- )
- else:
- fut.set_result(None)
- else:
- # schedule the next task
- fut = submit()( # type: ignore[misc]
- run_with_retry,
- next_task,
- retry_policy,
- configurable={
- CONFIG_KEY_CALL: partial(
- _call,
- weakref.ref(next_task),
- futures=futures,
- retry_policy=retry_policy,
- callbacks=callbacks,
- schedule_task=schedule_task,
- submit=submit,
- ),
- },
- __reraise_on_exit__=False,
- # starting a new task in the next tick ensures
- # updates from this tick are committed/streamed first
- __next_tick__=True,
- )
- # exceptions for call() tasks are raised into the parent task
- # so we should not re-raise at the end of the tick
- SKIP_RERAISE_SET.add(fut)
- futures()[fut] = next_task # type: ignore[index]
- fut = cast(asyncio.Future | concurrent.futures.Future, fut)
- # return a chained future to ensure commit() callback is called
- # before the returned future is resolved, to ensure stream order etc
- return chain_future(fut, concurrent.futures.Future())
- def _acall(
- task: weakref.ref[PregelExecutableTask],
- func: Callable[[Any], Awaitable[Any] | Any],
- input: Any,
- *,
- retry_policy: Sequence[RetryPolicy] | None = None,
- cache_policy: CachePolicy | None = None,
- callbacks: Callbacks = None,
- # injected dependencies
- futures: weakref.ref[FuturesDict],
- schedule_task: Callable[
- [PregelExecutableTask, int, Call | None],
- Awaitable[PregelExecutableTask | None],
- ],
- submit: weakref.ref[Submit],
- loop: asyncio.AbstractEventLoop,
- stream: bool = False,
- ) -> asyncio.Future[Any] | concurrent.futures.Future[Any]:
- # return a chained future to ensure commit() callback is called
- # before the returned future is resolved, to ensure stream order etc
- try:
- in_async = asyncio.current_task() is not None
- except RuntimeError:
- in_async = False
- # if in async context return an async future, otherwise return a sync future
- if in_async:
- fut: asyncio.Future[Any] | concurrent.futures.Future[Any] = asyncio.Future(
- loop=loop
- )
- else:
- fut = concurrent.futures.Future()
- # schedule the next task
- run_coroutine_threadsafe(
- _acall_impl(
- fut,
- task,
- func,
- input,
- retry_policy=retry_policy,
- cache_policy=cache_policy,
- callbacks=callbacks,
- futures=futures,
- schedule_task=schedule_task,
- submit=submit,
- loop=loop,
- stream=stream,
- ),
- loop,
- lazy=False,
- )
- return fut
- async def _acall_impl(
- destination: asyncio.Future[Any] | concurrent.futures.Future[Any],
- task: weakref.ref[PregelExecutableTask],
- func: Callable[[Any], Awaitable[Any] | Any],
- input: Any,
- *,
- retry_policy: Sequence[RetryPolicy] | None = None,
- cache_policy: CachePolicy | None = None,
- callbacks: Callbacks = None,
- # injected dependencies
- futures: weakref.ref[FuturesDict[asyncio.Future, asyncio.Event]],
- schedule_task: Callable[
- [PregelExecutableTask, int, Call | None],
- Awaitable[PregelExecutableTask | None],
- ],
- submit: weakref.ref[Submit],
- loop: asyncio.AbstractEventLoop,
- stream: bool = False,
- ) -> None:
- try:
- fut: asyncio.Future | None = None
- # schedule PUSH tasks, collect futures
- scratchpad: PregelScratchpad = task().config[CONF][CONFIG_KEY_SCRATCHPAD] # type: ignore[union-attr]
- # schedule the next task, if the callback returns one
- if next_task := await schedule_task(
- task(), # type: ignore[arg-type]
- scratchpad.call_counter(),
- Call(
- func,
- input,
- retry_policy=retry_policy,
- cache_policy=cache_policy,
- callbacks=callbacks,
- ),
- ):
- if fut := next(
- (
- f
- for f, t in futures().items() # type: ignore[union-attr]
- if t is not None and t == next_task.id
- ),
- None,
- ):
- # if the parent task was retried,
- # the next task might already be running
- pass
- elif next_task.writes:
- # if it already ran, return the result
- fut = asyncio.Future(loop=loop)
- ret = next((v for c, v in next_task.writes if c == RETURN), MISSING)
- if ret is not MISSING:
- fut.set_result(ret)
- elif exc := next((v for c, v in next_task.writes if c == ERROR), None):
- fut.set_exception(
- exc if isinstance(exc, BaseException) else Exception(exc)
- )
- else:
- fut.set_result(None)
- futures()[fut] = next_task # type: ignore[index]
- else:
- # schedule the next task
- fut = cast(
- asyncio.Future,
- submit()( # type: ignore[misc]
- arun_with_retry,
- next_task,
- retry_policy,
- stream=stream,
- configurable={
- CONFIG_KEY_CALL: partial(
- _acall,
- weakref.ref(next_task),
- stream=stream,
- futures=futures,
- schedule_task=schedule_task,
- submit=submit,
- loop=loop,
- ),
- },
- __name__=next_task.name,
- __cancel_on_exit__=True,
- __reraise_on_exit__=False,
- # starting a new task in the next tick ensures
- # updates from this tick are committed/streamed first
- __next_tick__=True,
- ),
- )
- # exceptions for call() tasks are raised into the parent task
- # so we should not re-raise at the end of the tick
- SKIP_RERAISE_SET.add(fut)
- futures()[fut] = next_task # type: ignore[index]
- if fut is not None:
- chain_future(fut, destination)
- else:
- destination.set_exception(RuntimeError("Task not scheduled"))
- except Exception as exc:
- destination.set_exception(exc)
|