瀏覽代碼

feat: add FastAPI server for PDF processing and file parsing

myhloli 4 月之前
父節點
當前提交
3f87f83fda
共有 2 個文件被更改,包括 201 次插入0 次删除
  1. 189 0
      mineru/cli/fast_api.py
  2. 12 0
      pyproject.toml

+ 189 - 0
mineru/cli/fast_api.py

@@ -0,0 +1,189 @@
+import uuid
+import os
+from base64 import b64encode
+
+import uvicorn
+import argparse
+from pathlib import Path
+from glob import glob
+from fastapi import FastAPI, UploadFile, File, Form
+from fastapi.responses import JSONResponse
+from typing import List, Optional
+
+from loguru import logger
+
+from mineru.cli.common import aio_do_parse, read_fn
+from mineru.version import __version__
+
+app = FastAPI()
+
+
+def encode_image(image_path: str) -> str:
+    """Encode image using base64"""
+    with open(image_path, "rb") as f:
+        return b64encode(f.read()).decode()
+
+
+@app.post(path="/file_parse",)
+async def parse_pdf(
+        files: List[UploadFile] = File(...),
+        output_dir: str = Form("./output"),
+        lang_list: List[str] = Form(["ch"]),
+        backend: str = Form("pipeline"),
+        parse_method: str = Form("auto"),
+        formula_enable: bool = Form(True),
+        table_enable: bool = Form(True),
+        server_url: Optional[str] = Form(None),
+        reuturn_md: bool = Form(True),
+        reuturn_middle_json: bool = Form(False),
+        return_model_output: bool = Form(False),
+        reuturn_content_list: bool = Form(False),
+        return_images: bool = Form(False),
+        start_page_id: int = Form(0),
+        end_page_id: int = Form(99999),
+):
+    try:
+        # 创建唯一的输出目录
+        unique_dir = os.path.join(output_dir, str(uuid.uuid4()))
+        os.makedirs(unique_dir, exist_ok=True)
+
+        # 处理上传的PDF文件
+        pdf_file_names = []
+        pdf_bytes_list = []
+
+        for file in files:
+            content = await file.read()
+            file_path = Path(file.filename)
+
+            # 如果是图像文件或PDF,使用read_fn处理
+            if file_path.suffix.lower() in [".pdf", ".png", ".jpeg", ".jpg"]:
+                # 创建临时文件以便使用read_fn
+                temp_path = Path(unique_dir) / file_path.name
+                with open(temp_path, "wb") as f:
+                    f.write(content)
+
+                try:
+                    pdf_bytes = read_fn(temp_path)
+                    pdf_bytes_list.append(pdf_bytes)
+                    pdf_file_names.append(file_path.stem)
+                    os.remove(temp_path)  # 删除临时文件
+                except Exception as e:
+                    return JSONResponse(
+                        status_code=400,
+                        content={"error": f"处理文件失败: {str(e)}"}
+                    )
+            else:
+                return JSONResponse(
+                    status_code=400,
+                    content={"error": f"不支持的文件类型: {file_path.suffix}"}
+                )
+
+
+        # 设置语言列表,确保与文件数量一致
+        actual_lang_list = lang_list
+        if len(actual_lang_list) != len(pdf_file_names):
+            # 如果语言列表长度不匹配,使用第一个语言或默认"ch"
+            actual_lang_list = [actual_lang_list[0] if actual_lang_list else "ch"] * len(pdf_file_names)
+
+        # 调用异步处理函数
+        await aio_do_parse(
+            output_dir=unique_dir,
+            pdf_file_names=pdf_file_names,
+            pdf_bytes_list=pdf_bytes_list,
+            p_lang_list=actual_lang_list,
+            backend=backend,
+            parse_method=parse_method,
+            p_formula_enable=formula_enable,
+            p_table_enable=table_enable,
+            server_url=server_url,
+            f_draw_layout_bbox=False,
+            f_draw_span_bbox=False,
+            f_dump_md=reuturn_md,
+            f_dump_middle_json=reuturn_middle_json,
+            f_dump_model_output=return_model_output,
+            f_dump_orig_pdf=False,
+            f_dump_content_list=reuturn_content_list,
+            start_page_id=start_page_id,
+            end_page_id=end_page_id,
+        )
+
+        # 构建结果路径
+        result_dict = {}
+        for pdf_name in pdf_file_names:
+            result_dict[pdf_name] = {}
+            data = result_dict[pdf_name]
+
+            if backend.startswith("pipeline"):
+                parse_dir = os.path.join(unique_dir, pdf_name, parse_method)
+            else:
+                parse_dir = os.path.join(unique_dir, pdf_name, "vlm")
+
+            def get_infer_result(file_suffix_identifier: str):
+                """从结果文件中读取推理结果"""
+                result_file_path = os.path.join(parse_dir, f"{pdf_name}{file_suffix_identifier}")
+                if os.path.exists(result_file_path):
+                    with open(result_file_path, "r", encoding="utf-8") as fp:
+                        return fp.read()
+                return None
+
+
+            if os.path.exists(parse_dir):
+                if reuturn_md:
+                    data["md_content"] = get_infer_result(".md")
+                if reuturn_middle_json:
+                    data["middle_json"] = get_infer_result("_middle.json")
+                if return_model_output:
+                    if backend.startswith("pipeline"):
+                        data["model_output"] = get_infer_result("_model.json")
+                    else:
+                        data["model_output"] = get_infer_result("_model_output.txt")
+                if reuturn_content_list:
+                    data["content_list"] = get_infer_result("_content_list.json")
+                if return_images:
+                    image_paths = glob(f"{parse_dir}/images/*.jpg")
+                    data["images"] = {
+                        os.path.basename(
+                            image_path
+                        ): f"data:image/jpeg;base64,{encode_image(image_path)}"
+                        for image_path in image_paths
+                    }
+        return JSONResponse(
+            status_code=200,
+            content={
+                "status": "success",
+                "backend": backend,
+                "version": __version__,
+                "results": result_dict
+            }
+        )
+    except Exception as e:
+        logger.exception(e)
+        return JSONResponse(
+            status_code=500,
+            content={"error": str(e)}
+        )
+
+
+def main():
+    """启动MinerU FastAPI服务器的命令行入口"""
+    parser = argparse.ArgumentParser(description='Start MinerU FastAPI Service')
+    parser.add_argument('--host', type=str, default='127.0.0.1', help='Server host (default: 127.0.0.1)')
+    parser.add_argument('--port', type=int, default=8000, help='Server port (default: 8000)')
+    parser.add_argument('--reload', action='store_true', help='Enable auto-reload (development mode)')
+    args = parser.parse_args()
+
+    print(f"Start MinerU FastAPI Service: http://{args.host}:{args.port}")
+    print("The API documentation can be accessed at the following address:")
+    print(f"- Swagger UI: http://{args.host}:{args.port}/docs")
+    print(f"- ReDoc: http://{args.host}:{args.port}/redoc")
+
+    uvicorn.run(
+        "mineru.cli.fast_api:app",
+        host=args.host,
+        port=args.port,
+        reload=args.reload
+    )
+
+
+if __name__ == "__main__":
+    main()

+ 12 - 0
pyproject.toml

@@ -62,9 +62,20 @@ pipeline = [
     "transformers>=4.49.0,!=4.51.0,<5.0.0",
     "fast-langdetect>=0.2.3,<0.3.0",
 ]
+api = [
+    "fastapi",
+    "python-multipart",
+    "uvicorn",
+]
+gradio = [
+    "gradio",
+    "gradio-pdf",
+]
 core = [
     "mineru[vlm]",
     "mineru[pipeline]",
+    "mineru[api]",
+    "mineru[gradio]",
 ]
 all = [
     "mineru[core]",
@@ -97,6 +108,7 @@ Repository = "https://github.com/opendatalab/MinerU"
 mineru = "mineru.cli:client.main"
 mineru-sglang-server = "mineru.cli.vlm_sglang_server:main"
 mineru-models-download = "mineru.cli.models_download:download_models"
+mineru-api = "mineru.cli.fast_api:main"
 
 [tool.setuptools.dynamic]
 version = {attr = "mineru.version.__version__"}