import collections.abc from collections.abc import Callable, Sequence from typing import Any, Generic from typing_extensions import NotRequired, Required, Self from langgraph._internal._constants import OVERWRITE from langgraph._internal._typing import MISSING from langgraph.channels.base import BaseChannel, Value from langgraph.errors import ( EmptyChannelError, ErrorCode, InvalidUpdateError, create_error_message, ) from langgraph.types import Overwrite __all__ = ("BinaryOperatorAggregate",) # Adapted from typing_extensions def _strip_extras(t): # type: ignore[no-untyped-def] """Strips Annotated, Required and NotRequired from a given type.""" if hasattr(t, "__origin__"): return _strip_extras(t.__origin__) if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired): return _strip_extras(t.__args__[0]) return t def _get_overwrite(value: Any) -> tuple[bool, Any]: """Inspects the given value and returns (is_overwrite, overwrite_value).""" if isinstance(value, Overwrite): return True, value.value if isinstance(value, dict) and set(value.keys()) == {OVERWRITE}: return True, value[OVERWRITE] return False, None class BinaryOperatorAggregate(Generic[Value], BaseChannel[Value, Value, Value]): """Stores the result of applying a binary operator to the current value and each new value. ```python import operator total = Channels.BinaryOperatorAggregate(int, operator.add) ``` """ __slots__ = ("value", "operator") def __init__(self, typ: type[Value], operator: Callable[[Value, Value], Value]): super().__init__(typ) self.operator = operator # special forms from typing or collections.abc are not instantiable # so we need to replace them with their concrete counterparts typ = _strip_extras(typ) if typ in (collections.abc.Sequence, collections.abc.MutableSequence): typ = list if typ in (collections.abc.Set, collections.abc.MutableSet): typ = set if typ in (collections.abc.Mapping, collections.abc.MutableMapping): typ = dict try: self.value = typ() except Exception: self.value = MISSING def __eq__(self, value: object) -> bool: return isinstance(value, BinaryOperatorAggregate) and ( value.operator is self.operator if value.operator.__name__ != "" and self.operator.__name__ != "" else True ) @property def ValueType(self) -> type[Value]: """The type of the value stored in the channel.""" return self.typ @property def UpdateType(self) -> type[Value]: """The type of the update received by the channel.""" return self.typ def copy(self) -> Self: """Return a copy of the channel.""" empty = self.__class__(self.typ, self.operator) empty.key = self.key empty.value = self.value return empty def from_checkpoint(self, checkpoint: Value) -> Self: empty = self.__class__(self.typ, self.operator) empty.key = self.key if checkpoint is not MISSING: empty.value = checkpoint return empty def update(self, values: Sequence[Value]) -> bool: if not values: return False if self.value is MISSING: self.value = values[0] values = values[1:] seen_overwrite: bool = False for value in values: is_overwrite, overwrite_value = _get_overwrite(value) if is_overwrite: if seen_overwrite: msg = create_error_message( message="Can receive only one Overwrite value per super-step.", error_code=ErrorCode.INVALID_CONCURRENT_GRAPH_UPDATE, ) raise InvalidUpdateError(msg) self.value = overwrite_value seen_overwrite = True continue if not seen_overwrite: self.value = self.operator(self.value, value) return True def get(self) -> Value: if self.value is MISSING: raise EmptyChannelError() return self.value def is_available(self) -> bool: return self.value is not MISSING def checkpoint(self) -> Value: return self.value