瀏覽代碼

add support for more document types

JesseChen1031 8 月之前
父節點
當前提交
102fe27777
共有 1 個文件被更改,包括 59 次插入34 次删除
  1. 59 34
      projects/web_api/app.py

+ 59 - 34
projects/web_api/app.py

@@ -3,6 +3,7 @@ import os
 from base64 import b64encode
 from glob import glob
 from io import StringIO
+import tempfile
 from typing import Tuple, Union
 
 import uvicorn
@@ -10,11 +11,12 @@ from fastapi import FastAPI, HTTPException, UploadFile
 from fastapi.responses import JSONResponse
 from loguru import logger
 
+from magic_pdf.data.read_api import read_local_images, read_local_office
 import magic_pdf.model as model_config
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.data.data_reader_writer import DataWriter, FileBasedDataWriter
 from magic_pdf.data.data_reader_writer.s3 import S3DataReader, S3DataWriter
-from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
 from magic_pdf.libs.config_reader import get_bucket_name, get_s3_config
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.operators.models import InferenceResult
@@ -24,6 +26,9 @@ model_config.__use_inside_model__ = True
 
 app = FastAPI()
 
+pdf_extensions = [".pdf"]
+office_extensions = [".ppt", ".pptx", ".doc", ".docx"]
+image_extensions = [".png", ".jpg"]
 
 class MemoryDataWriter(DataWriter):
     def __init__(self):
