Sfoglia il codice sorgente

Merge pull request #1910 from icecraft/fix/parallel_split

Fix/parallel split
Xiaomeng Zhao 8 mesi fa
parent
commit
ecdd162f11

+ 156 - 0
magic_pdf/data/batch_build_dataset.py

@@ -0,0 +1,156 @@
+import concurrent.futures
+
+import fitz
+
+from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.data.utils import fitz_doc_to_image  # PyMuPDF
+
+
+def partition_array_greedy(arr, k):
+    """Partition an array into k parts using a simple greedy approach.
+
+    Parameters:
+    -----------
+    arr : list
+        The input array of integers
+    k : int
+        Number of partitions to create
+
+    Returns:
+    --------
+    partitions : list of lists
+        The k partitions of the array
+    """
+    # Handle edge cases
+    if k <= 0:
+        raise ValueError('k must be a positive integer')
+    if k > len(arr):
+        k = len(arr)  # Adjust k if it's too large
+    if k == 1:
+        return [list(range(len(arr)))]
+    if k == len(arr):
+        return [[i] for i in range(len(arr))]
+
+    # Sort the array in descending order
+    sorted_indices = sorted(range(len(arr)), key=lambda i: arr[i][1], reverse=True)
+
+    # Initialize k empty partitions
+    partitions = [[] for _ in range(k)]
+    partition_sums = [0] * k
+
+    # Assign each element to the partition with the smallest current sum
+    for idx in sorted_indices:
+        # Find the partition with the smallest sum
+        min_sum_idx = partition_sums.index(min(partition_sums))
+
+        # Add the element to this partition
+        partitions[min_sum_idx].append(idx)  # Store the original index
+        partition_sums[min_sum_idx] += arr[idx][1]
+
+    return partitions
+
+
+def process_pdf_batch(pdf_jobs, idx):
+    """Process a batch of PDF pages using multiple threads.
+
+    Parameters:
+    -----------
+    pdf_jobs : list of tuples
+        List of (pdf_path, page_num) tuples
+    output_dir : str or None
+        Directory to save images to
+    num_threads : int
+        Number of threads to use
+    **kwargs :
+        Additional arguments for process_pdf_page
+
+    Returns:
+    --------
+    images : list
+        List of processed images
+    """
+    images = []
+
+    for pdf_path, _ in pdf_jobs:
+        doc = fitz.open(pdf_path)
+        tmp = []
+        for page_num in range(len(doc)):
+            page = doc[page_num]
+            tmp.append(fitz_doc_to_image(page))
+        images.append(tmp)
+    return (idx, images)
+
+
+def batch_build_dataset(pdf_paths, k, lang=None):
+    """Process multiple PDFs by partitioning them into k balanced parts and
+    processing each part in parallel.
+
+    Parameters:
+    -----------
+    pdf_paths : list
+        List of paths to PDF files
+    k : int
+        Number of partitions to create
+    output_dir : str or None
+        Directory to save images to
+    threads_per_worker : int
+        Number of threads to use per worker
+    **kwargs :
+        Additional arguments for process_pdf_page
+
+    Returns:
+    --------
+    all_images : list
+        List of all processed images
+    """
+    # Get page counts for each PDF
+    pdf_info = []
+    total_pages = 0
+
+    for pdf_path in pdf_paths:
+        try:
+            doc = fitz.open(pdf_path)
+            num_pages = len(doc)
+            pdf_info.append((pdf_path, num_pages))
+            total_pages += num_pages
+            doc.close()
+        except Exception as e:
+            print(f'Error opening {pdf_path}: {e}')
+
+    # Partition the jobs based on page countEach job has 1 page
+    partitions = partition_array_greedy(pdf_info, k)
+
+    # Process each partition in parallel
+    all_images_h = {}
+
+    with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor:
+        # Submit one task per partition
+        futures = []
+        for sn, partition in enumerate(partitions):
+            # Get the jobs for this partition
+            partition_jobs = [pdf_info[idx] for idx in partition]
+
+            # Submit the task
+            future = executor.submit(
+                process_pdf_batch,
+                partition_jobs,
+                sn
+            )
+            futures.append(future)
+        # Process results as they complete
+        for i, future in enumerate(concurrent.futures.as_completed(futures)):
+            try:
+                idx, images = future.result()
+                all_images_h[idx] = images
+            except Exception as e:
+                print(f'Error processing partition: {e}')
+    results = [None] * len(pdf_paths)
+    for i in range(len(partitions)):
+        partition = partitions[i]
+        for j in range(len(partition)):
+            with open(pdf_info[partition[j]][0], 'rb') as f:
+                pdf_bytes = f.read()
+            dataset = PymuDocDataset(pdf_bytes, lang=lang)
+            dataset.set_images(all_images_h[i][j])
+            results[partition[j]] = dataset
+    return results

