rag.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. from __future__ import annotations
  2. from fastapi import APIRouter, File, Form, HTTPException, UploadFile
  3. from finrep_algo_agent.api.deps import RagServiceDep, SettingsDep
  4. from finrep_algo_agent.rag.ingestion import extract_text_from_upload
  5. from finrep_algo_agent.schemas.rag import (
  6. RagDeleteResponse,
  7. RagDocumentIn,
  8. RagFileProcessResult,
  9. RagIngestFilesResponse,
  10. RagIngestRequest,
  11. RagIngestResponse,
  12. RagRetrieveRequest,
  13. RagRetrieveResponse,
  14. )
  15. router = APIRouter()
  16. @router.post("/ingest-files", response_model=RagIngestFilesResponse)
  17. async def rag_ingest_files(
  18. settings: SettingsDep,
  19. rag: RagServiceDep,
  20. task_id: str = Form(..., description="报告任务 ID,材料向量按 task 隔离"),
  21. replace: bool = Form(True),
  22. files: list[UploadFile] = File(..., description="业务上传材料,服务端解析文本后自动分块入库"),
  23. ) -> RagIngestFilesResponse:
  24. """文件直达 Python 时走本接口:解析 → 向量化 → 写入本任务 RAG 索引(无需 Java 先抽正文)。"""
  25. if not (settings.embedding_api_key or settings.llm_api_key):
  26. raise HTTPException(
  27. status_code=400,
  28. detail="RAG 入库需配置 FINREP_EMBEDDING_API_KEY 或 FINREP_LLM_API_KEY",
  29. )
  30. if not files:
  31. raise HTTPException(status_code=422, detail="files 不能为空")
  32. file_results: list[RagFileProcessResult] = []
  33. documents: list[RagDocumentIn] = []
  34. for uf in files:
  35. raw = await uf.read()
  36. ex = extract_text_from_upload(filename=uf.filename or "upload.bin", data=raw)
  37. fr = RagFileProcessResult(
  38. filename=ex.source_label,
  39. doc_id=ex.doc_id,
  40. characters=len(ex.text),
  41. skipped=not bool(ex.text),
  42. warning=ex.warning,
  43. )
  44. file_results.append(fr)
  45. if ex.text:
  46. documents.append(
  47. RagDocumentIn(
  48. doc_id=ex.doc_id,
  49. title="",
  50. text=ex.text,
  51. source_label=ex.source_label,
  52. )
  53. )
  54. if not documents:
  55. raise HTTPException(
  56. status_code=422,
  57. detail="所有文件均未解析出有效文本,未写入索引",
  58. )
  59. try:
  60. ing = await rag.ingest(task_id, documents, replace=replace)
  61. except ValueError as e:
  62. raise HTTPException(status_code=422, detail=str(e)) from e
  63. except Exception as e:
  64. raise HTTPException(status_code=502, detail=f"向量化服务异常: {e}") from e
  65. return RagIngestFilesResponse(
  66. task_id=ing.task_id,
  67. document_count=ing.document_count,
  68. chunk_count=ing.chunk_count,
  69. files=file_results,
  70. )
  71. @router.post("/ingest", response_model=RagIngestResponse)
  72. async def rag_ingest(
  73. body: RagIngestRequest,
  74. settings: SettingsDep,
  75. rag: RagServiceDep,
  76. ) -> RagIngestResponse:
  77. if not (settings.embedding_api_key or settings.llm_api_key):
  78. raise HTTPException(
  79. status_code=400,
  80. detail="RAG 入库需配置 FINREP_EMBEDDING_API_KEY 或 FINREP_LLM_API_KEY",
  81. )
  82. try:
  83. return await rag.ingest(
  84. body.task_id,
  85. body.documents,
  86. replace=body.replace,
  87. )
  88. except ValueError as e:
  89. raise HTTPException(status_code=422, detail=str(e)) from e
  90. except Exception as e:
  91. raise HTTPException(status_code=502, detail=f"向量化服务异常: {e}") from e
  92. @router.post("/retrieve", response_model=RagRetrieveResponse)
  93. async def rag_retrieve(
  94. body: RagRetrieveRequest,
  95. settings: SettingsDep,
  96. rag: RagServiceDep,
  97. ) -> RagRetrieveResponse:
  98. if not (settings.embedding_api_key or settings.llm_api_key):
  99. raise HTTPException(
  100. status_code=400,
  101. detail="RAG 检索需配置 FINREP_EMBEDDING_API_KEY 或 FINREP_LLM_API_KEY",
  102. )
  103. try:
  104. return await rag.retrieve(
  105. body.task_id,
  106. body.query,
  107. top_k=body.top_k,
  108. min_score=body.min_score,
  109. )
  110. except Exception as e:
  111. raise HTTPException(status_code=502, detail=f"检索服务异常: {e}") from e
  112. @router.delete("/{task_id}", response_model=RagDeleteResponse)
  113. async def rag_delete_index(task_id: str, rag: RagServiceDep) -> RagDeleteResponse:
  114. deleted = rag.delete_index(task_id)
  115. return RagDeleteResponse(task_id=task_id, deleted=deleted)