_validate.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. from __future__ import annotations
  2. from collections.abc import Mapping, Sequence
  3. from typing import Any
  4. from langgraph._internal._constants import RESERVED
  5. from langgraph.channels.base import BaseChannel
  6. from langgraph.managed.base import ManagedValueMapping
  7. from langgraph.pregel._read import PregelNode
  8. from langgraph.types import All
  9. def validate_graph(
  10. nodes: Mapping[str, PregelNode],
  11. channels: dict[str, BaseChannel],
  12. managed: ManagedValueMapping,
  13. input_channels: str | Sequence[str],
  14. output_channels: str | Sequence[str],
  15. stream_channels: str | Sequence[str] | None,
  16. interrupt_after_nodes: All | Sequence[str],
  17. interrupt_before_nodes: All | Sequence[str],
  18. ) -> None:
  19. for chan in channels:
  20. if chan in RESERVED:
  21. raise ValueError(f"Channel name '{chan}' is reserved")
  22. for name in managed:
  23. if name in RESERVED:
  24. raise ValueError(f"Managed name '{name}' is reserved")
  25. subscribed_channels = set[str]()
  26. for name, node in nodes.items():
  27. if name in RESERVED:
  28. raise ValueError(f"Node name '{name}' is reserved")
  29. if isinstance(node, PregelNode):
  30. subscribed_channels.update(node.triggers)
  31. if isinstance(node.channels, str):
  32. if node.channels not in channels:
  33. raise ValueError(
  34. f"Node {name} reads channel '{node.channels}' "
  35. f"not in known channels: '{repr(sorted(channels))[:100]}'"
  36. )
  37. else:
  38. for chan in node.channels:
  39. if chan not in channels and chan not in managed:
  40. raise ValueError(
  41. f"Node {name} reads channel '{chan}' "
  42. f"not in known channels: '{repr(sorted(channels))[:100]}'"
  43. )
  44. else:
  45. raise TypeError(
  46. f"Invalid node type {type(node)}, expected PregelNode or NodeBuilder"
  47. )
  48. for chan in subscribed_channels:
  49. if chan not in channels:
  50. raise ValueError(
  51. f"Subscribed channel '{chan}' not "
  52. f"in known channels: '{repr(sorted(channels))[:100]}'"
  53. )
  54. if isinstance(input_channels, str):
  55. if input_channels not in channels:
  56. raise ValueError(
  57. f"Input channel '{input_channels}' not "
  58. f"in known channels: '{repr(sorted(channels))[:100]}'"
  59. )
  60. if input_channels not in subscribed_channels:
  61. raise ValueError(
  62. f"Input channel {input_channels} is not subscribed to by any node"
  63. )
  64. else:
  65. for chan in input_channels:
  66. if chan not in channels:
  67. raise ValueError(
  68. f"Input channel '{chan}' not in '{repr(sorted(channels))[:100]}'"
  69. )
  70. if all(chan not in subscribed_channels for chan in input_channels):
  71. raise ValueError(
  72. f"None of the input channels {input_channels} "
  73. f"are subscribed to by any node"
  74. )
  75. all_output_channels = set[str]()
  76. if isinstance(output_channels, str):
  77. all_output_channels.add(output_channels)
  78. else:
  79. all_output_channels.update(output_channels)
  80. if isinstance(stream_channels, str):
  81. all_output_channels.add(stream_channels)
  82. elif stream_channels is not None:
  83. all_output_channels.update(stream_channels)
  84. for chan in all_output_channels:
  85. if chan not in channels:
  86. raise ValueError(
  87. f"Output channel '{chan}' not "
  88. f"in known channels: '{repr(sorted(channels))[:100]}'"
  89. )
  90. if interrupt_after_nodes != "*":
  91. for n in interrupt_after_nodes:
  92. if n not in nodes:
  93. raise ValueError(f"Node {n} not in nodes")
  94. if interrupt_before_nodes != "*":
  95. for n in interrupt_before_nodes:
  96. if n not in nodes:
  97. raise ValueError(f"Node {n} not in nodes")
  98. def validate_keys(
  99. keys: str | Sequence[str] | None,
  100. channels: Mapping[str, Any],
  101. ) -> None:
  102. if isinstance(keys, str):
  103. if keys not in channels:
  104. raise ValueError(f"Key {keys} not in channels")
  105. elif keys is not None:
  106. for chan in keys:
  107. if chan not in channels:
  108. raise ValueError(f"Key {chan} not in channels")