pii.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. """PII detection and handling middleware for agents."""
  2. from __future__ import annotations
  3. from typing import TYPE_CHECKING, Any, Literal
  4. from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage
  5. from langchain.agents.middleware._redaction import (
  6. PIIDetectionError,
  7. PIIMatch,
  8. RedactionRule,
  9. ResolvedRedactionRule,
  10. apply_strategy,
  11. detect_credit_card,
  12. detect_email,
  13. detect_ip,
  14. detect_mac_address,
  15. detect_url,
  16. )
  17. from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config
  18. if TYPE_CHECKING:
  19. from collections.abc import Callable
  20. from langgraph.runtime import Runtime
  21. class PIIMiddleware(AgentMiddleware):
  22. """Detect and handle Personally Identifiable Information (PII) in conversations.
  23. This middleware detects common PII types and applies configurable strategies
  24. to handle them. It can detect emails, credit cards, IP addresses, MAC addresses, and
  25. URLs in both user input and agent output.
  26. Built-in PII types:
  27. - `email`: Email addresses
  28. - `credit_card`: Credit card numbers (validated with Luhn algorithm)
  29. - `ip`: IP addresses (validated with stdlib)
  30. - `mac_address`: MAC addresses
  31. - `url`: URLs (both `http`/`https` and bare URLs)
  32. Strategies:
  33. - `block`: Raise an exception when PII is detected
  34. - `redact`: Replace PII with `[REDACTED_TYPE]` placeholders
  35. - `mask`: Partially mask PII (e.g., `****-****-****-1234` for credit card)
  36. - `hash`: Replace PII with deterministic hash (e.g., `<email_hash:a1b2c3d4>`)
  37. Strategy Selection Guide:
  38. | Strategy | Preserves Identity? | Best For |
  39. | -------- | ------------------- | --------------------------------------- |
  40. | `block` | N/A | Avoid PII completely |
  41. | `redact` | No | General compliance, log sanitization |
  42. | `mask` | No | Human readability, customer service UIs |
  43. | `hash` | Yes (pseudonymous) | Analytics, debugging |
  44. Example:
  45. ```python
  46. from langchain.agents.middleware import PIIMiddleware
  47. from langchain.agents import create_agent
  48. # Redact all emails in user input
  49. agent = create_agent(
  50. "openai:gpt-5",
  51. middleware=[
  52. PIIMiddleware("email", strategy="redact"),
  53. ],
  54. )
  55. # Use different strategies for different PII types
  56. agent = create_agent(
  57. "openai:gpt-4o",
  58. middleware=[
  59. PIIMiddleware("credit_card", strategy="mask"),
  60. PIIMiddleware("url", strategy="redact"),
  61. PIIMiddleware("ip", strategy="hash"),
  62. ],
  63. )
  64. # Custom PII type with regex
  65. agent = create_agent(
  66. "openai:gpt-5",
  67. middleware=[
  68. PIIMiddleware("api_key", detector=r"sk-[a-zA-Z0-9]{32}", strategy="block"),
  69. ],
  70. )
  71. ```
  72. """
  73. def __init__(
  74. self,
  75. pii_type: Literal["email", "credit_card", "ip", "mac_address", "url"] | str, # noqa: PYI051
  76. *,
  77. strategy: Literal["block", "redact", "mask", "hash"] = "redact",
  78. detector: Callable[[str], list[PIIMatch]] | str | None = None,
  79. apply_to_input: bool = True,
  80. apply_to_output: bool = False,
  81. apply_to_tool_results: bool = False,
  82. ) -> None:
  83. """Initialize the PII detection middleware.
  84. Args:
  85. pii_type: Type of PII to detect.
  86. Can be a built-in type (`email`, `credit_card`, `ip`, `mac_address`,
  87. `url`) or a custom type name.
  88. strategy: How to handle detected PII.
  89. Options:
  90. * `block`: Raise `PIIDetectionError` when PII is detected
  91. * `redact`: Replace with `[REDACTED_TYPE]` placeholders
  92. * `mask`: Partially mask PII (show last few characters)
  93. * `hash`: Replace with deterministic hash (format: `<type_hash:digest>`)
  94. detector: Custom detector function or regex pattern.
  95. * If `Callable`: Function that takes content string and returns
  96. list of `PIIMatch` objects
  97. * If `str`: Regex pattern to match PII
  98. * If `None`: Uses built-in detector for the `pii_type`
  99. apply_to_input: Whether to check user messages before model call.
  100. apply_to_output: Whether to check AI messages after model call.
  101. apply_to_tool_results: Whether to check tool result messages after tool execution.
  102. Raises:
  103. ValueError: If `pii_type` is not built-in and no detector is provided.
  104. """
  105. super().__init__()
  106. self.apply_to_input = apply_to_input
  107. self.apply_to_output = apply_to_output
  108. self.apply_to_tool_results = apply_to_tool_results
  109. self._resolved_rule: ResolvedRedactionRule = RedactionRule(
  110. pii_type=pii_type,
  111. strategy=strategy,
  112. detector=detector,
  113. ).resolve()
  114. self.pii_type = self._resolved_rule.pii_type
  115. self.strategy = self._resolved_rule.strategy
  116. self.detector = self._resolved_rule.detector
  117. @property
  118. def name(self) -> str:
  119. """Name of the middleware."""
  120. return f"{self.__class__.__name__}[{self.pii_type}]"
  121. def _process_content(self, content: str) -> tuple[str, list[PIIMatch]]:
  122. """Apply the configured redaction rule to the provided content."""
  123. matches = self.detector(content)
  124. if not matches:
  125. return content, []
  126. sanitized = apply_strategy(content, matches, self.strategy)
  127. return sanitized, matches
  128. @hook_config(can_jump_to=["end"])
  129. def before_model(
  130. self,
  131. state: AgentState,
  132. runtime: Runtime, # noqa: ARG002
  133. ) -> dict[str, Any] | None:
  134. """Check user messages and tool results for PII before model invocation.
  135. Args:
  136. state: The current agent state.
  137. runtime: The langgraph runtime.
  138. Returns:
  139. Updated state with PII handled according to strategy, or `None` if no PII
  140. detected.
  141. Raises:
  142. PIIDetectionError: If PII is detected and strategy is `'block'`.
  143. """
  144. if not self.apply_to_input and not self.apply_to_tool_results:
  145. return None
  146. messages = state["messages"]
  147. if not messages:
  148. return None
  149. new_messages = list(messages)
  150. any_modified = False
  151. # Check user input if enabled
  152. if self.apply_to_input:
  153. # Get last user message
  154. last_user_msg = None
  155. last_user_idx = None
  156. for i in range(len(messages) - 1, -1, -1):
  157. if isinstance(messages[i], HumanMessage):
  158. last_user_msg = messages[i]
  159. last_user_idx = i
  160. break
  161. if last_user_idx is not None and last_user_msg and last_user_msg.content:
  162. # Detect PII in message content
  163. content = str(last_user_msg.content)
  164. new_content, matches = self._process_content(content)
  165. if matches:
  166. updated_message: AnyMessage = HumanMessage(
  167. content=new_content,
  168. id=last_user_msg.id,
  169. name=last_user_msg.name,
  170. )
  171. new_messages[last_user_idx] = updated_message
  172. any_modified = True
  173. # Check tool results if enabled
  174. if self.apply_to_tool_results:
  175. # Find the last AIMessage, then process all `ToolMessage` objects after it
  176. last_ai_idx = None
  177. for i in range(len(messages) - 1, -1, -1):
  178. if isinstance(messages[i], AIMessage):
  179. last_ai_idx = i
  180. break
  181. if last_ai_idx is not None:
  182. # Get all tool messages after the last AI message
  183. for i in range(last_ai_idx + 1, len(messages)):
  184. msg = messages[i]
  185. if isinstance(msg, ToolMessage):
  186. tool_msg = msg
  187. if not tool_msg.content:
  188. continue
  189. content = str(tool_msg.content)
  190. new_content, matches = self._process_content(content)
  191. if not matches:
  192. continue
  193. # Create updated tool message
  194. updated_message = ToolMessage(
  195. content=new_content,
  196. id=tool_msg.id,
  197. name=tool_msg.name,
  198. tool_call_id=tool_msg.tool_call_id,
  199. )
  200. new_messages[i] = updated_message
  201. any_modified = True
  202. if any_modified:
  203. return {"messages": new_messages}
  204. return None
  205. @hook_config(can_jump_to=["end"])
  206. async def abefore_model(
  207. self,
  208. state: AgentState,
  209. runtime: Runtime,
  210. ) -> dict[str, Any] | None:
  211. """Async check user messages and tool results for PII before model invocation.
  212. Args:
  213. state: The current agent state.
  214. runtime: The langgraph runtime.
  215. Returns:
  216. Updated state with PII handled according to strategy, or `None` if no PII
  217. detected.
  218. Raises:
  219. PIIDetectionError: If PII is detected and strategy is `'block'`.
  220. """
  221. return self.before_model(state, runtime)
  222. def after_model(
  223. self,
  224. state: AgentState,
  225. runtime: Runtime, # noqa: ARG002
  226. ) -> dict[str, Any] | None:
  227. """Check AI messages for PII after model invocation.
  228. Args:
  229. state: The current agent state.
  230. runtime: The langgraph runtime.
  231. Returns:
  232. Updated state with PII handled according to strategy, or None if no PII
  233. detected.
  234. Raises:
  235. PIIDetectionError: If PII is detected and strategy is `'block'`.
  236. """
  237. if not self.apply_to_output:
  238. return None
  239. messages = state["messages"]
  240. if not messages:
  241. return None
  242. # Get last AI message
  243. last_ai_msg = None
  244. last_ai_idx = None
  245. for i in range(len(messages) - 1, -1, -1):
  246. msg = messages[i]
  247. if isinstance(msg, AIMessage):
  248. last_ai_msg = msg
  249. last_ai_idx = i
  250. break
  251. if last_ai_idx is None or not last_ai_msg or not last_ai_msg.content:
  252. return None
  253. # Detect PII in message content
  254. content = str(last_ai_msg.content)
  255. new_content, matches = self._process_content(content)
  256. if not matches:
  257. return None
  258. # Create updated message
  259. updated_message = AIMessage(
  260. content=new_content,
  261. id=last_ai_msg.id,
  262. name=last_ai_msg.name,
  263. tool_calls=last_ai_msg.tool_calls,
  264. )
  265. # Return updated messages
  266. new_messages = list(messages)
  267. new_messages[last_ai_idx] = updated_message
  268. return {"messages": new_messages}
  269. async def aafter_model(
  270. self,
  271. state: AgentState,
  272. runtime: Runtime,
  273. ) -> dict[str, Any] | None:
  274. """Async check AI messages for PII after model invocation.
  275. Args:
  276. state: The current agent state.
  277. runtime: The langgraph runtime.
  278. Returns:
  279. Updated state with PII handled according to strategy, or None if no PII
  280. detected.
  281. Raises:
  282. PIIDetectionError: If PII is detected and strategy is `'block'`.
  283. """
  284. return self.after_model(state, runtime)
  285. __all__ = [
  286. "PIIDetectionError",
  287. "PIIMiddleware",
  288. "detect_credit_card",
  289. "detect_email",
  290. "detect_ip",
  291. "detect_mac_address",
  292. "detect_url",
  293. ]