test_section_rag.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. from __future__ import annotations
  2. from unittest.mock import AsyncMock
  3. import pytest
  4. from finrep_algo_agent.config import Settings
  5. from finrep_algo_agent.schemas.rag import RagHit, RagRetrieveResponse
  6. from finrep_algo_agent.schemas.section import SectionRequest
  7. from finrep_algo_agent.skills.section_gen.section_gen import run_section
  8. @pytest.mark.asyncio
  9. async def test_section_rag_recall_merges_into_data_package() -> None:
  10. settings = Settings(stub_skills=False, llm_api_key="test-key")
  11. mock_llm = AsyncMock()
  12. mock_llm.chat_completion = AsyncMock(return_value=" 生成正文片段 ")
  13. mock_rag = AsyncMock()
  14. mock_rag.retrieve = AsyncMock(
  15. return_value=RagRetrieveResponse(
  16. hits=[
  17. RagHit(
  18. chunk_id="c1",
  19. text="召回片段A",
  20. score=0.9,
  21. doc_id="d1",
  22. )
  23. ],
  24. formatted_context="[RAG] 召回片段A",
  25. )
  26. )
  27. req = SectionRequest(
  28. knowledge_unit_id="ku-1",
  29. template_type="info",
  30. task_id="task-99",
  31. rag_recall=True,
  32. rag_query="融资主体",
  33. paragraph_position="定位",
  34. paragraph_logic="撰写逻辑",
  35. data_package={"api": {"x": 1}},
  36. )
  37. resp = await run_section(req, settings=settings, llm=mock_llm, rag=mock_rag)
  38. assert "生成正文片段" in resp.generated_text
  39. mock_rag.retrieve.assert_awaited_once()
  40. call_kw = mock_llm.chat_completion.await_args
  41. prompt = call_kw[0][0][0]["content"]
  42. assert "rag_recall" in prompt or "[RAG]" in prompt