_retry.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. from __future__ import annotations
  2. import asyncio
  3. import logging
  4. import random
  5. import sys
  6. import time
  7. from collections.abc import Awaitable, Callable, Sequence
  8. from dataclasses import replace
  9. from typing import Any
  10. from langgraph._internal._config import patch_configurable
  11. from langgraph._internal._constants import (
  12. CONF,
  13. CONFIG_KEY_CHECKPOINT_NS,
  14. CONFIG_KEY_RESUMING,
  15. NS_SEP,
  16. )
  17. from langgraph.errors import GraphBubbleUp, ParentCommand
  18. from langgraph.types import Command, PregelExecutableTask, RetryPolicy
  19. logger = logging.getLogger(__name__)
  20. SUPPORTS_EXC_NOTES = sys.version_info >= (3, 11)
  21. def run_with_retry(
  22. task: PregelExecutableTask,
  23. retry_policy: Sequence[RetryPolicy] | None,
  24. configurable: dict[str, Any] | None = None,
  25. ) -> None:
  26. """Run a task with retries."""
  27. retry_policy = task.retry_policy or retry_policy
  28. attempts = 0
  29. config = task.config
  30. if configurable is not None:
  31. config = patch_configurable(config, configurable)
  32. while True:
  33. try:
  34. # clear any writes from previous attempts
  35. task.writes.clear()
  36. # run the task
  37. return task.proc.invoke(task.input, config)
  38. except ParentCommand as exc:
  39. ns: str = config[CONF][CONFIG_KEY_CHECKPOINT_NS]
  40. cmd = exc.args[0]
  41. if cmd.graph in (ns, task.name):
  42. # this command is for the current graph, handle it
  43. for w in task.writers:
  44. w.invoke(cmd, config)
  45. break
  46. elif cmd.graph == Command.PARENT:
  47. # this command is for the parent graph, assign it to the parent
  48. parts = ns.split(NS_SEP)
  49. if parts[-1].isdigit():
  50. parts.pop()
  51. parent_ns = NS_SEP.join(parts[:-1])
  52. exc.args = (replace(cmd, graph=parent_ns),)
  53. # bubble up
  54. raise
  55. except GraphBubbleUp:
  56. # if interrupted, end
  57. raise
  58. except Exception as exc:
  59. if SUPPORTS_EXC_NOTES:
  60. exc.add_note(f"During task with name '{task.name}' and id '{task.id}'")
  61. if not retry_policy:
  62. raise
  63. # Check which retry policy applies to this exception
  64. matching_policy = None
  65. for policy in retry_policy:
  66. if _should_retry_on(policy, exc):
  67. matching_policy = policy
  68. break
  69. if not matching_policy:
  70. raise
  71. # increment attempts
  72. attempts += 1
  73. # check if we should give up
  74. if attempts >= matching_policy.max_attempts:
  75. raise
  76. # sleep before retrying
  77. interval = matching_policy.initial_interval
  78. # Apply backoff factor based on attempt count
  79. interval = min(
  80. matching_policy.max_interval,
  81. interval * (matching_policy.backoff_factor ** (attempts - 1)),
  82. )
  83. # Apply jitter if configured
  84. sleep_time = (
  85. interval + random.uniform(0, 1) if matching_policy.jitter else interval
  86. )
  87. time.sleep(sleep_time)
  88. # log the retry
  89. logger.info(
  90. f"Retrying task {task.name} after {sleep_time:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}",
  91. exc_info=exc,
  92. )
  93. # signal subgraphs to resume (if available)
  94. config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
  95. async def arun_with_retry(
  96. task: PregelExecutableTask,
  97. retry_policy: Sequence[RetryPolicy] | None,
  98. stream: bool = False,
  99. match_cached_writes: Callable[[], Awaitable[Sequence[PregelExecutableTask]]]
  100. | None = None,
  101. configurable: dict[str, Any] | None = None,
  102. ) -> None:
  103. """Run a task asynchronously with retries."""
  104. retry_policy = task.retry_policy or retry_policy
  105. attempts = 0
  106. config = task.config
  107. if configurable is not None:
  108. config = patch_configurable(config, configurable)
  109. if match_cached_writes is not None and task.cache_key is not None:
  110. for t in await match_cached_writes():
  111. if t is task:
  112. # if the task is already cached, return
  113. return
  114. while True:
  115. try:
  116. # clear any writes from previous attempts
  117. task.writes.clear()
  118. # run the task
  119. if stream:
  120. async for _ in task.proc.astream(task.input, config):
  121. pass
  122. # if successful, end
  123. break
  124. else:
  125. return await task.proc.ainvoke(task.input, config)
  126. except ParentCommand as exc:
  127. ns: str = config[CONF][CONFIG_KEY_CHECKPOINT_NS]
  128. cmd = exc.args[0]
  129. if cmd.graph in (ns, task.name):
  130. # this command is for the current graph, handle it
  131. for w in task.writers:
  132. w.invoke(cmd, config)
  133. break
  134. elif cmd.graph == Command.PARENT:
  135. # this command is for the parent graph, assign it to the parent
  136. parts = ns.split(NS_SEP)
  137. if parts[-1].isdigit():
  138. parts.pop()
  139. parent_ns = NS_SEP.join(parts[:-1])
  140. exc.args = (replace(cmd, graph=parent_ns),)
  141. # bubble up
  142. raise
  143. except GraphBubbleUp:
  144. # if interrupted, end
  145. raise
  146. except Exception as exc:
  147. if SUPPORTS_EXC_NOTES:
  148. exc.add_note(f"During task with name '{task.name}' and id '{task.id}'")
  149. if not retry_policy:
  150. raise
  151. # Check which retry policy applies to this exception
  152. matching_policy = None
  153. for policy in retry_policy:
  154. if _should_retry_on(policy, exc):
  155. matching_policy = policy
  156. break
  157. if not matching_policy:
  158. raise
  159. # increment attempts
  160. attempts += 1
  161. # check if we should give up
  162. if attempts >= matching_policy.max_attempts:
  163. raise
  164. # sleep before retrying
  165. interval = matching_policy.initial_interval
  166. # Apply backoff factor based on attempt count
  167. interval = min(
  168. matching_policy.max_interval,
  169. interval * (matching_policy.backoff_factor ** (attempts - 1)),
  170. )
  171. # Apply jitter if configured
  172. sleep_time = (
  173. interval + random.uniform(0, 1) if matching_policy.jitter else interval
  174. )
  175. await asyncio.sleep(sleep_time)
  176. # log the retry
  177. logger.info(
  178. f"Retrying task {task.name} after {sleep_time:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}",
  179. exc_info=exc,
  180. )
  181. # signal subgraphs to resume (if available)
  182. config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
  183. def _should_retry_on(retry_policy: RetryPolicy, exc: Exception) -> bool:
  184. """Check if the given exception should be retried based on the retry policy."""
  185. if isinstance(retry_policy.retry_on, Sequence):
  186. return isinstance(exc, tuple(retry_policy.retry_on))
  187. elif isinstance(retry_policy.retry_on, type) and issubclass(
  188. retry_policy.retry_on, Exception
  189. ):
  190. return isinstance(exc, retry_policy.retry_on)
  191. elif callable(retry_policy.retry_on):
  192. return retry_policy.retry_on(exc) # type: ignore[call-arg]
  193. else:
  194. raise TypeError(
  195. "retry_on must be an Exception class, a list or tuple of Exception classes, or a callable"
  196. )