test_section_rag.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from __future__ import annotations
  2. from unittest.mock import AsyncMock
  3. import pytest
  4. from finrep_algo_agent.config import Settings
  5. from finrep_algo_agent.schemas.rag import RagHit, RagRetrieveResponse
  6. from finrep_algo_agent.schemas.section import SectionRequest
  7. from finrep_algo_agent.skills.section_gen.section_gen import run_section
  8. @pytest.mark.asyncio
  9. async def test_section_uses_upstream_data_package_without_internal_rag_recall() -> None:
  10. settings = Settings(stub_skills=False, llm_api_key="test-key")
  11. mock_llm = AsyncMock()
  12. mock_llm.chat_completion = AsyncMock(return_value=" 生成正文片段 ")
  13. mock_rag = AsyncMock()
  14. mock_rag.retrieve = AsyncMock(
  15. return_value=RagRetrieveResponse(
  16. hits=[
  17. RagHit(
  18. chunk_id="c1",
  19. text="召回片段A",
  20. score=0.9,
  21. doc_id="d1",
  22. )
  23. ],
  24. formatted_context="[RAG] 召回片段A",
  25. )
  26. )
  27. req = SectionRequest(
  28. knowledge_unit_id="ku-1",
  29. template_type="info",
  30. task_id="task-99",
  31. rag_recall=True,
  32. rag_query="融资主体",
  33. template={"paragraph_position": "定位", "paragraph_logic": "撰写逻辑"},
  34. data_package={"api": {"x": 1}},
  35. )
  36. resp = await run_section(req, settings=settings, llm=mock_llm, rag=mock_rag)
  37. assert "生成正文片段" in resp.generated_text
  38. mock_rag.retrieve.assert_not_awaited()
  39. call_kw = mock_llm.chat_completion.await_args
  40. prompt = call_kw[0][0][0]["content"]
  41. assert "定位" in prompt
  42. @pytest.mark.asyncio
  43. async def test_section_uses_template_fields_from_top_level_template() -> None:
  44. settings = Settings(stub_skills=False, llm_api_key="test-key")
  45. mock_llm = AsyncMock()
  46. mock_llm.chat_completion = AsyncMock(return_value="模板注入正文")
  47. req = SectionRequest(
  48. knowledge_unit_id="ku-2",
  49. template_type="info",
  50. template={
  51. "paragraph_position": "来自知识单元模板-段落定位",
  52. "paragraph_logic": "来自知识单元模板-段落逻辑",
  53. "example": "示例文本",
  54. "notes": "模板注意事项",
  55. },
  56. )
  57. resp = await run_section(req, settings=settings, llm=mock_llm, rag=AsyncMock())
  58. assert "模板注入正文" in resp.generated_text
  59. prompt = mock_llm.chat_completion.await_args[0][0][0]["content"]
  60. assert "来自知识单元模板-段落定位" in prompt
  61. assert "来自知识单元模板-段落逻辑" in prompt
  62. assert "模板注意事项" in prompt