瀏覽代碼

feat: support return parse result by zip

lr90 3 月之前
父節點
當前提交
4a237eef36
共有 1 個文件被更改,包括 122 次插入39 次删除
  1. 122 39
      mineru/cli/fast_api.py

+ 122 - 39
mineru/cli/fast_api.py

@@ -1,12 +1,17 @@
 import uuid
 import os
+import re
+import tempfile
+import asyncio
 import uvicorn
 import click
+import zipfile
 from pathlib import Path
-from glob import glob
+import glob
 from fastapi import FastAPI, UploadFile, File, Form
 from fastapi.middleware.gzip import GZipMiddleware
-from fastapi.responses import JSONResponse
+from fastapi.responses import JSONResponse, FileResponse
+from starlette.background import BackgroundTask
 from typing import List, Optional
 from loguru import logger
 from base64 import b64encode
@@ -18,6 +23,27 @@ from mineru.version import __version__
 app = FastAPI()
 app.add_middleware(GZipMiddleware, minimum_size=1000)
 
+
+def sanitize_filename(filename: str) -> str:
+    """
+    格式化压缩文件的文件名
+    移除路径遍历字符, 保留 Unicode 字母、数字、._- 
+    禁止隐藏文件
+    """
+    sanitized = re.sub(r'[/\\\.]{2,}|[/\\]', '', filename)
+    sanitized = re.sub(r'[^\w.-]', '_', sanitized, flags=re.UNICODE)
+    if sanitized.startswith('.'):
+        sanitized = '_' + sanitized[1:]
+    return sanitized or 'unnamed'
+
+def cleanup_file(file_path: str) -> None:
+    """清理临时 zip 文件"""
+    try:
+        if os.path.exists(file_path):
+            os.remove(file_path)
+    except Exception as e:
+        logger.warning(f"fail clean file {file_path}: {e}")
+
 def encode_image(image_path: str) -> str:
     """Encode image using base64"""
     with open(image_path, "rb") as f:
@@ -48,6 +74,7 @@ async def parse_pdf(
         return_model_output: bool = Form(False),
         return_content_list: bool = Form(False),
         return_images: bool = Form(False),
+        response_format_zip: bool = Form(False),
         start_page_id: int = Form(0),
         end_page_id: int = Form(99999),
 ):
@@ -121,45 +148,101 @@ async def parse_pdf(
             **config
         )
 
-        # 构建结果路径
-        result_dict = {}
-        for pdf_name in pdf_file_names:
-            result_dict[pdf_name] = {}
-            data = result_dict[pdf_name]
-
-            if backend.startswith("pipeline"):
-                parse_dir = os.path.join(unique_dir, pdf_name, parse_method)
-            else:
-                parse_dir = os.path.join(unique_dir, pdf_name, "vlm")
-
-            if os.path.exists(parse_dir):
-                if return_md:
-                    data["md_content"] = get_infer_result(".md", pdf_name, parse_dir)
-                if return_middle_json:
-                    data["middle_json"] = get_infer_result("_middle.json", pdf_name, parse_dir)
-                if return_model_output:
+        # 根据 response_format_zip 决定返回类型
+        if response_format_zip:
+            zip_fd, zip_path = tempfile.mkstemp(suffix=".zip", prefix="mineru_results_")
+            os.close(zip_fd) 
+            with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
+                for pdf_name in pdf_file_names:
+                    safe_pdf_name = sanitize_filename(pdf_name)
                     if backend.startswith("pipeline"):
-                        data["model_output"] = get_infer_result("_model.json", pdf_name, parse_dir)
+                        parse_dir = os.path.join(unique_dir, pdf_name, parse_method)
                     else:
