| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849 |
- from __future__ import annotations
- from unittest.mock import AsyncMock
- import pytest
- from finrep_algo_agent.config import Settings
- from finrep_algo_agent.schemas.rag import RagHit, RagRetrieveResponse
- from finrep_algo_agent.schemas.section import SectionRequest
- from finrep_algo_agent.skills.section_gen.section_gen import run_section
- @pytest.mark.asyncio
- async def test_section_rag_recall_merges_into_data_package() -> None:
- settings = Settings(stub_skills=False, llm_api_key="test-key")
- mock_llm = AsyncMock()
- mock_llm.chat_completion = AsyncMock(return_value=" 生成正文片段 ")
- mock_rag = AsyncMock()
- mock_rag.retrieve = AsyncMock(
- return_value=RagRetrieveResponse(
- hits=[
- RagHit(
- chunk_id="c1",
- text="召回片段A",
- score=0.9,
- doc_id="d1",
- )
- ],
- formatted_context="[RAG] 召回片段A",
- )
- )
- req = SectionRequest(
- knowledge_unit_id="ku-1",
- template_type="info",
- task_id="task-99",
- rag_recall=True,
- rag_query="融资主体",
- paragraph_position="定位",
- paragraph_logic="撰写逻辑",
- data_package={"api": {"x": 1}},
- )
- resp = await run_section(req, settings=settings, llm=mock_llm, rag=mock_rag)
- assert "生成正文片段" in resp.generated_text
- mock_rag.retrieve.assert_awaited_once()
- call_kw = mock_llm.chat_completion.await_args
- prompt = call_kw[0][0][0]["content"]
- assert "rag_recall" in prompt or "[RAG]" in prompt
|