Przeglądaj źródła

refactor: add GZip middleware and refactor get_infer_result function in fast_api.py

myhloli 4 miesięcy temu
rodzic
commit
a55c47f184
1 zmienionych plików z 17 dodań i 18 usunięć
  1. 17 18
      mineru/cli/fast_api.py

+ 17 - 18
mineru/cli/fast_api.py

@@ -1,22 +1,21 @@
 import uuid
 import os
-from base64 import b64encode
-
 import uvicorn
 import argparse
 from pathlib import Path
 from glob import glob
 from fastapi import FastAPI, UploadFile, File, Form
+from fastapi.middleware.gzip import GZipMiddleware
 from fastapi.responses import JSONResponse
 from typing import List, Optional
-
 from loguru import logger
+from base64 import b64encode
 
 from mineru.cli.common import aio_do_parse, read_fn
 from mineru.version import __version__
 
 app = FastAPI()
-
+app.add_middleware(GZipMiddleware, minimum_size=1000)
 
 def encode_image(image_path: str) -> str:
     """Encode image using base64"""
@@ -24,6 +23,15 @@ def encode_image(image_path: str) -> str:
         return b64encode(f.read()).decode()
 
 
+def get_infer_result(file_suffix_identifier: str, pdf_name: str, parse_dir: str) -> Optional[str]:
+    """从结果文件中读取推理结果"""
+    result_file_path = os.path.join(parse_dir, f"{pdf_name}{file_suffix_identifier}")
+    if os.path.exists(result_file_path):
+        with open(result_file_path, "r", encoding="utf-8") as fp:
+            return fp.read()
+    return None
+
+
 @app.post(path="/file_parse",)
 async def parse_pdf(
         files: List[UploadFile] = File(...),
@@ -118,27 +126,18 @@ async def parse_pdf(
             else:
                 parse_dir = os.path.join(unique_dir, pdf_name, "vlm")
 
-            def get_infer_result(file_suffix_identifier: str):
-                """从结果文件中读取推理结果"""
-                result_file_path = os.path.join(parse_dir, f"{pdf_name}{file_suffix_identifier}")
-                if os.path.exists(result_file_path):
-                    with open(result_file_path, "r", encoding="utf-8") as fp:
-                        return fp.read()
-                return None
-
-
             if os.path.exists(parse_dir):
                 if return_md:
-                    data["md_content"] = get_infer_result(".md")
+                    data["md_content"] = get_infer_result(".md", pdf_name, parse_dir)
                 if return_middle_json:
-                    data["middle_json"] = get_infer_result("_middle.json")
+                    data["middle_json"] = get_infer_result("_middle.json", pdf_name, parse_dir)
                 if return_model_output:
                     if backend.startswith("pipeline"):
-                        data["model_output"] = get_infer_result("_model.json")
+                        data["model_output"] = get_infer_result("_model.json", pdf_name, parse_dir)
                     else:
-                        data["model_output"] = get_infer_result("_model_output.txt")
+                        data["model_output"] = get_infer_result("_model_output.txt", pdf_name, parse_dir)
                 if return_content_list:
-                    data["content_list"] = get_infer_result("_content_list.json")
+                    data["content_list"] = get_infer_result("_content_list.json", pdf_name, parse_dir)
                 if return_images:
                     image_paths = glob(f"{parse_dir}/images/*.jpg")
                     data["images"] = {