Browse Source

Merge branch 'opendatalab:dev' into dev

Xiaomeng Zhao 8 months ago
parent
commit
eae0e6d8c4

+ 21 - 77
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -34,8 +34,6 @@ from magic_pdf.model.model_list import MODEL
 
 # from magic_pdf.operators.models import InferenceResult
 
-MIN_BATCH_INFERENCE_SIZE = 100
-
 class ModelSingleton:
     _instance = None
     _models = {}
@@ -143,17 +141,14 @@ def doc_analyze(
     layout_model=None,
     formula_enable=None,
     table_enable=None,
-    one_shot: bool = True,
 ):
     end_page_id = (
         end_page_id
         if end_page_id is not None and end_page_id >= 0
         else len(dataset) - 1
     )
-    parallel_count = None
-    if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
-        parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
 
+    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
     images = []
     page_wh_list = []
     for index in range(len(dataset)):
@@ -163,41 +158,16 @@ def doc_analyze(
             images.append(img_dict['img'])
             page_wh_list.append((img_dict['width'], img_dict['height']))
 
-    if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
-        if parallel_count is None:
-            parallel_count = 2 # should check the gpu memory firstly !
-        # split images into parallel_count batches
-        if parallel_count > 1:
-            batch_size = (len(images) + parallel_count - 1) // parallel_count
-            batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
-        else:
-            batch_images = [images]
-        results = []
-        parallel_count = len(batch_images) # adjust to real parallel count
-        # using concurrent.futures to analyze
-        """
-        with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
-            futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
-            for future in fut.as_completed(futures):
-                sn, result = future.result()
-                result_history[sn] = result
-
-        for key in sorted(result_history.keys()):
-            results.extend(result_history[key])
-        """
-        results = []
-        pool = mp.Pool(processes=parallel_count)
-        mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
-        for sn, result in mapped_results:
-            results.extend(result)
-
+    if len(images) >= MIN_BATCH_INFERENCE_SIZE:
+        batch_size = MIN_BATCH_INFERENCE_SIZE
+        batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
     else:
-        _, results = may_batch_image_analyze(
-            images,
-            0,
-            ocr,
-            show_log,
-            lang, layout_model, formula_enable, table_enable)
+        batch_images = [images]
+
+    results = []
+    for sn, batch_image in enumerate(batch_images):
+        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
+        results.extend(result)
 
     model_json = []
     for index in range(len(dataset)):
@@ -224,11 +194,8 @@ def batch_doc_analyze(
     layout_model=None,
     formula_enable=None,
     table_enable=None,
-    one_shot: bool = True,
 ):
-    parallel_count = None
-    if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
-        parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
+    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
     images = []
     page_wh_list = []
     for dataset in datasets:
@@ -238,40 +205,17 @@ def batch_doc_analyze(
             images.append(img_dict['img'])
             page_wh_list.append((img_dict['width'], img_dict['height']))
 
-    if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
-        if parallel_count is None:
-            parallel_count = 2 # should check the gpu memory firstly !
-        # split images into parallel_count batches
-        if parallel_count > 1:
-            batch_size = (len(images) + parallel_count - 1) // parallel_count
-            batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
-        else:
-            batch_images = [images]
-        results = []
-        parallel_count = len(batch_images) # adjust to real parallel count
-        # using concurrent.futures to analyze
-        """
-        with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
-            futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
-            for future in fut.as_completed(futures):
-                sn, result = future.result()
-                result_history[sn] = result
-
-        for key in sorted(result_history.keys()):
-            results.extend(result_history[key])
-        """
-        results = []
-        pool = mp.Pool(processes=parallel_count)
-        mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
-        for sn, result in mapped_results:
-            results.extend(result)
+    if len(images) >= MIN_BATCH_INFERENCE_SIZE:
+        batch_size = MIN_BATCH_INFERENCE_SIZE
+        batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
     else:
-        _, results = may_batch_image_analyze(
-            images,
-            0,
-            ocr,
-            show_log,
-            lang, layout_model, formula_enable, table_enable)
+        batch_images = [images]
+
+    results = []
+
+    for sn, batch_image in enumerate(batch_images):
+        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
+        results.extend(result)
     infer_results = []
 
     from magic_pdf.operators.models import InferenceResult

+ 1 - 1
magic_pdf/tools/common.py

@@ -314,7 +314,7 @@ def batch_do_parse(
             dss.append(PymuDocDataset(v, lang=lang))
         else:
             dss.append(v)
-    infer_results = batch_doc_analyze(dss, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable, one_shot=True)
+    infer_results = batch_doc_analyze(dss, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     for idx, infer_result in enumerate(infer_results):
         _do_parse(output_dir, pdf_file_names[idx], dss[idx], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
 

+ 70 - 45
projects/web_api/app.py

@@ -3,6 +3,7 @@ import os
 from base64 import b64encode
 from glob import glob
 from io import StringIO
+import tempfile
 from typing import Tuple, Union
 
 import uvicorn
@@ -10,11 +11,12 @@ from fastapi import FastAPI, HTTPException, UploadFile
 from fastapi.responses import JSONResponse
 from loguru import logger
 
+from magic_pdf.data.read_api import read_local_images, read_local_office
 import magic_pdf.model as model_config
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.data.data_reader_writer import DataWriter, FileBasedDataWriter
 from magic_pdf.data.data_reader_writer.s3 import S3DataReader, S3DataWriter
-from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
 from magic_pdf.libs.config_reader import get_bucket_name, get_s3_config
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.operators.models import InferenceResult
@@ -24,6 +26,9 @@ model_config.__use_inside_model__ = True
 
 app = FastAPI()
 
+pdf_extensions = [".pdf"]
+office_extensions = [".ppt", ".pptx", ".doc", ".docx"]
+image_extensions = [".png", ".jpg"]
 
 class MemoryDataWriter(DataWriter):
     def __init__(self):
@@ -46,8 +51,8 @@ class MemoryDataWriter(DataWriter):
 
 
 def init_writers(
-    pdf_path: str = None,
-    pdf_file: UploadFile = None,
+    file_path: str = None,
+    file: UploadFile = None,
     output_path: str = None,
     output_image_path: str = None,
 ) -> Tuple[
@@ -59,19 +64,19 @@ def init_writers(
     Initialize writers based on path type
 
     Args:
-        pdf_path: PDF file path (local path or S3 path)
-        pdf_file: Uploaded PDF file object
+        file_path: file path (local path or S3 path)
+        file: Uploaded file object
         output_path: Output directory path
         output_image_path: Image output directory path
 
     Returns:
-        Tuple[writer, image_writer, pdf_bytes]: Returns initialized writer tuple and PDF
-        file content
+        Tuple[writer, image_writer, file_bytes]: Returns initialized writer tuple and file content
     """
-    if pdf_path:
-        is_s3_path = pdf_path.startswith("s3://")
+    file_extension:str = None
+    if file_path:
+        is_s3_path = file_path.startswith("s3://")
         if is_s3_path:
-            bucket = get_bucket_name(pdf_path)
+            bucket = get_bucket_name(file_path)
             ak, sk, endpoint = get_s3_config(bucket)
 
             writer = S3DataWriter(
@@ -84,25 +89,29 @@ def init_writers(
             temp_reader = S3DataReader(
                 "", bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint
             )
-            pdf_bytes = temp_reader.read(pdf_path)
+            file_bytes = temp_reader.read(file_path)
+            file_extension = os.path.splitext(file_path)[1]
         else:
             writer = FileBasedDataWriter(output_path)
             image_writer = FileBasedDataWriter(output_image_path)
             os.makedirs(output_image_path, exist_ok=True)
-            with open(pdf_path, "rb") as f:
-                pdf_bytes = f.read()
+            with open(file_path, "rb") as f:
+                file_bytes = f.read()
+            file_extension = os.path.splitext(file_path)[1]
     else:
         # 处理上传的文件
-        pdf_bytes = pdf_file.file.read()
+        file_bytes = file.file.read()
+        file_extension = os.path.splitext(file.filename)[1]
         writer = FileBasedDataWriter(output_path)
         image_writer = FileBasedDataWriter(output_image_path)
         os.makedirs(output_image_path, exist_ok=True)
 
-    return writer, image_writer, pdf_bytes
+    return writer, image_writer, file_bytes, file_extension
 
 
-def process_pdf(
-    pdf_bytes: bytes,
+def process_file(
+    file_bytes: bytes,
+    file_extension: str,
     parse_method: str,
     image_writer: Union[S3DataWriter, FileBasedDataWriter],
 ) -> Tuple[InferenceResult, PipeResult]:
@@ -110,14 +119,30 @@ def process_pdf(
     Process PDF file content
 
     Args:
-        pdf_bytes: Binary content of PDF file
+        file_bytes: Binary content of file
+        file_extension: file extension
         parse_method: Parse method ('ocr', 'txt', 'auto')
         image_writer: Image writer
 
     Returns:
         Tuple[InferenceResult, PipeResult]: Returns inference result and pipeline result
     """
-    ds = PymuDocDataset(pdf_bytes)
+
+    ds = Union[PymuDocDataset, ImageDataset]
+    if file_extension in pdf_extensions:
+        ds = PymuDocDataset(file_bytes)
+    elif file_extension in office_extensions:
+        # 需要使用office解析
+        temp_dir = tempfile.mkdtemp()
+        with open(os.path.join(temp_dir, f"temp_file.{file_extension}"), "wb") as f:
+            f.write(file_bytes)
+        ds = read_local_office(temp_dir)[0]
+    elif file_extension in image_extensions:
+        # 需要使用ocr解析
+        temp_dir = tempfile.mkdtemp()
+        with open(os.path.join(temp_dir, f"temp_file.{file_extension}"), "wb") as f:
+            f.write(file_bytes)
+        ds = read_local_images(temp_dir)[0]
     infer_result: InferenceResult = None
     pipe_result: PipeResult = None
 
@@ -145,13 +170,13 @@ def encode_image(image_path: str) -> str:
 
 
 @app.post(
-    "/pdf_parse",
+    "/file_parse",
     tags=["projects"],
-    summary="Parse PDF files (supports local files and S3)",
+    summary="Parse files (supports local files and S3)",
 )
-async def pdf_parse(
-    pdf_file: UploadFile = None,
-    pdf_path: str = None,
+async def file_parse(
+    file: UploadFile = None,
+    file_path: str = None,
     parse_method: str = "auto",
     is_json_md_dump: bool = False,
     output_dir: str = "output",
@@ -165,10 +190,10 @@ async def pdf_parse(
     to the specified directory.
 
     Args:
-        pdf_file: The PDF file to be parsed. Must not be specified together with
-            `pdf_path`
-        pdf_path: The path to the PDF file to be parsed. Must not be specified together
-            with `pdf_file`
+        file: The PDF file to be parsed. Must not be specified together with
+            `file_path`
+        file_path: The path to the PDF file to be parsed. Must not be specified together
+            with `file`
         parse_method: Parsing method, can be auto, ocr, or txt. Default is auto. If
             results are not satisfactory, try ocr
         is_json_md_dump: Whether to write parsed data to .json and .md files. Default
@@ -181,31 +206,31 @@ async def pdf_parse(
         return_content_list: Whether to return parsed PDF content list. Default to False
     """
     try:
-        if (pdf_file is None and pdf_path is None) or (
-            pdf_file is not None and pdf_path is not None
+        if (file is None and file_path is None) or (
+            file is not None and file_path is not None
         ):
             return JSONResponse(
-                content={"error": "Must provide either pdf_file or pdf_path"},
+                content={"error": "Must provide either file or file_path"},
                 status_code=400,
             )
 
         # Get PDF filename
-        pdf_name = os.path.basename(pdf_path if pdf_path else pdf_file.filename).split(
+        file_name = os.path.basename(file_path if file_path else file.filename).split(
             "."
         )[0]
-        output_path = f"{output_dir}/{pdf_name}"
+        output_path = f"{output_dir}/{file_name}"
         output_image_path = f"{output_path}/images"
 
         # Initialize readers/writers and get PDF content
-        writer, image_writer, pdf_bytes = init_writers(
-            pdf_path=pdf_path,
-            pdf_file=pdf_file,
+        writer, image_writer, file_bytes, file_extension = init_writers(
+            file_path=file_path,
+            file=file,
             output_path=output_path,
             output_image_path=output_image_path,
         )
 
         # Process PDF
-        infer_result, pipe_result = process_pdf(pdf_bytes, parse_method, image_writer)
+        infer_result, pipe_result = process_file(file_bytes, file_extension, parse_method, image_writer)
 
         # Use MemoryDataWriter to get results
         content_list_writer = MemoryDataWriter()
@@ -226,23 +251,23 @@ async def pdf_parse(
         # If results need to be saved
         if is_json_md_dump:
             writer.write_string(
-                f"{pdf_name}_content_list.json", content_list_writer.get_value()
+                f"{file_name}_content_list.json", content_list_writer.get_value()
             )
-            writer.write_string(f"{pdf_name}.md", md_content)
+            writer.write_string(f"{file_name}.md", md_content)
             writer.write_string(
-                f"{pdf_name}_middle.json", middle_json_writer.get_value()
+                f"{file_name}_middle.json", middle_json_writer.get_value()
             )
             writer.write_string(
-                f"{pdf_name}_model.json",
+                f"{file_name}_model.json",
                 json.dumps(model_json, indent=4, ensure_ascii=False),
             )
             # Save visualization results
-            pipe_result.draw_layout(os.path.join(output_path, f"{pdf_name}_layout.pdf"))
-            pipe_result.draw_span(os.path.join(output_path, f"{pdf_name}_spans.pdf"))
+            pipe_result.draw_layout(os.path.join(output_path, f"{file_name}_layout.pdf"))
+            pipe_result.draw_span(os.path.join(output_path, f"{file_name}_spans.pdf"))
             pipe_result.draw_line_sort(
-                os.path.join(output_path, f"{pdf_name}_line_sort.pdf")
+                os.path.join(output_path, f"{file_name}_line_sort.pdf")
             )
-            infer_result.draw_model(os.path.join(output_path, f"{pdf_name}_model.pdf"))
+            infer_result.draw_model(os.path.join(output_path, f"{file_name}_model.pdf"))
 
         # Build return data
         data = {}

+ 24 - 0
signatures/version1/cla.json

@@ -183,6 +183,30 @@
       "created_at": "2025-02-26T09:23:25Z",
       "repoId": 765083837,
       "pullRequestNo": 1785
+    },
+    {
+      "name": "rschutski",
+      "id": 179498169,
+      "comment_id": 2705150371,
+      "created_at": "2025-03-06T23:16:30Z",
+      "repoId": 765083837,
+      "pullRequestNo": 1863
+    },
+    {
+      "name": "qbit-",
+      "id": 4794088,
+      "comment_id": 2705914730,
+      "created_at": "2025-03-07T09:09:13Z",
+      "repoId": 765083837,
+      "pullRequestNo": 1863
+    },
+    {
+      "name": "mauryaland",
+      "id": 22381129,
+      "comment_id": 2717322316,
+      "created_at": "2025-03-12T10:03:11Z",
+      "repoId": 765083837,
+      "pullRequestNo": 1906
     }
   ]
 }