-                        data["model_output"] = get_infer_result("_model_output.txt", pdf_name, parse_dir)
-                if return_content_list:
-                    data["content_list"] = get_infer_result("_content_list.json", pdf_name, parse_dir)
-                if return_images:
-                    image_paths = glob(f"{parse_dir}/images/*.jpg")
-                    data["images"] = {
-                        os.path.basename(
-                            image_path
-                        ): f"data:image/jpeg;base64,{encode_image(image_path)}"
-                        for image_path in image_paths
-                    }
-        return JSONResponse(
-            status_code=200,
-            content={
-                "backend": backend,
-                "version": __version__,
-                "results": result_dict
-            }
-        )
+                        parse_dir = os.path.join(unique_dir, pdf_name, "vlm")
+
+                    if not os.path.exists(parse_dir):
+                        continue
+
+                    # 写入文本类结果
+                    if return_md:
+                        path = os.path.join(parse_dir, f"{pdf_name}.md")
+                        if os.path.exists(path):
+                            zf.write(path, arcname=os.path.join(safe_pdf_name, f"{safe_pdf_name}.md"))
+
+                    if return_middle_json:
+                        path = os.path.join(parse_dir, f"{pdf_name}_middle.json")
+                        if os.path.exists(path):
+                            zf.write(path, arcname=os.path.join(safe_pdf_name, f"{safe_pdf_name}_middle.json"))
+
+                    if return_model_output:
+                        if backend.startswith("pipeline"):
+                            path = os.path.join(parse_dir, f"{pdf_name}_model.json")
+                        else:
+                            path = os.path.join(parse_dir, f"{pdf_name}_model_output.txt")
+                        if os.path.exists(path): 
+                            zf.write(path, arcname=os.path.join(safe_pdf_name, os.path.basename(path)))
+
+                    if return_content_list:
+                        path = os.path.join(parse_dir, f"{pdf_name}_content_list.json")
+                        if os.path.exists(path):
+                            zf.write(path, arcname=os.path.join(safe_pdf_name, f"{safe_pdf_name}_content_list.json"))
+
+                    # 写入图片
+                    if return_images:
+                        images_dir = os.path.join(parse_dir, "images")
+                        image_paths = glob.glob(os.path.join(glob.escape(images_dir), "*.jpg"))
+                        for image_path in image_paths:
+                            zf.write(image_path, arcname=os.path.join(safe_pdf_name, "images", os.path.basename(image_path)))
+
+            return FileResponse(
+                path=zip_path,
+                media_type="application/zip",
+                filename="results.zip",
+                background=BackgroundTask(cleanup_file, zip_path)
+            )
+        else:
+            # 构建 JSON 结果
+            result_dict = {}
+            for pdf_name in pdf_file_names:
+                result_dict[pdf_name] = {}
+                data = result_dict[pdf_name]
+
+                if backend.startswith("pipeline"):
+                    parse_dir = os.path.join(unique_dir, pdf_name, parse_method)
+                else:
+                    parse_dir = os.path.join(unique_dir, pdf_name, "vlm")
+
+                if os.path.exists(parse_dir):
+                    if return_md:
+                        data["md_content"] = get_infer_result(".md", pdf_name, parse_dir)
+                    if return_middle_json:
+                        data["middle_json"] = get_infer_result("_middle.json", pdf_name, parse_dir)
+                    if return_model_output:
+                        if backend.startswith("pipeline"):
+                            data["model_output"] = get_infer_result("_model.json", pdf_name, parse_dir)
+                        else:
+                            data["model_output"] = get_infer_result("_model_output.txt", pdf_name, parse_dir)
+                    if return_content_list:
+                        data["content_list"] = get_infer_result("_content_list.json", pdf_name, parse_dir)
+                    if return_images:
+                        images_dir = os.path.join(parse_dir, "images")
+                        safe_pattern = os.path.join(glob.escape(images_dir), "*.jpg")
+                        image_paths = glob.glob(safe_pattern)
+                        data["images"] = {
+                            os.path.basename(
+                                image_path
+                            ): f"data:image/jpeg;base64,{encode_image(image_path)}"
+                            for image_path in image_paths
+                        }
+
+            return JSONResponse(
+                status_code=200,
+                content={
+                    "backend": backend,
+                    "version": __version__,
+                    "results": result_dict
+                }
+            )
     except Exception as e:
         logger.exception(e)
         return JSONResponse(