+ 40 - 23
magic_pdf/data/dataset.py

@@ -97,10 +97,10 @@ class Dataset(ABC):
 
     @abstractmethod
     def dump_to_file(self, file_path: str):
-        """Dump the file
+        """Dump the file.
 
-        Args: 
-            file_path (str): the file path 
+        Args:
+            file_path (str): the file path
         """
         pass
 
@@ -119,7 +119,7 @@ class Dataset(ABC):
 
     @abstractmethod
     def classify(self) -> SupportedPdfParseMethod:
-        """classify the dataset 
+        """classify the dataset.
 
         Returns:
             SupportedPdfParseMethod: _description_
@@ -128,8 +128,7 @@ class Dataset(ABC):
 
     @abstractmethod
     def clone(self):
-        """clone this dataset
-        """
+        """clone this dataset."""
         pass
 
 
@@ -148,12 +147,14 @@ class PymuDocDataset(Dataset):
         if lang == '':
             self._lang = None
         elif lang == 'auto':
-            from magic_pdf.model.sub_modules.language_detection.utils import auto_detect_lang
+            from magic_pdf.model.sub_modules.language_detection.utils import \
+                auto_detect_lang
             self._lang = auto_detect_lang(bits)
-            logger.info(f"lang: {lang}, detect_lang: {self._lang}")
+            logger.info(f'lang: {lang}, detect_lang: {self._lang}')
         else:
             self._lang = lang
-            logger.info(f"lang: {lang}")
+            logger.info(f'lang: {lang}')
+
     def __len__(self) -> int:
         """The page number of the pdf."""
         return len(self._records)
@@ -186,12 +187,12 @@ class PymuDocDataset(Dataset):
         return self._records[page_id]
 
     def dump_to_file(self, file_path: str):
-        """Dump the file
+        """Dump the file.
 
-        Args: 
-            file_path (str): the file path 
+        Args:
+            file_path (str): the file path
         """
-        
+
         dir_name = os.path.dirname(file_path)
         if dir_name not in ('', '.', '..'):
             os.makedirs(dir_name, exist_ok=True)
@@ -212,7 +213,7 @@ class PymuDocDataset(Dataset):
         return proc(self, *args, **kwargs)
 
     def classify(self) -> SupportedPdfParseMethod:
-        """classify the dataset 
+        """classify the dataset.
 
         Returns:
             SupportedPdfParseMethod: _description_
@@ -220,10 +221,12 @@ class PymuDocDataset(Dataset):
         return classify(self._data_bits)
 
     def clone(self):
-        """clone this dataset
-        """
+        """clone this dataset."""
         return PymuDocDataset(self._raw_data)
 
+    def set_images(self, images):
+        for i in range(len(self._records)):
+            self._records[i].set_image(images[i])
 
 class ImageDataset(Dataset):
     def __init__(self, bits: bytes):
@@ -270,10 +273,10 @@ class ImageDataset(Dataset):
         return self._records[page_id]
 
     def dump_to_file(self, file_path: str):
-        """Dump the file
+        """Dump the file.
 