@@ -46,8 +51,8 @@ class MemoryDataWriter(DataWriter):
 
 
 def init_writers(
-    pdf_path: str = None,
-    pdf_file: UploadFile = None,
+    file_path: str = None,
+    file: UploadFile = None,
     output_path: str = None,
     output_image_path: str = None,
 ) -> Tuple[
@@ -68,10 +73,11 @@ def init_writers(
         Tuple[writer, image_writer, pdf_bytes]: Returns initialized writer tuple and PDF
         file content
     """
-    if pdf_path:
-        is_s3_path = pdf_path.startswith("s3://")
+    file_extension:str = None
+    if file_path:
+        is_s3_path = file_path.startswith("s3://")
         if is_s3_path:
-            bucket = get_bucket_name(pdf_path)
+            bucket = get_bucket_name(file_path)
             ak, sk, endpoint = get_s3_config(bucket)
 
             writer = S3DataWriter(
@@ -84,25 +90,29 @@ def init_writers(
             temp_reader = S3DataReader(
                 "", bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint
             )
-            pdf_bytes = temp_reader.read(pdf_path)
+            file_bytes = temp_reader.read(file_path)
+            file_extension = os.path.splitext(file_path)[1]
         else:
             writer = FileBasedDataWriter(output_path)
             image_writer = FileBasedDataWriter(output_image_path)
             os.makedirs(output_image_path, exist_ok=True)
-            with open(pdf_path, "rb") as f:
-                pdf_bytes = f.read()
+            with open(file_path, "rb") as f:
+                file_bytes = f.read()
+            file_extension = os.path.splitext(file_path)[1]
     else:
         # 处理上传的文件
-        pdf_bytes = pdf_file.file.read()
+        file_bytes = file.file.read()
+        file_extension = os.path.splitext(file.filename)[1]
         writer = FileBasedDataWriter(output_path)
         image_writer = FileBasedDataWriter(output_image_path)
         os.makedirs(output_image_path, exist_ok=True)
 
-    return writer, image_writer, pdf_bytes
+    return writer, image_writer, file_bytes, file_extension
 
 
-def process_pdf(
-    pdf_bytes: bytes,
+def process_file(
+    file_bytes: bytes,
+    file_extension: str,
     parse_method: str,
     image_writer: Union[S3DataWriter, FileBasedDataWriter],
 ) -> Tuple[InferenceResult, PipeResult]:
@@ -117,7 +127,22 @@ def process_pdf(
     Returns:
         Tuple[InferenceResult, PipeResult]: Returns inference result and pipeline result
     """
-    ds = PymuDocDataset(pdf_bytes)
+
+    ds = Union[PymuDocDataset, ImageDataset]
+    if file_extension in pdf_extensions:
+        ds = PymuDocDataset(file_bytes)
+    elif file_extension in office_extensions:
+        # 需要使用office解析
+        temp_dir = tempfile.mkdtemp()
+        with open(os.path.join(temp_dir, f"temp_file.{file_extension}"), "wb") as f:
+            f.write(file_bytes)
+        ds = read_local_office(temp_dir)[0]
+    elif file_extension in image_extensions:
+        # 需要使用ocr解析
+        temp_dir = tempfile.mkdtemp()
+        with open(os.path.join(temp_dir, f"temp_file.{file_extension}"), "wb") as f:
+            f.write(file_bytes)
+        ds = read_local_images(temp_dir)[0]
     infer_result: InferenceResult = None
     pipe_result: PipeResult = None
 
@@ -149,9 +174,9 @@ def encode_image(image_path: str) -> str:
     tags=["projects"],
     summary="Parse PDF files (supports local files and S3)",
 )
-async def pdf_parse(
-    pdf_file: UploadFile = None,
-    pdf_path: str = None,
+async def file_parse(
+    file: UploadFile = None,
+    file_path: str = None,
     parse_method: str = "auto",
     is_json_md_dump: bool = False,
     output_dir: str = "output",
@@ -181,31 +206,31 @@ async def pdf_parse(
         return_content_list: Whether to return parsed PDF content list. Default to False
     """
     try:
-        if (pdf_file is None and pdf_path is None) or (
-            pdf_file is not None and pdf_path is not None
+        if (file is None and file_path is None) or (
+            file is not None and file_path is not None
         ):
             return JSONResponse(
-                content={"error": "Must provide either pdf_file or pdf_path"},
+                content={"error": "Must provide either file or file_path"},
                 status_code=400,
             )
 
         # Get PDF filename
-        pdf_name = os.path.basename(pdf_path if pdf_path else pdf_file.filename).split(
+        file_name = os.path.basename(file_path if file_path else file.filename).split(
             "."
         )[0]
-        output_path = f"{output_dir}/{pdf_name}"
+        output_path = f"{output_dir}/{file_name}"
         output_image_path = f"{output_path}/images"
 
         # Initialize readers/writers and get PDF content
-        writer, image_writer, pdf_bytes = init_writers(
-            pdf_path=pdf_path,
-            pdf_file=pdf_file,
+        writer, image_writer, file_bytes, file_extension = init_writers(
+            file_path=file_path,
+            file=file,
             output_path=output_path,
             output_image_path=output_image_path,
         )
 
         # Process PDF
-        infer_result, pipe_result = process_pdf(pdf_bytes, parse_method, image_writer)
+        infer_result, pipe_result = process_file(file_bytes, file_extension, parse_method, image_writer)
 
         # Use MemoryDataWriter to get results
         content_list_writer = MemoryDataWriter()
@@ -226,23 +251,23 @@ async def pdf_parse(
         # If results need to be saved
         if is_json_md_dump:
             writer.write_string(
-                f"{pdf_name}_content_list.json", content_list_writer.get_value()
+                f"{file_name}_content_list.json", content_list_writer.get_value()
             )
-            writer.write_string(f"{pdf_name}.md", md_content)
+            writer.write_string(f"{file_name}.md", md_content)
             writer.write_string(
-                f"{pdf_name}_middle.json", middle_json_writer.get_value()
+                f"{file_name}_middle.json", middle_json_writer.get_value()
             )
             writer.write_string(
-                f"{pdf_name}_model.json",
+                f"{file_name}_model.json",
                 json.dumps(model_json, indent=4, ensure_ascii=False),
             )
             # Save visualization results
-            pipe_result.draw_layout(os.path.join(output_path, f"{pdf_name}_layout.pdf"))
-            pipe_result.draw_span(os.path.join(output_path, f"{pdf_name}_spans.pdf"))
+            pipe_result.draw_layout(os.path.join(output_path, f"{file_name}_layout.pdf"))
+            pipe_result.draw_span(os.path.join(output_path, f"{file_name}_spans.pdf"))
             pipe_result.draw_line_sort(
-                os.path.join(output_path, f"{pdf_name}_line_sort.pdf")
+                os.path.join(output_path, f"{file_name}_line_sort.pdf")
             )
-            infer_result.draw_model(os.path.join(output_path, f"{pdf_name}_model.pdf"))
+            infer_result.draw_model(os.path.join(output_path, f"{file_name}_model.pdf"))
 
         # Build return data
         data = {}