| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- from __future__ import annotations
- from unittest.mock import AsyncMock
- import pytest
- from finrep_algo_agent.config import Settings
- from finrep_algo_agent.schemas.rag import RagHit, RagRetrieveResponse
- from finrep_algo_agent.schemas.section import SectionRequest
- from finrep_algo_agent.skills.section_gen.section_gen import run_section
- @pytest.mark.asyncio
- async def test_section_uses_upstream_data_package_without_internal_rag_recall() -> None:
- settings = Settings(stub_skills=False, llm_api_key="test-key")
- mock_llm = AsyncMock()
- mock_llm.chat_completion = AsyncMock(return_value=" 生成正文片段 ")
- mock_rag = AsyncMock()
- mock_rag.retrieve = AsyncMock(
- return_value=RagRetrieveResponse(
- hits=[
- RagHit(
- chunk_id="c1",
- text="召回片段A",
- score=0.9,
- doc_id="d1",
- )
- ],
- formatted_context="[RAG] 召回片段A",
- )
- )
- req = SectionRequest(
- knowledge_unit_id="ku-1",
- template_type="info",
- task_id="task-99",
- rag_recall=True,
- rag_query="融资主体",
- template={"paragraph_position": "定位", "paragraph_logic": "撰写逻辑"},
- data_package={"api": {"x": 1}},
- )
- resp = await run_section(req, settings=settings, llm=mock_llm, rag=mock_rag)
- assert "生成正文片段" in resp.generated_text
- mock_rag.retrieve.assert_not_awaited()
- call_kw = mock_llm.chat_completion.await_args
- prompt = call_kw[0][0][0]["content"]
- assert "定位" in prompt
- @pytest.mark.asyncio
- async def test_section_uses_template_fields_from_top_level_template() -> None:
- settings = Settings(stub_skills=False, llm_api_key="test-key")
- mock_llm = AsyncMock()
- mock_llm.chat_completion = AsyncMock(return_value="模板注入正文")
- req = SectionRequest(
- knowledge_unit_id="ku-2",
- template_type="info",
- template={
- "paragraph_position": "来自知识单元模板-段落定位",
- "paragraph_logic": "来自知识单元模板-段落逻辑",
- "example": "示例文本",
- "notes": "模板注意事项",
- },
- )
- resp = await run_section(req, settings=settings, llm=mock_llm, rag=AsyncMock())
- assert "模板注入正文" in resp.generated_text
- prompt = mock_llm.chat_completion.await_args[0][0][0]["content"]
- assert "来自知识单元模板-段落定位" in prompt
- assert "来自知识单元模板-段落逻辑" in prompt
- assert "模板注意事项" in prompt
|