| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215 |
- from __future__ import annotations
- import asyncio
- import logging
- import random
- import sys
- import time
- from collections.abc import Awaitable, Callable, Sequence
- from dataclasses import replace
- from typing import Any
- from langgraph._internal._config import patch_configurable
- from langgraph._internal._constants import (
- CONF,
- CONFIG_KEY_CHECKPOINT_NS,
- CONFIG_KEY_RESUMING,
- NS_SEP,
- )
- from langgraph.errors import GraphBubbleUp, ParentCommand
- from langgraph.types import Command, PregelExecutableTask, RetryPolicy
- logger = logging.getLogger(__name__)
- SUPPORTS_EXC_NOTES = sys.version_info >= (3, 11)
- def run_with_retry(
- task: PregelExecutableTask,
- retry_policy: Sequence[RetryPolicy] | None,
- configurable: dict[str, Any] | None = None,
- ) -> None:
- """Run a task with retries."""
- retry_policy = task.retry_policy or retry_policy
- attempts = 0
- config = task.config
- if configurable is not None:
- config = patch_configurable(config, configurable)
- while True:
- try:
- # clear any writes from previous attempts
- task.writes.clear()
- # run the task
- return task.proc.invoke(task.input, config)
- except ParentCommand as exc:
- ns: str = config[CONF][CONFIG_KEY_CHECKPOINT_NS]
- cmd = exc.args[0]
- if cmd.graph in (ns, task.name):
- # this command is for the current graph, handle it
- for w in task.writers:
- w.invoke(cmd, config)
- break
- elif cmd.graph == Command.PARENT:
- # this command is for the parent graph, assign it to the parent
- parts = ns.split(NS_SEP)
- if parts[-1].isdigit():
- parts.pop()
- parent_ns = NS_SEP.join(parts[:-1])
- exc.args = (replace(cmd, graph=parent_ns),)
- # bubble up
- raise
- except GraphBubbleUp:
- # if interrupted, end
- raise
- except Exception as exc:
- if SUPPORTS_EXC_NOTES:
- exc.add_note(f"During task with name '{task.name}' and id '{task.id}'")
- if not retry_policy:
- raise
- # Check which retry policy applies to this exception
- matching_policy = None
- for policy in retry_policy:
- if _should_retry_on(policy, exc):
- matching_policy = policy
- break
- if not matching_policy:
- raise
- # increment attempts
- attempts += 1
- # check if we should give up
- if attempts >= matching_policy.max_attempts:
- raise
- # sleep before retrying
- interval = matching_policy.initial_interval
- # Apply backoff factor based on attempt count
- interval = min(
- matching_policy.max_interval,
- interval * (matching_policy.backoff_factor ** (attempts - 1)),
- )
- # Apply jitter if configured
- sleep_time = (
- interval + random.uniform(0, 1) if matching_policy.jitter else interval
- )
- time.sleep(sleep_time)
- # log the retry
- logger.info(
- f"Retrying task {task.name} after {sleep_time:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}",
- exc_info=exc,
- )
- # signal subgraphs to resume (if available)
- config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
- async def arun_with_retry(
- task: PregelExecutableTask,
- retry_policy: Sequence[RetryPolicy] | None,
- stream: bool = False,
- match_cached_writes: Callable[[], Awaitable[Sequence[PregelExecutableTask]]]
- | None = None,
- configurable: dict[str, Any] | None = None,
- ) -> None:
- """Run a task asynchronously with retries."""
- retry_policy = task.retry_policy or retry_policy
- attempts = 0
- config = task.config
- if configurable is not None:
- config = patch_configurable(config, configurable)
- if match_cached_writes is not None and task.cache_key is not None:
- for t in await match_cached_writes():
- if t is task:
- # if the task is already cached, return
- return
- while True:
- try:
- # clear any writes from previous attempts
- task.writes.clear()
- # run the task
- if stream:
- async for _ in task.proc.astream(task.input, config):
- pass
- # if successful, end
- break
- else:
- return await task.proc.ainvoke(task.input, config)
- except ParentCommand as exc:
- ns: str = config[CONF][CONFIG_KEY_CHECKPOINT_NS]
- cmd = exc.args[0]
- if cmd.graph in (ns, task.name):
- # this command is for the current graph, handle it
- for w in task.writers:
- w.invoke(cmd, config)
- break
- elif cmd.graph == Command.PARENT:
- # this command is for the parent graph, assign it to the parent
- parts = ns.split(NS_SEP)
- if parts[-1].isdigit():
- parts.pop()
- parent_ns = NS_SEP.join(parts[:-1])
- exc.args = (replace(cmd, graph=parent_ns),)
- # bubble up
- raise
- except GraphBubbleUp:
- # if interrupted, end
- raise
- except Exception as exc:
- if SUPPORTS_EXC_NOTES:
- exc.add_note(f"During task with name '{task.name}' and id '{task.id}'")
- if not retry_policy:
- raise
- # Check which retry policy applies to this exception
- matching_policy = None
- for policy in retry_policy:
- if _should_retry_on(policy, exc):
- matching_policy = policy
- break
- if not matching_policy:
- raise
- # increment attempts
- attempts += 1
- # check if we should give up
- if attempts >= matching_policy.max_attempts:
- raise
- # sleep before retrying
- interval = matching_policy.initial_interval
- # Apply backoff factor based on attempt count
- interval = min(
- matching_policy.max_interval,
- interval * (matching_policy.backoff_factor ** (attempts - 1)),
- )
- # Apply jitter if configured
- sleep_time = (
- interval + random.uniform(0, 1) if matching_policy.jitter else interval
- )
- await asyncio.sleep(sleep_time)
- # log the retry
- logger.info(
- f"Retrying task {task.name} after {sleep_time:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}",
- exc_info=exc,
- )
- # signal subgraphs to resume (if available)
- config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
- def _should_retry_on(retry_policy: RetryPolicy, exc: Exception) -> bool:
- """Check if the given exception should be retried based on the retry policy."""
- if isinstance(retry_policy.retry_on, Sequence):
- return isinstance(exc, tuple(retry_policy.retry_on))
- elif isinstance(retry_policy.retry_on, type) and issubclass(
- retry_policy.retry_on, Exception
- ):
- return isinstance(exc, retry_policy.retry_on)
- elif callable(retry_policy.retry_on):
- return retry_policy.retry_on(exc) # type: ignore[call-arg]
- else:
- raise TypeError(
- "retry_on must be an Exception class, a list or tuple of Exception classes, or a callable"
- )
|