_checkpoint.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from __future__ import annotations
  2. from collections.abc import Mapping
  3. from datetime import datetime, timezone
  4. from langgraph.checkpoint.base import Checkpoint
  5. from langgraph.checkpoint.base.id import uuid6
  6. from langgraph._internal._typing import MISSING
  7. from langgraph.channels.base import BaseChannel
  8. from langgraph.managed.base import ManagedValueMapping, ManagedValueSpec
  9. LATEST_VERSION = 4
  10. def empty_checkpoint() -> Checkpoint:
  11. return Checkpoint(
  12. v=LATEST_VERSION,
  13. id=str(uuid6(clock_seq=-2)),
  14. ts=datetime.now(timezone.utc).isoformat(),
  15. channel_values={},
  16. channel_versions={},
  17. versions_seen={},
  18. )
  19. def create_checkpoint(
  20. checkpoint: Checkpoint,
  21. channels: Mapping[str, BaseChannel] | None,
  22. step: int,
  23. *,
  24. id: str | None = None,
  25. updated_channels: set[str] | None = None,
  26. ) -> Checkpoint:
  27. """Create a checkpoint for the given channels."""
  28. ts = datetime.now(timezone.utc).isoformat()
  29. if channels is None:
  30. values = checkpoint["channel_values"]
  31. else:
  32. values = {}
  33. for k in channels:
  34. if k not in checkpoint["channel_versions"]:
  35. continue
  36. v = channels[k].checkpoint()
  37. if v is not MISSING:
  38. values[k] = v
  39. return Checkpoint(
  40. v=LATEST_VERSION,
  41. ts=ts,
  42. id=id or str(uuid6(clock_seq=step)),
  43. channel_values=values,
  44. channel_versions=checkpoint["channel_versions"],
  45. versions_seen=checkpoint["versions_seen"],
  46. updated_channels=None if updated_channels is None else sorted(updated_channels),
  47. )
  48. def channels_from_checkpoint(
  49. specs: Mapping[str, BaseChannel | ManagedValueSpec],
  50. checkpoint: Checkpoint,
  51. ) -> tuple[Mapping[str, BaseChannel], ManagedValueMapping]:
  52. """Get channels from a checkpoint."""
  53. channel_specs: dict[str, BaseChannel] = {}
  54. managed_specs: dict[str, ManagedValueSpec] = {}
  55. for k, v in specs.items():
  56. if isinstance(v, BaseChannel):
  57. channel_specs[k] = v
  58. else:
  59. managed_specs[k] = v
  60. return (
  61. {
  62. k: v.from_checkpoint(checkpoint["channel_values"].get(k, MISSING))
  63. for k, v in channel_specs.items()
  64. },
  65. managed_specs,
  66. )
  67. def copy_checkpoint(checkpoint: Checkpoint) -> Checkpoint:
  68. return Checkpoint(
  69. v=checkpoint["v"],
  70. ts=checkpoint["ts"],
  71. id=checkpoint["id"],
  72. channel_values=checkpoint["channel_values"].copy(),
  73. channel_versions=checkpoint["channel_versions"].copy(),
  74. versions_seen={k: v.copy() for k, v in checkpoint["versions_seen"].items()},
  75. updated_channels=checkpoint.get("updated_channels", None),
  76. )