summarization.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. """Summarization middleware."""
  2. import uuid
  3. import warnings
  4. from collections.abc import Callable, Iterable, Mapping
  5. from functools import partial
  6. from typing import Any, Literal, cast
  7. from langchain_core.messages import (
  8. AnyMessage,
  9. MessageLikeRepresentation,
  10. RemoveMessage,
  11. ToolMessage,
  12. )
  13. from langchain_core.messages.human import HumanMessage
  14. from langchain_core.messages.utils import count_tokens_approximately, trim_messages
  15. from langgraph.graph.message import (
  16. REMOVE_ALL_MESSAGES,
  17. )
  18. from langgraph.runtime import Runtime
  19. from langchain.agents.middleware.types import AgentMiddleware, AgentState
  20. from langchain.chat_models import BaseChatModel, init_chat_model
  21. TokenCounter = Callable[[Iterable[MessageLikeRepresentation]], int]
  22. DEFAULT_SUMMARY_PROMPT = """<role>
  23. Context Extraction Assistant
  24. </role>
  25. <primary_objective>
  26. Your sole objective in this task is to extract the highest quality/most relevant context from the conversation history below.
  27. </primary_objective>
  28. <objective_information>
  29. You're nearing the total number of input tokens you can accept, so you must extract the highest quality/most relevant pieces of information from your conversation history.
  30. This context will then overwrite the conversation history presented below. Because of this, ensure the context you extract is only the most important information to your overall goal.
  31. </objective_information>
  32. <instructions>
  33. The conversation history below will be replaced with the context you extract in this step. Because of this, you must do your very best to extract and record all of the most important context from the conversation history.
  34. You want to ensure that you don't repeat any actions you've already completed, so the context you extract from the conversation history should be focused on the most important information to your overall goal.
  35. </instructions>
  36. The user will message you with the full message history you'll be extracting context from, to then replace. Carefully read over it all, and think deeply about what information is most important to your overall goal that should be saved:
  37. With all of this in mind, please carefully read over the entire conversation history, and extract the most important and relevant context to replace it so that you can free up space in the conversation history.
  38. Respond ONLY with the extracted context. Do not include any additional information, or text before or after the extracted context.
  39. <messages>
  40. Messages to summarize:
  41. {messages}
  42. </messages>""" # noqa: E501
  43. _DEFAULT_MESSAGES_TO_KEEP = 20
  44. _DEFAULT_TRIM_TOKEN_LIMIT = 4000
  45. _DEFAULT_FALLBACK_MESSAGE_COUNT = 15
  46. ContextFraction = tuple[Literal["fraction"], float]
  47. """Fraction of model's maximum input tokens.
  48. Example:
  49. To specify 50% of the model's max input tokens:
  50. ```python
  51. ("fraction", 0.5)
  52. ```
  53. """
  54. ContextTokens = tuple[Literal["tokens"], int]
  55. """Absolute number of tokens.
  56. Example:
  57. To specify 3000 tokens:
  58. ```python
  59. ("tokens", 3000)
  60. ```
  61. """
  62. ContextMessages = tuple[Literal["messages"], int]
  63. """Absolute number of messages.
  64. Example:
  65. To specify 50 messages:
  66. ```python
  67. ("messages", 50)
  68. ```
  69. """
  70. ContextSize = ContextFraction | ContextTokens | ContextMessages
  71. """Union type for context size specifications.
  72. Can be either:
  73. - [`ContextFraction`][langchain.agents.middleware.summarization.ContextFraction]: A
  74. fraction of the model's maximum input tokens.
  75. - [`ContextTokens`][langchain.agents.middleware.summarization.ContextTokens]: An absolute
  76. number of tokens.
  77. - [`ContextMessages`][langchain.agents.middleware.summarization.ContextMessages]: An
  78. absolute number of messages.
  79. Depending on use with `trigger` or `keep` parameters, this type indicates either
  80. when to trigger summarization or how much context to retain.
  81. Example:
  82. ```python
  83. # ContextFraction
  84. context_size: ContextSize = ("fraction", 0.5)
  85. # ContextTokens
  86. context_size: ContextSize = ("tokens", 3000)
  87. # ContextMessages
  88. context_size: ContextSize = ("messages", 50)
  89. ```
  90. """
  91. def _get_approximate_token_counter(model: BaseChatModel) -> TokenCounter:
  92. """Tune parameters of approximate token counter based on model type."""
  93. if model._llm_type == "anthropic-chat":
  94. # 3.3 was estimated in an offline experiment, comparing with Claude's token-counting
  95. # API: https://platform.claude.com/docs/en/build-with-claude/token-counting
  96. return partial(count_tokens_approximately, chars_per_token=3.3)
  97. return count_tokens_approximately
  98. class SummarizationMiddleware(AgentMiddleware):
  99. """Summarizes conversation history when token limits are approached.
  100. This middleware monitors message token counts and automatically summarizes older
  101. messages when a threshold is reached, preserving recent messages and maintaining
  102. context continuity by ensuring AI/Tool message pairs remain together.
  103. """
  104. def __init__(
  105. self,
  106. model: str | BaseChatModel,
  107. *,
  108. trigger: ContextSize | list[ContextSize] | None = None,
  109. keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
  110. token_counter: TokenCounter = count_tokens_approximately,
  111. summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
  112. trim_tokens_to_summarize: int | None = _DEFAULT_TRIM_TOKEN_LIMIT,
  113. **deprecated_kwargs: Any,
  114. ) -> None:
  115. """Initialize summarization middleware.
  116. Args:
  117. model: The language model to use for generating summaries.
  118. trigger: One or more thresholds that trigger summarization.
  119. Provide a single
  120. [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
  121. tuple or a list of tuples, in which case summarization runs when any
  122. threshold is met.
  123. !!! example
  124. ```python
  125. # Trigger summarization when 50 messages is reached
  126. ("messages", 50)
  127. # Trigger summarization when 3000 tokens is reached
  128. ("tokens", 3000)
  129. # Trigger summarization either when 80% of model's max input tokens
  130. # is reached or when 100 messages is reached (whichever comes first)
  131. [("fraction", 0.8), ("messages", 100)]
  132. ```
  133. See [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
  134. for more details.
  135. keep: Context retention policy applied after summarization.
  136. Provide a [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
  137. tuple to specify how much history to preserve.
  138. Defaults to keeping the most recent `20` messages.
  139. Does not support multiple values like `trigger`.
  140. !!! example
  141. ```python
  142. # Keep the most recent 20 messages
  143. ("messages", 20)
  144. # Keep the most recent 3000 tokens
  145. ("tokens", 3000)
  146. # Keep the most recent 30% of the model's max input tokens
  147. ("fraction", 0.3)
  148. ```
  149. token_counter: Function to count tokens in messages.
  150. summary_prompt: Prompt template for generating summaries.
  151. trim_tokens_to_summarize: Maximum tokens to keep when preparing messages for
  152. the summarization call.
  153. Pass `None` to skip trimming entirely.
  154. """
  155. # Handle deprecated parameters
  156. if "max_tokens_before_summary" in deprecated_kwargs:
  157. value = deprecated_kwargs["max_tokens_before_summary"]
  158. warnings.warn(
  159. "max_tokens_before_summary is deprecated. Use trigger=('tokens', value) instead.",
  160. DeprecationWarning,
  161. stacklevel=2,
  162. )
  163. if trigger is None and value is not None:
  164. trigger = ("tokens", value)
  165. if "messages_to_keep" in deprecated_kwargs:
  166. value = deprecated_kwargs["messages_to_keep"]
  167. warnings.warn(
  168. "messages_to_keep is deprecated. Use keep=('messages', value) instead.",
  169. DeprecationWarning,
  170. stacklevel=2,
  171. )
  172. if keep == ("messages", _DEFAULT_MESSAGES_TO_KEEP):
  173. keep = ("messages", value)
  174. super().__init__()
  175. if isinstance(model, str):
  176. model = init_chat_model(model)
  177. self.model = model
  178. if trigger is None:
  179. self.trigger: ContextSize | list[ContextSize] | None = None
  180. trigger_conditions: list[ContextSize] = []
  181. elif isinstance(trigger, list):
  182. validated_list = [self._validate_context_size(item, "trigger") for item in trigger]
  183. self.trigger = validated_list
  184. trigger_conditions = validated_list
  185. else:
  186. validated = self._validate_context_size(trigger, "trigger")
  187. self.trigger = validated
  188. trigger_conditions = [validated]
  189. self._trigger_conditions = trigger_conditions
  190. self.keep = self._validate_context_size(keep, "keep")
  191. if token_counter is count_tokens_approximately:
  192. self.token_counter = _get_approximate_token_counter(self.model)
  193. else:
  194. self.token_counter = token_counter
  195. self.summary_prompt = summary_prompt
  196. self.trim_tokens_to_summarize = trim_tokens_to_summarize
  197. requires_profile = any(condition[0] == "fraction" for condition in self._trigger_conditions)
  198. if self.keep[0] == "fraction":
  199. requires_profile = True
  200. if requires_profile and self._get_profile_limits() is None:
  201. msg = (
  202. "Model profile information is required to use fractional token limits, "
  203. "and is unavailable for the specified model. Please use absolute token "
  204. "counts instead, or pass "
  205. '`\n\nChatModel(..., profile={"max_input_tokens": ...})`.\n\n'
  206. "with a desired integer value of the model's maximum input tokens."
  207. )
  208. raise ValueError(msg)
  209. def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
  210. """Process messages before model invocation, potentially triggering summarization."""
  211. messages = state["messages"]
  212. self._ensure_message_ids(messages)
  213. total_tokens = self.token_counter(messages)
  214. if not self._should_summarize(messages, total_tokens):
  215. return None
  216. cutoff_index = self._determine_cutoff_index(messages)
  217. if cutoff_index <= 0:
  218. return None
  219. messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
  220. summary = self._create_summary(messages_to_summarize)
  221. new_messages = self._build_new_messages(summary)
  222. return {
  223. "messages": [
  224. RemoveMessage(id=REMOVE_ALL_MESSAGES),
  225. *new_messages,
  226. *preserved_messages,
  227. ]
  228. }
  229. async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
  230. """Process messages before model invocation, potentially triggering summarization."""
  231. messages = state["messages"]
  232. self._ensure_message_ids(messages)
  233. total_tokens = self.token_counter(messages)
  234. if not self._should_summarize(messages, total_tokens):
  235. return None
  236. cutoff_index = self._determine_cutoff_index(messages)
  237. if cutoff_index <= 0:
  238. return None
  239. messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
  240. summary = await self._acreate_summary(messages_to_summarize)
  241. new_messages = self._build_new_messages(summary)
  242. return {
  243. "messages": [
  244. RemoveMessage(id=REMOVE_ALL_MESSAGES),
  245. *new_messages,
  246. *preserved_messages,
  247. ]
  248. }
  249. def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
  250. """Determine whether summarization should run for the current token usage."""
  251. if not self._trigger_conditions:
  252. return False
  253. for kind, value in self._trigger_conditions:
  254. if kind == "messages" and len(messages) >= value:
  255. return True
  256. if kind == "tokens" and total_tokens >= value:
  257. return True
  258. if kind == "fraction":
  259. max_input_tokens = self._get_profile_limits()
  260. if max_input_tokens is None:
  261. continue
  262. threshold = int(max_input_tokens * value)
  263. if threshold <= 0:
  264. threshold = 1
  265. if total_tokens >= threshold:
  266. return True
  267. return False
  268. def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
  269. """Choose cutoff index respecting retention configuration."""
  270. kind, value = self.keep
  271. if kind in {"tokens", "fraction"}:
  272. token_based_cutoff = self._find_token_based_cutoff(messages)
  273. if token_based_cutoff is not None:
  274. return token_based_cutoff
  275. # None cutoff -> model profile data not available (caught in __init__ but
  276. # here for safety), fallback to message count
  277. return self._find_safe_cutoff(messages, _DEFAULT_MESSAGES_TO_KEEP)
  278. return self._find_safe_cutoff(messages, cast("int", value))
  279. def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
  280. """Find cutoff index based on target token retention."""
  281. if not messages:
  282. return 0
  283. kind, value = self.keep
  284. if kind == "fraction":
  285. max_input_tokens = self._get_profile_limits()
  286. if max_input_tokens is None:
  287. return None
  288. target_token_count = int(max_input_tokens * value)
  289. elif kind == "tokens":
  290. target_token_count = int(value)
  291. else:
  292. return None
  293. if target_token_count <= 0:
  294. target_token_count = 1
  295. if self.token_counter(messages) <= target_token_count:
  296. return 0
  297. # Use binary search to identify the earliest message index that keeps the
  298. # suffix within the token budget.
  299. left, right = 0, len(messages)
  300. cutoff_candidate = len(messages)
  301. max_iterations = len(messages).bit_length() + 1
  302. for _ in range(max_iterations):
  303. if left >= right:
  304. break
  305. mid = (left + right) // 2
  306. if self.token_counter(messages[mid:]) <= target_token_count:
  307. cutoff_candidate = mid
  308. right = mid
  309. else:
  310. left = mid + 1
  311. if cutoff_candidate == len(messages):
  312. cutoff_candidate = left
  313. if cutoff_candidate >= len(messages):
  314. if len(messages) == 1:
  315. return 0
  316. cutoff_candidate = len(messages) - 1
  317. # Advance past any ToolMessages to avoid splitting AI/Tool pairs
  318. return self._find_safe_cutoff_point(messages, cutoff_candidate)
  319. def _get_profile_limits(self) -> int | None:
  320. """Retrieve max input token limit from the model profile."""
  321. try:
  322. profile = self.model.profile
  323. except AttributeError:
  324. return None
  325. if not isinstance(profile, Mapping):
  326. return None
  327. max_input_tokens = profile.get("max_input_tokens")
  328. if not isinstance(max_input_tokens, int):
  329. return None
  330. return max_input_tokens
  331. def _validate_context_size(self, context: ContextSize, parameter_name: str) -> ContextSize:
  332. """Validate context configuration tuples."""
  333. kind, value = context
  334. if kind == "fraction":
  335. if not 0 < value <= 1:
  336. msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
  337. raise ValueError(msg)
  338. elif kind in {"tokens", "messages"}:
  339. if value <= 0:
  340. msg = f"{parameter_name} thresholds must be greater than 0, got {value}."
  341. raise ValueError(msg)
  342. else:
  343. msg = f"Unsupported context size type {kind} for {parameter_name}."
  344. raise ValueError(msg)
  345. return context
  346. def _build_new_messages(self, summary: str) -> list[HumanMessage]:
  347. return [
  348. HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}")
  349. ]
  350. def _ensure_message_ids(self, messages: list[AnyMessage]) -> None:
  351. """Ensure all messages have unique IDs for the add_messages reducer."""
  352. for msg in messages:
  353. if msg.id is None:
  354. msg.id = str(uuid.uuid4())
  355. def _partition_messages(
  356. self,
  357. conversation_messages: list[AnyMessage],
  358. cutoff_index: int,
  359. ) -> tuple[list[AnyMessage], list[AnyMessage]]:
  360. """Partition messages into those to summarize and those to preserve."""
  361. messages_to_summarize = conversation_messages[:cutoff_index]
  362. preserved_messages = conversation_messages[cutoff_index:]
  363. return messages_to_summarize, preserved_messages
  364. def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -> int:
  365. """Find safe cutoff point that preserves AI/Tool message pairs.
  366. Returns the index where messages can be safely cut without separating
  367. related AI and Tool messages. Returns `0` if no safe cutoff is found.
  368. This is aggressive with summarization - if the target cutoff lands in the
  369. middle of tool messages, we advance past all of them (summarizing more).
  370. """
  371. if len(messages) <= messages_to_keep:
  372. return 0
  373. target_cutoff = len(messages) - messages_to_keep
  374. return self._find_safe_cutoff_point(messages, target_cutoff)
  375. def _find_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> int:
  376. """Find a safe cutoff point that doesn't split AI/Tool message pairs.
  377. If the message at cutoff_index is a ToolMessage, advance until we find
  378. a non-ToolMessage. This ensures we never cut in the middle of parallel
  379. tool call responses.
  380. """
  381. while cutoff_index < len(messages) and isinstance(messages[cutoff_index], ToolMessage):
  382. cutoff_index += 1
  383. return cutoff_index
  384. def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
  385. """Generate summary for the given messages."""
  386. if not messages_to_summarize:
  387. return "No previous conversation history."
  388. trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
  389. if not trimmed_messages:
  390. return "Previous conversation was too long to summarize."
  391. try:
  392. response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
  393. return response.text.strip()
  394. except Exception as e: # noqa: BLE001
  395. return f"Error generating summary: {e!s}"
  396. async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
  397. """Generate summary for the given messages."""
  398. if not messages_to_summarize:
  399. return "No previous conversation history."
  400. trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
  401. if not trimmed_messages:
  402. return "Previous conversation was too long to summarize."
  403. try:
  404. response = await self.model.ainvoke(
  405. self.summary_prompt.format(messages=trimmed_messages)
  406. )
  407. return response.text.strip()
  408. except Exception as e: # noqa: BLE001
  409. return f"Error generating summary: {e!s}"
  410. def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
  411. """Trim messages to fit within summary generation limits."""
  412. try:
  413. if self.trim_tokens_to_summarize is None:
  414. return messages
  415. return cast(
  416. "list[AnyMessage]",
  417. trim_messages(
  418. messages,
  419. max_tokens=self.trim_tokens_to_summarize,
  420. token_counter=self.token_counter,
  421. start_on="human",
  422. strategy="last",
  423. allow_partial=True,
  424. include_system=True,
  425. ),
  426. )
  427. except Exception: # noqa: BLE001
  428. return messages[-_DEFAULT_FALLBACK_MESSAGE_COUNT:]