-        Args: 
-            file_path (str): the file path 
+        Args:
+            file_path (str): the file path
         """
         dir_name = os.path.dirname(file_path)
         if dir_name not in ('', '.', '..'):
@@ -293,7 +296,7 @@ class ImageDataset(Dataset):
         return proc(self, *args, **kwargs)
 
     def classify(self) -> SupportedPdfParseMethod:
-        """classify the dataset 
+        """classify the dataset.
 
         Returns:
             SupportedPdfParseMethod: _description_
@@ -301,15 +304,19 @@ class ImageDataset(Dataset):
         return SupportedPdfParseMethod.OCR
 
     def clone(self):
-        """clone this dataset
-        """
+        """clone this dataset."""
         return ImageDataset(self._raw_data)
 
+    def set_images(self, images):
+        for i in range(len(self._records)):
+            self._records[i].set_image(images[i])
+
 class Doc(PageableData):
     """Initialized with pymudoc object."""
 
     def __init__(self, doc: fitz.Page):
         self._doc = doc
+        self._img = None
 
     def get_image(self):
         """Return the image info.
@@ -321,7 +328,17 @@ class Doc(PageableData):
                 height: int
             }
         """
-        return fitz_doc_to_image(self._doc)
+        if self._img is None:
+            self._img = fitz_doc_to_image(self._doc)
+        return self._img
+
+    def set_image(self, img):
+        """
+        Args:
+            img (np.ndarray): the image
+        """
+        if self._img is None:
+            self._img = img
 
     def get_doc(self) -> fitz.Page:
         """Get the pymudoc object.

+ 103 - 0
magic_pdf/data/utils.py

@@ -1,4 +1,9 @@
 
+import multiprocessing as mp
+import threading
+from concurrent.futures import (ProcessPoolExecutor, ThreadPoolExecutor,
+                                as_completed)
+
 import fitz
 import numpy as np
 from loguru import logger
@@ -65,3 +70,101 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
 
             images.append(img_dict)
     return images
+
+
+def convert_page(bytes_page):
+    pdfs = fitz.open('pdf', bytes_page)
+    page = pdfs[0]
+    return fitz_doc_to_image(page)
+
+def parallel_process_pdf_safe(pages, num_workers=None, **kwargs):
+    """Process PDF pages in parallel with serialization-safe approach."""
+    if num_workers is None:
+        num_workers = mp.cpu_count()
+
+
+    # Process the extracted page data in parallel
+    with ProcessPoolExecutor(max_workers=num_workers) as executor:
+        # Process the page data
+        results = list(
+            executor.map(convert_page, pages)
+        )
+
+    return results
+
+
+def threaded_process_pdf(pdf_path, num_threads=4, **kwargs):
+    """Process all pages of a PDF using multiple threads.
+
+    Parameters:
+    -----------
+    pdf_path : str
+        Path to the PDF file
+    num_threads : int
+        Number of threads to use
+    **kwargs :
+        Additional arguments for fitz_doc_to_image
+
+    Returns:
+    --------
+    images : list
+        List of processed images, in page order
+    """
+    # Open the PDF
+    doc = fitz.open(pdf_path)
+    num_pages = len(doc)
+
+    # Create a list to store results in the correct order
+    results = [None] * num_pages
+
+    # Create a thread pool
+    with ThreadPoolExecutor(max_workers=num_threads) as executor:
+        # Submit all tasks
+        futures = {}
+        for page_num in range(num_pages):
+            page = doc[page_num]
+            future = executor.submit(fitz_doc_to_image, page, **kwargs)
+            futures[future] = page_num
+        # Process results as they complete with progress bar
+        for future in as_completed(futures):
+            page_num = futures[future]
+            try:
+                results[page_num] = future.result()
+            except Exception as e:
+                print(f'Error processing page {page_num}: {e}')
+                results[page_num] = None
+
+    # Close the document
+    doc.close()
+
+if __name__ == '__main__':
+    pdf = fitz.open('/tmp/[MS-DOC].pdf')
+
+
+    pdf_page = [fitz.open() for i in range(pdf.page_count)]
+    [pdf_page[i].insert_pdf(pdf, from_page=i, to_page=i) for i in range(pdf.page_count)]
+
+    pdf_page = [v.tobytes() for v in pdf_page]
+    results = parallel_process_pdf_safe(pdf_page, num_workers=16)
+
+    # threaded_process_pdf('/tmp/[MS-DOC].pdf', num_threads=16)
+
+    """ benchmark results of multi-threaded processing (fitz page to image)
+    total page nums: 578
+    thread nums,    time cost
+    1               7.351 sec
+    2               6.334 sec
+    4               5.968 sec
+    8               6.728 sec
+    16              8.085 sec
+    """
+
+    """ benchmark results of multi-processor processing (fitz page to image)
+    total page nums: 578
+    processor nums,    time cost
+    1                  17.170 sec
+    2                  10.170 sec
+    4                  7.841 sec
+    8                  7.900 sec
+    16                 7.984 sec
+    """

