binop.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import collections.abc
  2. from collections.abc import Callable, Sequence
  3. from typing import Any, Generic
  4. from typing_extensions import NotRequired, Required, Self
  5. from langgraph._internal._constants import OVERWRITE
  6. from langgraph._internal._typing import MISSING
  7. from langgraph.channels.base import BaseChannel, Value
  8. from langgraph.errors import (
  9. EmptyChannelError,
  10. ErrorCode,
  11. InvalidUpdateError,
  12. create_error_message,
  13. )
  14. from langgraph.types import Overwrite
  15. __all__ = ("BinaryOperatorAggregate",)
  16. # Adapted from typing_extensions
  17. def _strip_extras(t): # type: ignore[no-untyped-def]
  18. """Strips Annotated, Required and NotRequired from a given type."""
  19. if hasattr(t, "__origin__"):
  20. return _strip_extras(t.__origin__)
  21. if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired):
  22. return _strip_extras(t.__args__[0])
  23. return t
  24. def _get_overwrite(value: Any) -> tuple[bool, Any]:
  25. """Inspects the given value and returns (is_overwrite, overwrite_value)."""
  26. if isinstance(value, Overwrite):
  27. return True, value.value
  28. if isinstance(value, dict) and set(value.keys()) == {OVERWRITE}:
  29. return True, value[OVERWRITE]
  30. return False, None
  31. class BinaryOperatorAggregate(Generic[Value], BaseChannel[Value, Value, Value]):
  32. """Stores the result of applying a binary operator to the current value and each new value.
  33. ```python
  34. import operator
  35. total = Channels.BinaryOperatorAggregate(int, operator.add)
  36. ```
  37. """
  38. __slots__ = ("value", "operator")
  39. def __init__(self, typ: type[Value], operator: Callable[[Value, Value], Value]):
  40. super().__init__(typ)
  41. self.operator = operator
  42. # special forms from typing or collections.abc are not instantiable
  43. # so we need to replace them with their concrete counterparts
  44. typ = _strip_extras(typ)
  45. if typ in (collections.abc.Sequence, collections.abc.MutableSequence):
  46. typ = list
  47. if typ in (collections.abc.Set, collections.abc.MutableSet):
  48. typ = set
  49. if typ in (collections.abc.Mapping, collections.abc.MutableMapping):
  50. typ = dict
  51. try:
  52. self.value = typ()
  53. except Exception:
  54. self.value = MISSING
  55. def __eq__(self, value: object) -> bool:
  56. return isinstance(value, BinaryOperatorAggregate) and (
  57. value.operator is self.operator
  58. if value.operator.__name__ != "<lambda>"
  59. and self.operator.__name__ != "<lambda>"
  60. else True
  61. )
  62. @property
  63. def ValueType(self) -> type[Value]:
  64. """The type of the value stored in the channel."""
  65. return self.typ
  66. @property
  67. def UpdateType(self) -> type[Value]:
  68. """The type of the update received by the channel."""
  69. return self.typ
  70. def copy(self) -> Self:
  71. """Return a copy of the channel."""
  72. empty = self.__class__(self.typ, self.operator)
  73. empty.key = self.key
  74. empty.value = self.value
  75. return empty
  76. def from_checkpoint(self, checkpoint: Value) -> Self:
  77. empty = self.__class__(self.typ, self.operator)
  78. empty.key = self.key
  79. if checkpoint is not MISSING:
  80. empty.value = checkpoint
  81. return empty
  82. def update(self, values: Sequence[Value]) -> bool:
  83. if not values:
  84. return False
  85. if self.value is MISSING:
  86. self.value = values[0]
  87. values = values[1:]
  88. seen_overwrite: bool = False
  89. for value in values:
  90. is_overwrite, overwrite_value = _get_overwrite(value)
  91. if is_overwrite:
  92. if seen_overwrite:
  93. msg = create_error_message(
  94. message="Can receive only one Overwrite value per super-step.",
  95. error_code=ErrorCode.INVALID_CONCURRENT_GRAPH_UPDATE,
  96. )
  97. raise InvalidUpdateError(msg)
  98. self.value = overwrite_value
  99. seen_overwrite = True
  100. continue
  101. if not seen_overwrite:
  102. self.value = self.operator(self.value, value)
  103. return True
  104. def get(self) -> Value:
  105. if self.value is MISSING:
  106. raise EmptyChannelError()
  107. return self.value
  108. def is_available(self) -> bool:
  109. return self.value is not MISSING
  110. def checkpoint(self) -> Value:
  111. return self.value