| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- import collections.abc
- from collections.abc import Callable, Sequence
- from typing import Any, Generic
- from typing_extensions import NotRequired, Required, Self
- from langgraph._internal._constants import OVERWRITE
- from langgraph._internal._typing import MISSING
- from langgraph.channels.base import BaseChannel, Value
- from langgraph.errors import (
- EmptyChannelError,
- ErrorCode,
- InvalidUpdateError,
- create_error_message,
- )
- from langgraph.types import Overwrite
- __all__ = ("BinaryOperatorAggregate",)
- # Adapted from typing_extensions
- def _strip_extras(t): # type: ignore[no-untyped-def]
- """Strips Annotated, Required and NotRequired from a given type."""
- if hasattr(t, "__origin__"):
- return _strip_extras(t.__origin__)
- if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired):
- return _strip_extras(t.__args__[0])
- return t
- def _get_overwrite(value: Any) -> tuple[bool, Any]:
- """Inspects the given value and returns (is_overwrite, overwrite_value)."""
- if isinstance(value, Overwrite):
- return True, value.value
- if isinstance(value, dict) and set(value.keys()) == {OVERWRITE}:
- return True, value[OVERWRITE]
- return False, None
- class BinaryOperatorAggregate(Generic[Value], BaseChannel[Value, Value, Value]):
- """Stores the result of applying a binary operator to the current value and each new value.
- ```python
- import operator
- total = Channels.BinaryOperatorAggregate(int, operator.add)
- ```
- """
- __slots__ = ("value", "operator")
- def __init__(self, typ: type[Value], operator: Callable[[Value, Value], Value]):
- super().__init__(typ)
- self.operator = operator
- # special forms from typing or collections.abc are not instantiable
- # so we need to replace them with their concrete counterparts
- typ = _strip_extras(typ)
- if typ in (collections.abc.Sequence, collections.abc.MutableSequence):
- typ = list
- if typ in (collections.abc.Set, collections.abc.MutableSet):
- typ = set
- if typ in (collections.abc.Mapping, collections.abc.MutableMapping):
- typ = dict
- try:
- self.value = typ()
- except Exception:
- self.value = MISSING
- def __eq__(self, value: object) -> bool:
- return isinstance(value, BinaryOperatorAggregate) and (
- value.operator is self.operator
- if value.operator.__name__ != "<lambda>"
- and self.operator.__name__ != "<lambda>"
- else True
- )
- @property
- def ValueType(self) -> type[Value]:
- """The type of the value stored in the channel."""
- return self.typ
- @property
- def UpdateType(self) -> type[Value]:
- """The type of the update received by the channel."""
- return self.typ
- def copy(self) -> Self:
- """Return a copy of the channel."""
- empty = self.__class__(self.typ, self.operator)
- empty.key = self.key
- empty.value = self.value
- return empty
- def from_checkpoint(self, checkpoint: Value) -> Self:
- empty = self.__class__(self.typ, self.operator)
- empty.key = self.key
- if checkpoint is not MISSING:
- empty.value = checkpoint
- return empty
- def update(self, values: Sequence[Value]) -> bool:
- if not values:
- return False
- if self.value is MISSING:
- self.value = values[0]
- values = values[1:]
- seen_overwrite: bool = False
- for value in values:
- is_overwrite, overwrite_value = _get_overwrite(value)
- if is_overwrite:
- if seen_overwrite:
- msg = create_error_message(
- message="Can receive only one Overwrite value per super-step.",
- error_code=ErrorCode.INVALID_CONCURRENT_GRAPH_UPDATE,
- )
- raise InvalidUpdateError(msg)
- self.value = overwrite_value
- seen_overwrite = True
- continue
- if not seen_overwrite:
- self.value = self.operator(self.value, value)
- return True
- def get(self) -> Value:
- if self.value is MISSING:
- raise EmptyChannelError()
- return self.value
- def is_available(self) -> bool:
- return self.value is not MISSING
- def checkpoint(self) -> Value:
- return self.value
|