+ 179 - 22
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -1,18 +1,19 @@
+import concurrent.futures as fut
+import multiprocessing as mp
 import os
 import time
+
+import numpy as np
 import torch
 
 os.environ['FLAGS_npu_jit_compile'] = '0'  # 关闭paddle的jit编译
 os.environ['FLAGS_use_stride_kernel'] = '0'
 os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # 让mps可以fallback
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
-# 关闭paddle的信号处理
-import paddle
-paddle.disable_signal_handler()
+
 
 from loguru import logger
 
-from magic_pdf.model.batch_analyze import BatchAnalyze
 from magic_pdf.model.sub_modules.model_utils import get_vram
 
 try:
@@ -30,8 +31,10 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
                                           get_local_models_dir,
                                           get_table_recog_config)
 from magic_pdf.model.model_list import MODEL
-from magic_pdf.operators.models import InferenceResult
 
+# from magic_pdf.operators.models import InferenceResult
+
+MIN_BATCH_INFERENCE_SIZE = 100
 
 class ModelSingleton:
     _instance = None
@@ -72,9 +75,7 @@ def custom_model_init(
     formula_enable=None,
     table_enable=None,
 ):
-
     model = None
-
     if model_config.__model_mode__ == 'lite':
         logger.warning(
             'The Lite mode is provided for developers to conduct testing only, and the output quality is '
@@ -132,7 +133,6 @@ def custom_model_init(
 
     return custom_model
 
-
 def doc_analyze(
     dataset: Dataset,
     ocr: bool = False,
@@ -143,13 +143,165 @@ def doc_analyze(
     layout_model=None,
     formula_enable=None,
     table_enable=None,
-) -> InferenceResult:
-
+    one_shot: bool = True,
+):
     end_page_id = (
         end_page_id
         if end_page_id is not None and end_page_id >= 0
         else len(dataset) - 1
     )
+    parallel_count = None
+    if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
+        parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
+
+    images = []
+    page_wh_list = []
+    for index in range(len(dataset)):
+        if start_page_id <= index <= end_page_id:
+            page_data = dataset.get_page(index)
+            img_dict = page_data.get_image()
+            images.append(img_dict['img'])
+            page_wh_list.append((img_dict['width'], img_dict['height']))
+
+    if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
+        if parallel_count is None:
+            parallel_count = 2 # should check the gpu memory firstly !
+        # split images into parallel_count batches
+        if parallel_count > 1:
+            batch_size = (len(images) + parallel_count - 1) // parallel_count
+            batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
+        else:
+            batch_images = [images]
+        results = []
+        parallel_count = len(batch_images) # adjust to real parallel count
+        # using concurrent.futures to analyze
+        """
+        with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
+            futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
+            for future in fut.as_completed(futures):
+                sn, result = future.result()
+                result_history[sn] = result
+
+        for key in sorted(result_history.keys()):
+            results.extend(result_history[key])
+        """
+        results = []
+        pool = mp.Pool(processes=parallel_count)
+        mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
+        for sn, result in mapped_results:
+            results.extend(result)
+
+    else:
+        _, results = may_batch_image_analyze(
+            images,
+            0,
+            ocr,
+            show_log,
+            lang, layout_model, formula_enable, table_enable)
+
+    model_json = []
+    for index in range(len(dataset)):
+        if start_page_id <= index <= end_page_id:
+            result = results.pop(0)
+            page_width, page_height = page_wh_list.pop(0)
+        else:
+            result = []
+            page_height = 0
+            page_width = 0
+
+        page_info = {'page_no': index, 'width': page_width, 'height': page_height}
+        page_dict = {'layout_dets': result, 'page_info': page_info}
+        model_json.append(page_dict)
+
+    from magic_pdf.operators.models import InferenceResult
+    return InferenceResult(model_json, dataset)
+
+def batch_doc_analyze(
+    datasets: list[Dataset],
+    ocr: bool = False,
+    show_log: bool = False,
+    lang=None,
+    layout_model=None,
+    formula_enable=None,
+    table_enable=None,
+    one_shot: bool = True,
+):
+    parallel_count = None
+    if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
+        parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
+    images = []
+    page_wh_list = []
+    for dataset in datasets:
+        for index in range(len(dataset)):
+            page_data = dataset.get_page(index)
+            img_dict = page_data.get_image()
+            images.append(img_dict['img'])
+            page_wh_list.append((img_dict['width'], img_dict['height']))
+
+    if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
+        if parallel_count is None:
+            parallel_count = 2 # should check the gpu memory firstly !
+        # split images into parallel_count batches
+        if parallel_count > 1:
+            batch_size = (len(images) + parallel_count - 1) // parallel_count
+            batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
+        else:
+            batch_images = [images]
+        results = []
+        parallel_count = len(batch_images) # adjust to real parallel count
+        # using concurrent.futures to analyze
+        """
+        with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
+            futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
+            for future in fut.as_completed(futures):
+                sn, result = future.result()
+                result_history[sn] = result
+
+        for key in sorted(result_history.keys()):
+            results.extend(result_history[key])
+        """
+        results = []
+        pool = mp.Pool(processes=parallel_count)
+        mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
+        for sn, result in mapped_results:
+            results.extend(result)
+    else:
+        _, results = may_batch_image_analyze(
+            images,
+            0,
+            ocr,
+            show_log,
+            lang, layout_model, formula_enable, table_enable)
+    infer_results = []
+
+    from magic_pdf.operators.models import InferenceResult
+    for index in range(len(datasets)):
+        dataset = datasets[index]
+        model_json = []
+        for i in range(len(dataset)):
+            result = results.pop(0)
+            page_width, page_height = page_wh_list.pop(0)
+            page_info = {'page_no': i, 'width': page_width, 'height': page_height}
+            page_dict = {'layout_dets': result, 'page_info': page_info}
+            model_json.append(page_dict)
+        infer_results.append(InferenceResult(model_json, dataset))
+    return infer_results
+
+
+def may_batch_image_analyze(
+        images: list[np.ndarray],
+        idx: int,
+        ocr: bool = False,
+        show_log: bool = False,
+        lang=None,
+        layout_model=None,
+        formula_enable=None,
+        table_enable=None):
+    # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
+    # 关闭paddle的信号处理
+    import paddle
+    paddle.disable_signal_handler()
+    from magic_pdf.model.batch_analyze import BatchAnalyze
 
     model_manager = ModelSingleton()
     custom_model = model_manager.get_model(
@@ -161,14 +313,14 @@ def doc_analyze(
     device = get_device()
 
     npu_support = False
-    if str(device).startswith("npu"):
+    if str(device).startswith('npu'):
         import torch_npu
         if torch_npu.npu.is_available():
             npu_support = True
             torch.npu.set_compile_mode(jit_compile=False)
 
     if torch.cuda.is_available() and device != 'cpu' or npu_support:
-        gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
+        gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
         if gpu_memory is not None and gpu_memory >= 8:
             if gpu_memory >= 20:
                 batch_ratio = 16
@@ -181,12 +333,10 @@ def doc_analyze(
 
             logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
             batch_analyze = True
-
-    model_json = []
     doc_analyze_start = time.time()
 
     if batch_analyze:
-        # batch analyze
+        """# batch analyze
         images = []
         page_wh_list = []
         for index in range(len(dataset)):
@@ -195,9 +345,10 @@ def doc_analyze(
                 img_dict = page_data.get_image()
                 images.append(img_dict['img'])
                 page_wh_list.append((img_dict['width'], img_dict['height']))
+        """
         batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
-        analyze_result = batch_model(images)
-
+        results = batch_model(images)
+        """
         for index in range(len(dataset)):
             if start_page_id <= index <= end_page_id:
                 result = analyze_result.pop(0)
@@ -210,10 +361,10 @@ def doc_analyze(
             page_info = {'page_no': index, 'width': page_width, 'height': page_height}
             page_dict = {'layout_dets': result, 'page_info': page_info}
             model_json.append(page_dict)
-
+        """
     else:
         # single analyze
-
+        """
         for index in range(len(dataset)):
             page_data = dataset.get_page(index)
             img_dict = page_data.get_image()
@@ -230,6 +381,13 @@ def doc_analyze(
             page_info = {'page_no': index, 'width': page_width, 'height': page_height}
             page_dict = {'layout_dets': result, 'page_info': page_info}
             model_json.append(page_dict)
+        """
+        results = []
+        for img_idx, img in enumerate(images):
+            inference_start = time.time()
+            result = custom_model(img)
+            logger.info(f'-----image index : {img_idx}, image inference total time: {round(time.time() - inference_start, 2)}-----')
+            results.append(result)
 
     gc_start = time.time()
     clean_memory(get_device())
@@ -237,10 +395,9 @@ def doc_analyze(
     logger.info(f'gc time: {gc_time}')
 
     doc_analyze_time = round(time.time() - doc_analyze_start, 2)
-    doc_analyze_speed = round((end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
+    doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
     logger.info(
         f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
         f' speed: {doc_analyze_speed} pages/second'
     )
-
-    return InferenceResult(model_json, dataset)
+    return (idx, results)

+ 28 - 16
magic_pdf/model/sub_modules/model_init.py

@@ -1,18 +1,27 @@
+import os
+
 import torch
 from loguru import logger
 
 from magic_pdf.config.constants import MODEL_NAME
 from magic_pdf.model.model_list import AtomicModel
-from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
-from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
-from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
+from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import \
+    YOLOv11LangDetModel
+from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
+    DocLayoutYOLOModel
+from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
+    Layoutlmv3_Predictor
 from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
 from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
 
 try:
-    from magic_pdf_ascend_plugin.libs.license_verifier import load_license, LicenseFormatError, LicenseSignatureError, LicenseExpiredError
-    from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
-    from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
+    from magic_pdf_ascend_plugin.libs.license_verifier import (
+        LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
+        load_license)
+    from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import \
+        ModifiedPaddleOCR
+    from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import \
+        RapidTableModel
     license_key = load_license()
     logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
                 f' License expired at {license_key["payload"]["date"]["end_date"]}')
@@ -20,21 +29,24 @@ except Exception as e:
     if isinstance(e, ImportError):
         pass
     elif isinstance(e, LicenseFormatError):
-        logger.error("Ascend Plugin: Invalid license format. Please check the license file.")
+        logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
     elif isinstance(e, LicenseSignatureError):
-        logger.error("Ascend Plugin: Invalid signature. The license may be tampered with.")
+        logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
     elif isinstance(e, LicenseExpiredError):
-        logger.error("Ascend Plugin: License has expired. Please renew your license.")
+        logger.error('Ascend Plugin: License has expired. Please renew your license.')
     elif isinstance(e, FileNotFoundError):
-        logger.error("Ascend Plugin: Not found License file.")
+        logger.error('Ascend Plugin: Not found License file.')
     else:
-        logger.error(f"Ascend Plugin: {e}")
+        logger.error(f'Ascend Plugin: {e}')
     from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
     # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
     from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
 
-from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
-from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
+from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
+    StructTableModel
+from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
+    TableMasterPaddleModel
+
 
 def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
     if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
@@ -55,7 +67,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
 
 
 def mfd_model_init(weight, device='cpu'):
-    if str(device).startswith("npu"):
+    if str(device).startswith('npu'):
         device = torch.device(device)
     mfd_model = YOLOv8MFDModel(weight, device)
     return mfd_model
@@ -72,14 +84,14 @@ def layout_model_init(weight, config_file, device):
 
 
 def doclayout_yolo_model_init(weight, device='cpu'):
-    if str(device).startswith("npu"):
+    if str(device).startswith('npu'):
         device = torch.device(device)
     model = DocLayoutYOLOModel(weight, device)
     return model
 
 
 def langdetect_model_init(langdetect_model_weight, device='cpu'):
-    if str(device).startswith("npu"):
+    if str(device).startswith('npu'):
         device = torch.device(device)
     model = YOLOv11LangDetModel(langdetect_model_weight, device)
     return model

+ 1 - 0
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py

@@ -5,6 +5,7 @@ import cv2
 import numpy as np
 import torch
 
+
 from paddleocr import PaddleOCR
 from ppocr.utils.logging import get_logger
 from ppocr.utils.utility import alpha_to_color, binarize_img

+ 1 - 0
magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py

@@ -2,6 +2,7 @@ import os
 
 import cv2
 import numpy as np
+from paddleocr import PaddleOCR
 from ppstructure.table.predict_table import TableSystem
 from ppstructure.utility import init_args
 from PIL import Image

+ 19 - 10
magic_pdf/tools/cli.py

@@ -1,15 +1,18 @@
 import os
 import shutil
 import tempfile
+from pathlib import Path
+
 import click
 import fitz
 from loguru import logger
-from pathlib import Path
 
 import magic_pdf.model as model_config
+from magic_pdf.data.batch_build_dataset import batch_build_dataset
 from magic_pdf.data.data_reader_writer import FileBasedDataReader
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.libs.version import __version__
-from magic_pdf.tools.common import do_parse, parse_pdf_methods
+from magic_pdf.tools.common import batch_do_parse, do_parse, parse_pdf_methods
 from magic_pdf.utils.office_to_pdf import convert_file_to_pdf
 
 pdf_suffixes = ['.pdf']
@@ -94,30 +97,33 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
     def read_fn(path: Path):
         if path.suffix in ms_office_suffixes:
             convert_file_to_pdf(str(path), temp_dir)
-            fn = os.path.join(temp_dir, f"{path.stem}.pdf")
+            fn = os.path.join(temp_dir, f'{path.stem}.pdf')
         elif path.suffix in image_suffixes:
             with open(str(path), 'rb') as f:
                 bits = f.read()
             pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
-            fn = os.path.join(temp_dir, f"{path.stem}.pdf")
+            fn = os.path.join(temp_dir, f'{path.stem}.pdf')
             with open(fn, 'wb') as f:
                 f.write(pdf_bytes)
         elif path.suffix in pdf_suffixes:
             fn = str(path)
         else:
-            raise Exception(f"Unknown file suffix: {path.suffix}")
-        
+            raise Exception(f'Unknown file suffix: {path.suffix}')
+
         disk_rw = FileBasedDataReader(os.path.dirname(fn))
         return disk_rw.read(os.path.basename(fn))
 
-    def parse_doc(doc_path: Path):
+    def parse_doc(doc_path: Path, dataset: Dataset | None = None):
         try:
             file_name = str(Path(doc_path).stem)
-            pdf_data = read_fn(doc_path)
+            if dataset is None:
+                pdf_data_or_dataset = read_fn(doc_path)
+            else:
+                pdf_data_or_dataset = dataset
             do_parse(
                 output_dir,
                 file_name,
-                pdf_data,
+                pdf_data_or_dataset,
                 [],
                 method,
                 debug_able,
@@ -130,9 +136,12 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
             logger.exception(e)
 
     if os.path.isdir(path):
+        doc_paths = []
         for doc_path in Path(path).glob('*'):
             if doc_path.suffix in pdf_suffixes + image_suffixes + ms_office_suffixes:
-                parse_doc(doc_path)
+                doc_paths.append(doc_path)
+        datasets = batch_build_dataset(doc_paths, 4, lang)
+        batch_do_parse(output_dir, [str(doc_path.stem) for doc_path in doc_paths], datasets, method, debug_able, lang=lang)
     else:
         parse_doc(Path(path))
 

+ 88 - 11
magic_pdf/tools/common.py

@@ -8,10 +8,10 @@ import magic_pdf.model as model_config
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import FileBasedDataWriter
-from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.data.dataset import Dataset, PymuDocDataset
 from magic_pdf.libs.draw_bbox import draw_char_bbox
-from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
-from magic_pdf.operators.models import InferenceResult
+from magic_pdf.model.doc_analyze_by_custom_model import (batch_doc_analyze,
+                                                         doc_analyze)
 
 # from io import BytesIO
 # from pypdf import PdfReader, PdfWriter
@@ -67,10 +67,10 @@ def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_i
     return output_bytes
 
 
-def do_parse(
+def _do_parse(
     output_dir,
     pdf_file_name,
-    pdf_bytes,
+    pdf_bytes_or_dataset,
     model_list,
     parse_method,
     debug_able,
@@ -92,16 +92,21 @@ def do_parse(
     formula_enable=None,
     table_enable=None,
 ):
+    from magic_pdf.operators.models import InferenceResult
     if debug_able:
         logger.warning('debug mode is on')
         f_draw_model_bbox = True
         f_draw_line_sort_bbox = True
         # f_draw_char_bbox = True
 
-    pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
-        pdf_bytes, start_page_id, end_page_id
-    )
-
+    if isinstance(pdf_bytes_or_dataset, bytes):
+        pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
+            pdf_bytes_or_dataset, start_page_id, end_page_id
+        )
+        ds = PymuDocDataset(pdf_bytes, lang=lang)
+    else:
+        ds = pdf_bytes_or_dataset
+    pdf_bytes = ds._raw_data
     local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
 
     image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
@@ -109,8 +114,6 @@ def do_parse(
     )
     image_dir = str(os.path.basename(local_image_dir))
 
-    ds = PymuDocDataset(pdf_bytes, lang=lang)
-
     if len(model_list) == 0:
         if model_config.__use_inside_model__:
             if parse_method == 'auto':
@@ -241,5 +244,79 @@ def do_parse(
 
     logger.info(f'local output dir is {local_md_dir}')
 
+def do_parse(
+    output_dir,
+    pdf_file_name,
+    pdf_bytes_or_dataset,
+    model_list,
+    parse_method,
+    debug_able,
+    f_draw_span_bbox=True,
+    f_draw_layout_bbox=True,
+    f_dump_md=True,
+    f_dump_middle_json=True,
+    f_dump_model_json=True,
+    f_dump_orig_pdf=True,
+    f_dump_content_list=True,
+    f_make_md_mode=MakeMode.MM_MD,
+    f_draw_model_bbox=False,
+    f_draw_line_sort_bbox=False,
+    f_draw_char_bbox=False,
+    start_page_id=0,
+    end_page_id=None,
+    lang=None,
+    layout_model=None,
+    formula_enable=None,
+    table_enable=None,
+):
+    parallel_count = 1
+    if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
+        parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
+
+    if parallel_count > 1:
+        if isinstance(pdf_bytes_or_dataset, bytes):
+            pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
+                pdf_bytes_or_dataset, start_page_id, end_page_id
+            )
+            ds = PymuDocDataset(pdf_bytes, lang=lang)
+        else:
+            ds = pdf_bytes_or_dataset
+        batch_do_parse(output_dir, [pdf_file_name], [ds], parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
+    else:
+        _do_parse(output_dir, pdf_file_name, pdf_bytes_or_dataset, model_list, parse_method, debug_able, start_page_id=start_page_id, end_page_id=end_page_id, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable,  f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
+
+
+def batch_do_parse(
+    output_dir,
+    pdf_file_names: list[str],
+    pdf_bytes_or_datasets: list[bytes | Dataset],
+    parse_method,
+    debug_able,
+    f_draw_span_bbox=True,
+    f_draw_layout_bbox=True,
+    f_dump_md=True,
+    f_dump_middle_json=True,
+    f_dump_model_json=True,
+    f_dump_orig_pdf=True,
+    f_dump_content_list=True,
+    f_make_md_mode=MakeMode.MM_MD,
+    f_draw_model_bbox=False,
+    f_draw_line_sort_bbox=False,
+    f_draw_char_bbox=False,
+    lang=None,
+    layout_model=None,
+    formula_enable=None,
+    table_enable=None,
+):
+    dss = []
+    for v in pdf_bytes_or_datasets:
+        if isinstance(v, bytes):
+            dss.append(PymuDocDataset(v, lang=lang))
+        else:
+            dss.append(v)
+    infer_results = batch_doc_analyze(dss, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable, one_shot=True)
+    for idx, infer_result in enumerate(infer_results):
+        _do_parse(output_dir, pdf_file_names[idx], dss[idx], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
+
 
 parse_pdf_methods = click.Choice(['ocr', 'txt', 'auto'])