anonymizer.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import re # noqa
  2. import inspect
  3. from abc import abstractmethod
  4. from collections import defaultdict
  5. from typing import Any, Callable, Optional, TypedDict, Union
  6. class _ExtractOptions(TypedDict):
  7. max_depth: Optional[int]
  8. """
  9. Maximum depth to traverse to to extract string nodes
  10. """
  11. class StringNode(TypedDict):
  12. """String node extracted from the data."""
  13. value: str
  14. """String value."""
  15. path: list[Union[str, int]]
  16. """Path to the string node in the data."""
  17. def _extract_string_nodes(data: Any, options: _ExtractOptions) -> list[StringNode]:
  18. max_depth = options.get("max_depth") or 10
  19. queue: list[tuple[Any, int, list[Union[str, int]]]] = [(data, 0, [])]
  20. result: list[StringNode] = []
  21. while queue:
  22. task = queue.pop(0)
  23. if task is None:
  24. continue
  25. value, depth, path = task
  26. if isinstance(value, (dict, defaultdict)):
  27. if depth >= max_depth:
  28. continue
  29. for key, nested_value in value.items():
  30. queue.append((nested_value, depth + 1, path + [key]))
  31. elif isinstance(value, list):
  32. if depth >= max_depth:
  33. continue
  34. for i, item in enumerate(value):
  35. queue.append((item, depth + 1, path + [i]))
  36. elif isinstance(value, str):
  37. result.append(StringNode(value=value, path=path))
  38. return result
  39. class StringNodeProcessor:
  40. """Processes a list of string nodes for masking."""
  41. @abstractmethod
  42. def mask_nodes(self, nodes: list[StringNode]) -> list[StringNode]:
  43. """Accept and return a list of string nodes to be masked."""
  44. class ReplacerOptions(TypedDict):
  45. """Configuration options for replacing sensitive data."""
  46. max_depth: Optional[int]
  47. """Maximum depth to traverse to to extract string nodes."""
  48. deep_clone: Optional[bool]
  49. """Deep clone the data before replacing."""
  50. class StringNodeRule(TypedDict):
  51. """Declarative rule used for replacing sensitive data."""
  52. pattern: re.Pattern
  53. """Regex pattern to match."""
  54. replace: Optional[str]
  55. """Replacement value. Defaults to `[redacted]` if not specified."""
  56. class RuleNodeProcessor(StringNodeProcessor):
  57. """String node processor that uses a list of rules to replace sensitive data."""
  58. rules: list[StringNodeRule]
  59. """List of rules to apply for replacing sensitive data.
  60. Each rule is a StringNodeRule, which contains a regex pattern to match
  61. and an optional replacement string.
  62. """
  63. def __init__(self, rules: list[StringNodeRule]):
  64. """Initialize the processor with a list of rules."""
  65. self.rules = [
  66. {
  67. "pattern": (
  68. rule["pattern"]
  69. if isinstance(rule["pattern"], re.Pattern)
  70. else re.compile(rule["pattern"])
  71. ),
  72. "replace": (
  73. rule["replace"]
  74. if isinstance(rule.get("replace"), str)
  75. else "[redacted]"
  76. ),
  77. }
  78. for rule in rules
  79. ]
  80. def mask_nodes(self, nodes: list[StringNode]) -> list[StringNode]:
  81. """Mask nodes using the rules."""
  82. result = []
  83. for item in nodes:
  84. new_value = item["value"]
  85. for rule in self.rules:
  86. new_value = rule["pattern"].sub(rule["replace"], new_value)
  87. if new_value != item["value"]:
  88. result.append(StringNode(value=new_value, path=item["path"]))
  89. return result
  90. class CallableNodeProcessor(StringNodeProcessor):
  91. """String node processor that uses a callable function to replace sensitive data."""
  92. func: Union[Callable[[str], str], Callable[[str, list[Union[str, int]]], str]]
  93. """The callable function used to replace sensitive data.
  94. It can be either a function that takes a single string argument and returns a string,
  95. or a function that takes a string and a list of path elements (strings or integers)
  96. and returns a string."""
  97. accepts_path: bool
  98. """Indicates whether the callable function accepts a path argument.
  99. If True, the function expects two arguments: the string to be processed and the path to that string.
  100. If False, the function expects only the string to be processed."""
  101. def __init__(
  102. self,
  103. func: Union[Callable[[str], str], Callable[[str, list[Union[str, int]]], str]],
  104. ):
  105. """Initialize the processor with a callable function."""
  106. self.func = func
  107. self.accepts_path = len(inspect.signature(func).parameters) == 2
  108. def mask_nodes(self, nodes: list[StringNode]) -> list[StringNode]:
  109. """Mask nodes using the callable function."""
  110. retval: list[StringNode] = []
  111. for node in nodes:
  112. candidate = (
  113. self.func(node["value"], node["path"]) # type: ignore[call-arg]
  114. if self.accepts_path
  115. else self.func(node["value"]) # type: ignore[call-arg]
  116. )
  117. if candidate != node["value"]:
  118. retval.append(StringNode(value=candidate, path=node["path"]))
  119. return retval
  120. ReplacerType = Union[
  121. Callable[[str, list[Union[str, int]]], str],
  122. list[StringNodeRule],
  123. StringNodeProcessor,
  124. ]
  125. def _get_node_processor(replacer: ReplacerType) -> StringNodeProcessor:
  126. if isinstance(replacer, list):
  127. return RuleNodeProcessor(rules=replacer)
  128. elif callable(replacer):
  129. return CallableNodeProcessor(func=replacer)
  130. else:
  131. return replacer
  132. def create_anonymizer(
  133. replacer: ReplacerType,
  134. *,
  135. max_depth: Optional[int] = None,
  136. ) -> Callable[[Any], Any]:
  137. """Create an anonymizer function."""
  138. processor = _get_node_processor(replacer)
  139. def anonymizer(data: Any) -> Any:
  140. nodes = _extract_string_nodes(data, {"max_depth": max_depth or 10})
  141. mutate_value = data
  142. to_update = processor.mask_nodes(nodes)
  143. for node in to_update:
  144. if not node["path"]:
  145. mutate_value = node["value"]
  146. else:
  147. temp = mutate_value
  148. for part in node["path"][:-1]:
  149. temp = temp[part]
  150. last_part = node["path"][-1]
  151. temp[last_part] = node["value"]
  152. return mutate_value
  153. return anonymizer