| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- from __future__ import annotations
- from collections.abc import Mapping
- from datetime import datetime, timezone
- from langgraph.checkpoint.base import Checkpoint
- from langgraph.checkpoint.base.id import uuid6
- from langgraph._internal._typing import MISSING
- from langgraph.channels.base import BaseChannel
- from langgraph.managed.base import ManagedValueMapping, ManagedValueSpec
- LATEST_VERSION = 4
- def empty_checkpoint() -> Checkpoint:
- return Checkpoint(
- v=LATEST_VERSION,
- id=str(uuid6(clock_seq=-2)),
- ts=datetime.now(timezone.utc).isoformat(),
- channel_values={},
- channel_versions={},
- versions_seen={},
- )
- def create_checkpoint(
- checkpoint: Checkpoint,
- channels: Mapping[str, BaseChannel] | None,
- step: int,
- *,
- id: str | None = None,
- updated_channels: set[str] | None = None,
- ) -> Checkpoint:
- """Create a checkpoint for the given channels."""
- ts = datetime.now(timezone.utc).isoformat()
- if channels is None:
- values = checkpoint["channel_values"]
- else:
- values = {}
- for k in channels:
- if k not in checkpoint["channel_versions"]:
- continue
- v = channels[k].checkpoint()
- if v is not MISSING:
- values[k] = v
- return Checkpoint(
- v=LATEST_VERSION,
- ts=ts,
- id=id or str(uuid6(clock_seq=step)),
- channel_values=values,
- channel_versions=checkpoint["channel_versions"],
- versions_seen=checkpoint["versions_seen"],
- updated_channels=None if updated_channels is None else sorted(updated_channels),
- )
- def channels_from_checkpoint(
- specs: Mapping[str, BaseChannel | ManagedValueSpec],
- checkpoint: Checkpoint,
- ) -> tuple[Mapping[str, BaseChannel], ManagedValueMapping]:
- """Get channels from a checkpoint."""
- channel_specs: dict[str, BaseChannel] = {}
- managed_specs: dict[str, ManagedValueSpec] = {}
- for k, v in specs.items():
- if isinstance(v, BaseChannel):
- channel_specs[k] = v
- else:
- managed_specs[k] = v
- return (
- {
- k: v.from_checkpoint(checkpoint["channel_values"].get(k, MISSING))
- for k, v in channel_specs.items()
- },
- managed_specs,
- )
- def copy_checkpoint(checkpoint: Checkpoint) -> Checkpoint:
- return Checkpoint(
- v=checkpoint["v"],
- ts=checkpoint["ts"],
- id=checkpoint["id"],
- channel_values=checkpoint["channel_values"].copy(),
- channel_versions=checkpoint["channel_versions"].copy(),
- versions_seen={k: v.copy() for k, v in checkpoint["versions_seen"].items()},
- updated_channels=checkpoint.get("updated_channels", None),
- )
|