test_rag.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from __future__ import annotations
  2. from unittest.mock import AsyncMock
  3. import pytest
  4. from finrep_algo_agent.config import Settings
  5. from finrep_algo_agent.rag.ingestion import chunk_text, extract_text_from_upload
  6. from finrep_algo_agent.rag.vectorstore import InMemoryRagStore
  7. from finrep_algo_agent.schemas.rag import RagDocumentIn
  8. from finrep_algo_agent.skills.rag_retrieve import RagService
  9. def test_extract_plain_text_utf8() -> None:
  10. ex = extract_text_from_upload(filename="a.txt", data="融资说明\n第二行".encode())
  11. assert "融资说明" in ex.text
  12. assert not ex.warning
  13. def test_chunk_text_splits_and_non_empty() -> None:
  14. long = "第一段。\n\n" + "字" * 500 + "\n\n尾段"
  15. chunks = chunk_text(long, chunk_size=120, overlap=20)
  16. assert len(chunks) >= 2
  17. assert all(c for c in chunks)
  18. @pytest.mark.asyncio
  19. async def test_rag_service_ingest_retrieve() -> None:
  20. settings = Settings(
  21. rag_chunk_size=200,
  22. rag_chunk_overlap=40,
  23. rag_default_top_k=3,
  24. rag_embedding_batch_size=8,
  25. )
  26. store = InMemoryRagStore()
  27. async def fake_embeddings(texts: list[str]) -> list[list[float]]:
  28. return [[float(i % 5), float(len(t) % 3), 0.0, 1.0] for i, t in enumerate(texts)]
  29. async def fake_embedding(q: str) -> list[float]:
  30. return [1.0, 0.0, 0.0, 0.0]
  31. mock_llm = AsyncMock()
  32. mock_llm.embeddings = AsyncMock(side_effect=fake_embeddings)
  33. mock_llm.embedding = AsyncMock(side_effect=fake_embedding)
  34. svc = RagService(settings=settings, llm=mock_llm, store=store)
  35. await svc.ingest(
  36. "t1",
  37. [
  38. RagDocumentIn(
  39. doc_id="d1",
  40. title="测试",
  41. text="融资主体基本情况说明。" * 30,
  42. source_label="上传材料.pdf",
  43. )
  44. ],
  45. replace=True,
  46. )
  47. out = await svc.retrieve("t1", "融资 主体", top_k=2, min_score=None)
  48. assert out.hits
  49. assert "RAG片段" in out.formatted_context or out.hits[0].text
  50. @pytest.mark.asyncio
  51. async def test_rag_delete_index() -> None:
  52. settings = Settings(rag_chunk_size=500, rag_chunk_overlap=0)
  53. store = InMemoryRagStore()
  54. mock_llm = AsyncMock()
  55. mock_llm.embeddings = AsyncMock(return_value=[[0.0, 1.0]])
  56. mock_llm.embedding = AsyncMock(return_value=[0.0, 1.0])
  57. svc = RagService(settings=settings, llm=mock_llm, store=store)
  58. await svc.ingest("tx", [RagDocumentIn(doc_id="a", text="短文本")], replace=True)
  59. assert store.list_task_chunks("tx")
  60. assert svc.delete_index("tx")
  61. assert not store.list_task_chunks("tx")