Sfoglia il codice sorgente

feat: add parallel evalution

icecraft 8 mesi fa
parent
commit
3a2f86a1b6

+ 161 - 0
magic_pdf/data/batch_build_dataset.py

@@ -0,0 +1,161 @@
+import os
+import glob
+import threading
+import concurrent.futures
+import fitz
+from magic_pdf.data.utils import fitz_doc_to_image  # PyMuPDF
+from magic_pdf.data.dataset import PymuDocDataset
+
+def partition_array_greedy(arr, k):
+    """
+    Partition an array into k parts using a simple greedy approach.
+    
+    Parameters:
+    -----------
+    arr : list
+        The input array of integers
+    k : int
+        Number of partitions to create
+        
+    Returns:
+    --------
+    partitions : list of lists
+        The k partitions of the array
+    """
+    # Handle edge cases
+    if k <= 0:
+        raise ValueError("k must be a positive integer")
+    if k > len(arr):
+        k = len(arr)  # Adjust k if it's too large
+    if k == 1:
+        return [list(range(len(arr)))]
+    if k == len(arr):
+        return [[i] for i in range(len(arr))]
+    
+    # Sort the array in descending order
+    sorted_indices = sorted(range(len(arr)), key=lambda i: arr[i][1], reverse=True)
+    
+    # Initialize k empty partitions
+    partitions = [[] for _ in range(k)]
+    partition_sums = [0] * k
+    
+    # Assign each element to the partition with the smallest current sum
+    for idx in sorted_indices:
+        # Find the partition with the smallest sum
+        min_sum_idx = partition_sums.index(min(partition_sums))
+        
+        # Add the element to this partition
+        partitions[min_sum_idx].append(idx)  # Store the original index
+        partition_sums[min_sum_idx] += arr[idx][1]
+    
+    return partitions
+
+
+def process_pdf_batch(pdf_jobs, idx):
+    """
+    Process a batch of PDF pages using multiple threads.
+    
+    Parameters:
+    -----------
+    pdf_jobs : list of tuples
+        List of (pdf_path, page_num) tuples
+    output_dir : str or None
+        Directory to save images to
+    num_threads : int
+        Number of threads to use
+    **kwargs :
+        Additional arguments for process_pdf_page
+        
+    Returns:
+    --------
+    images : list
+        List of processed images
+    """
+    images = []
+    
+    for pdf_path, _ in pdf_jobs:
+        doc = fitz.open(pdf_path)
+        tmp = []
+        for page_num in range(len(doc)):
+            page = doc[page_num]
+            tmp.append(fitz_doc_to_image(page))
+        images.append(tmp)
+    return (idx, images)
+
+def batch_build_dataset(pdf_paths, k, lang=None):
+    """
+    Process multiple PDFs by partitioning them into k balanced parts and processing each part in parallel.
+    
+    Parameters:
+    -----------
+    pdf_paths : list
+        List of paths to PDF files
+    k : int
+        Number of partitions to create
+    output_dir : str or None
+        Directory to save images to
+    threads_per_worker : int
+        Number of threads to use per worker
+    **kwargs :
+        Additional arguments for process_pdf_page
+        
+    Returns:
+    --------
+    all_images : list
+        List of all processed images
+    """
+    # Get page counts for each PDF
+    pdf_info = []
+    total_pages = 0
+    
+    for pdf_path in pdf_paths:
+        try:
+            doc = fitz.open(pdf_path)
+            num_pages = len(doc)
+            pdf_info.append((pdf_path, num_pages))
+            total_pages += num_pages
+            doc.close()
+        except Exception as e:
+            print(f"Error opening {pdf_path}: {e}")
+    
+    # Partition the jobs based on page countEach job has 1 page
+    partitions = partition_array_greedy(pdf_info, k)
+
+    for i, partition in enumerate(partitions):
+        print(f"Partition {i+1}: {len(partition)} pdfs")
+    
+    # Process each partition in parallel
+    all_images_h = {}
+    
+    with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor:
+        # Submit one task per partition
+        futures = []
+        for sn, partition in enumerate(partitions):
+            # Get the jobs for this partition
+            partition_jobs = [pdf_info[idx] for idx in partition]
+            
+            # Submit the task
+            future = executor.submit(
+                process_pdf_batch,
+                partition_jobs,
+                sn
+            )
+            futures.append(future)
+        # Process results as they complete
+        for i, future in enumerate(concurrent.futures.as_completed(futures)):
+            try:
+                idx, images = future.result()
+                print(f"Partition {i+1} completed: processed {len(images)} images")
+                all_images_h[idx] = images
+            except Exception as e:
+                print(f"Error processing partition: {e}")
+    results = [None] * len(pdf_paths)
+    for i in range(len(partitions)):
+        partition = partitions[i]
+        for j in range(len(partition)):
+            with open(pdf_info[partition[j]][0], "rb") as f:
+                pdf_bytes = f.read()
+            dataset = PymuDocDataset(pdf_bytes, lang=lang)
+            dataset.set_images(all_images_h[i][j])
+            results[partition[j]] = dataset
+    return results

+ 20 - 1
magic_pdf/data/dataset.py

@@ -154,6 +154,7 @@ class PymuDocDataset(Dataset):
         else:
             self._lang = lang
             logger.info(f"lang: {lang}")
+
     def __len__(self) -> int:
         """The page number of the pdf."""
         return len(self._records)
@@ -224,6 +225,9 @@ class PymuDocDataset(Dataset):
         """
         return PymuDocDataset(self._raw_data)
 
+    def set_images(self, images):
+        for i in range(len(self._records)):
+            self._records[i].set_image(images[i])
 
 class ImageDataset(Dataset):
     def __init__(self, bits: bytes):
@@ -304,12 +308,17 @@ class ImageDataset(Dataset):
         """clone this dataset
         """
         return ImageDataset(self._raw_data)
+    
+    def set_images(self, images):
+        for i in range(len(self._records)):
+            self._records[i].set_image(images[i])
 
 class Doc(PageableData):
     """Initialized with pymudoc object."""
 
     def __init__(self, doc: fitz.Page):
         self._doc = doc
+        self._img = None
 
     def get_image(self):
         """Return the image info.
@@ -321,7 +330,17 @@ class Doc(PageableData):
                 height: int
             }
         """
-        return fitz_doc_to_image(self._doc)
+        if self._img is None:
+            self._img = fitz_doc_to_image(self._doc)
+        return self._img
+
+    def set_image(self, img):
+        """
+        Args:
+            img (np.ndarray): the image
+        """
+        if self._img is None:
+            self._img = img
 
     def get_doc(self) -> fitz.Page:
         """Get the pymudoc object.

+ 105 - 0
magic_pdf/data/utils.py

@@ -1,9 +1,12 @@
 
+import multiprocessing as mp
+import threading
 import fitz
 import numpy as np
 from loguru import logger
 
 from magic_pdf.utils.annotations import ImportPIL
+from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
 
 
 @ImportPIL
@@ -65,3 +68,105 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
 
             images.append(img_dict)
     return images
+
+    
+def convert_page(bytes_page):
+    pdfs = fitz.open('pdf', bytes_page)
+    page = pdfs[0]
+    return fitz_doc_to_image(page)
+    
+def parallel_process_pdf_safe(pages, num_workers=None, **kwargs):
+    """Process PDF pages in parallel with serialization-safe approach"""
+    if num_workers is None:
+        num_workers = mp.cpu_count()
+    
+
+    # Process the extracted page data in parallel
+    with ProcessPoolExecutor(max_workers=num_workers) as executor:
+        # Process the page data
+        results = list(
+            executor.map(convert_page, pages)
+        )
+    
+    return results
+
+
+def threaded_process_pdf(pdf_path, num_threads=4, **kwargs):
+    """
+    Process all pages of a PDF using multiple threads
+    
+    Parameters:
+    -----------
+    pdf_path : str
+        Path to the PDF file
+    num_threads : int
+        Number of threads to use
+    **kwargs :
+        Additional arguments for fitz_doc_to_image
+        
+    Returns:
+    --------
+    images : list
+        List of processed images, in page order
+    """
+    # Open the PDF
+    doc = fitz.open(pdf_path)
+    num_pages = len(doc)
+    
+    # Create a list to store results in the correct order
+    results = [None] * num_pages
+    
+    # Create a thread pool
+    with ThreadPoolExecutor(max_workers=num_threads) as executor:
+        # Submit all tasks
+        futures = {}
+        for page_num in range(num_pages):
+            page = doc[page_num]
+            future = executor.submit(fitz_doc_to_image, page, **kwargs)
+            futures[future] = page_num
+        # Process results as they complete with progress bar
+        for future in as_completed(futures):
+            page_num = futures[future]
+            try:
+                results[page_num] = future.result()
+            except Exception as e:
+                print(f"Error processing page {page_num}: {e}")
+                results[page_num] = None
+    
+    # Close the document
+    doc.close()
+
+if __name__ == "__main__":
+    pdf = fitz.open('/tmp/[MS-DOC].pdf')
+    
+    
+    pdf_page = [fitz.open() for i in range(pdf.page_count)]
+    [pdf_page[i].insert_pdf(pdf, from_page=i, to_page=i) for i in range(pdf.page_count)]
+ 
+    pdf_page = [v.tobytes() for v in pdf_page]
+    results = parallel_process_pdf_safe(pdf_page, num_workers=16)
+    
+    # threaded_process_pdf('/tmp/[MS-DOC].pdf', num_threads=16)
+
+    """ benchmark results of multi-threaded processing (fitz page to image)
+    total page nums: 578 
+    thread nums,    time cost 
+    1               7.351 sec
+    2               6.334 sec
+    4               5.968 sec
+    8               6.728 sec
+    16              8.085 sec
+    """
+
+    """ benchmark results of multi-processor processing (fitz page to image)
+    total page nums: 578 
+    processor nums,    time cost 
+    1                  17.170 sec
+    2                  10.170 sec 
+    4                  7.841 sec 
+    8                  7.900 sec
+    16                 7.984 sec
+    """
+
+
+

+ 175 - 20
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -1,18 +1,17 @@
 import os
 import time
 import torch
-
+import numpy as np
+import multiprocessing as mp
+import concurrent.futures as fut
 os.environ['FLAGS_npu_jit_compile'] = '0'  # 关闭paddle的jit编译
 os.environ['FLAGS_use_stride_kernel'] = '0'
 os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # 让mps可以fallback
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
-# 关闭paddle的信号处理
-import paddle
-paddle.disable_signal_handler()
+
 
 from loguru import logger
 
-from magic_pdf.model.batch_analyze import BatchAnalyze
 from magic_pdf.model.sub_modules.model_utils import get_vram
 
 try:
@@ -30,8 +29,9 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
                                           get_local_models_dir,
                                           get_table_recog_config)
 from magic_pdf.model.model_list import MODEL
-from magic_pdf.operators.models import InferenceResult
+# from magic_pdf.operators.models import InferenceResult
 
+MIN_BATCH_INFERENCE_SIZE = 100
 
 class ModelSingleton:
     _instance = None
@@ -72,9 +72,7 @@ def custom_model_init(
     formula_enable=None,
     table_enable=None,
 ):
-
     model = None
-
     if model_config.__model_mode__ == 'lite':
         logger.warning(
             'The Lite mode is provided for developers to conduct testing only, and the output quality is '
@@ -132,7 +130,6 @@ def custom_model_init(
 
     return custom_model
 
-
 def doc_analyze(
     dataset: Dataset,
     ocr: bool = False,
@@ -143,14 +140,166 @@ def doc_analyze(
     layout_model=None,
     formula_enable=None,
     table_enable=None,
-) -> InferenceResult:
-
+    one_shot: bool = True,
+):
     end_page_id = (
         end_page_id
         if end_page_id is not None and end_page_id >= 0
         else len(dataset) - 1
     )
+    parallel_count = None
+    if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
+        parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
+
+    images = []
+    page_wh_list = []
+    for index in range(len(dataset)):
+        if start_page_id <= index <= end_page_id:
+            page_data = dataset.get_page(index)
+            img_dict = page_data.get_image()
+            images.append(img_dict['img'])
+            page_wh_list.append((img_dict['width'], img_dict['height']))
+
+    if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
+        if parallel_count is None:
+            parallel_count = 2 # should check the gpu memory firstly !
+        # split images into parallel_count batches
+        if parallel_count > 1:
+            batch_size = (len(images) + parallel_count - 1) // parallel_count
+            batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
+        else:
+            batch_images = [images]
+        results = []
+        parallel_count = len(batch_images) # adjust to real parallel count 
+        # using concurrent.futures to analyze
+        """
+        with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
+            futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
+            for future in fut.as_completed(futures):
+                sn, result = future.result()
+                result_history[sn] = result
+
+        for key in sorted(result_history.keys()):
+            results.extend(result_history[key])
+        """
+        results = []
+        pool = mp.Pool(processes=parallel_count)
+        mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
+        for sn, result in mapped_results:
+            results.extend(result)
 
+    else:
+        _, results = may_batch_image_analyze(
+            images,
+            0,
+            ocr, 
+            show_log, 
+            lang, layout_model, formula_enable, table_enable)
+
+    model_json = []
+    for index in range(len(dataset)):
+        if start_page_id <= index <= end_page_id:
+            result = results.pop(0)
+            page_width, page_height = page_wh_list.pop(0)
+        else:
+            result = []
+            page_height = 0
+            page_width = 0
+
+        page_info = {'page_no': index, 'width': page_width, 'height': page_height}
+        page_dict = {'layout_dets': result, 'page_info': page_info}
+        model_json.append(page_dict)
+
+    from magic_pdf.operators.models import InferenceResult
+    return InferenceResult(model_json, dataset)
+
+def batch_doc_analyze(
+    datasets: list[Dataset],
+    ocr: bool = False,
+    show_log: bool = False,
+    lang=None,
+    layout_model=None,
+    formula_enable=None,
+    table_enable=None,
+    one_shot: bool = True,
+):
+    parallel_count = None
+    if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
+        parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
+    images = []
+    page_wh_list = []
+    for dataset in datasets:
+        for index in range(len(dataset)):
+            page_data = dataset.get_page(index)
+            img_dict = page_data.get_image()
+            images.append(img_dict['img'])
+            page_wh_list.append((img_dict['width'], img_dict['height']))
+    
+    if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
+        if parallel_count is None:
+            parallel_count = 2 # should check the gpu memory firstly !
+        # split images into parallel_count batches
+        if parallel_count > 1:
+            batch_size = (len(images) + parallel_count - 1) // parallel_count
+            batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
+        else:
+            batch_images = [images]
+        results = []
+        parallel_count = len(batch_images) # adjust to real parallel count 
+        # using concurrent.futures to analyze
+        """
+        with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
+            futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
+            for future in fut.as_completed(futures):
+                sn, result = future.result()
+                result_history[sn] = result
+
+        for key in sorted(result_history.keys()):
+            results.extend(result_history[key])
+        """
+        results = []
+        pool = mp.Pool(processes=parallel_count)
+        mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
+        for sn, result in mapped_results:
+            results.extend(result)
+    else:
+        _, results = may_batch_image_analyze(
+            images,
+            0,
+            ocr, 
+            show_log, 
+            lang, layout_model, formula_enable, table_enable)
+    infer_results = []
+
+    from magic_pdf.operators.models import InferenceResult
+    for index in range(len(datasets)):
+        dataset = datasets[index]
+        model_json = []
+        for i in range(len(dataset)):
+            result = results.pop(0)
+            page_width, page_height = page_wh_list.pop(0)
+            page_info = {'page_no': i, 'width': page_width, 'height': page_height}
+            page_dict = {'layout_dets': result, 'page_info': page_info}
+            model_json.append(page_dict)
+        infer_results.append(InferenceResult(model_json, dataset))
+    return infer_results
+
+
+def may_batch_image_analyze(
+        images: list[np.ndarray], 
+        idx: int,
+        ocr: bool = False, 
+        show_log: bool = False, 
+        lang=None, 
+        layout_model=None, 
+        formula_enable=None, 
+        table_enable=None):
+    # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
+    # 关闭paddle的信号处理
+    import paddle
+    paddle.disable_signal_handler()
+    from magic_pdf.model.batch_analyze import BatchAnalyze
+    
     model_manager = ModelSingleton()
     custom_model = model_manager.get_model(
         ocr, show_log, lang, layout_model, formula_enable, table_enable
@@ -181,12 +330,10 @@ def doc_analyze(
 
             logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
             batch_analyze = True
-
-    model_json = []
     doc_analyze_start = time.time()
 
     if batch_analyze:
-        # batch analyze
+        """# batch analyze
         images = []
         page_wh_list = []
         for index in range(len(dataset)):
@@ -195,9 +342,10 @@ def doc_analyze(
                 img_dict = page_data.get_image()
                 images.append(img_dict['img'])
                 page_wh_list.append((img_dict['width'], img_dict['height']))
+        """
         batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
-        analyze_result = batch_model(images)
-
+        results = batch_model(images)
+        """
         for index in range(len(dataset)):
             if start_page_id <= index <= end_page_id:
                 result = analyze_result.pop(0)
@@ -210,10 +358,10 @@ def doc_analyze(
             page_info = {'page_no': index, 'width': page_width, 'height': page_height}
             page_dict = {'layout_dets': result, 'page_info': page_info}
             model_json.append(page_dict)
-
+        """
     else:
         # single analyze
-
+        """
         for index in range(len(dataset)):
             page_data = dataset.get_page(index)
             img_dict = page_data.get_image()
@@ -230,6 +378,13 @@ def doc_analyze(
             page_info = {'page_no': index, 'width': page_width, 'height': page_height}
             page_dict = {'layout_dets': result, 'page_info': page_info}
             model_json.append(page_dict)
+        """
+        results = []
+        for img_idx, img in enumerate(images):
+            inference_start = time.time()
+            result = custom_model(img)
+            logger.info(f'-----image index : {img_idx}, image inference total time: {round(time.time() - inference_start, 2)}-----')
+            results.append(result)
 
     gc_start = time.time()
     clean_memory(get_device())
@@ -237,10 +392,10 @@ def doc_analyze(
     logger.info(f'gc time: {gc_time}')
 
     doc_analyze_time = round(time.time() - doc_analyze_start, 2)
-    doc_analyze_speed = round((end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
+    doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
     logger.info(
         f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
         f' speed: {doc_analyze_speed} pages/second'
     )
+    return (idx, results)
 
-    return InferenceResult(model_json, dataset)

+ 1 - 0
magic_pdf/model/sub_modules/model_init.py

@@ -1,3 +1,4 @@
+import os 
 import torch
 from loguru import logger
 

+ 14 - 5
magic_pdf/tools/cli.py

@@ -8,10 +8,13 @@ from pathlib import Path
 
 import magic_pdf.model as model_config
 from magic_pdf.data.data_reader_writer import FileBasedDataReader
+from magic_pdf.data.batch_build_dataset import batch_build_dataset
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.libs.version import __version__
-from magic_pdf.tools.common import do_parse, parse_pdf_methods
+from magic_pdf.tools.common import do_parse, parse_pdf_methods, batch_do_parse
 from magic_pdf.utils.office_to_pdf import convert_file_to_pdf
 
+
 pdf_suffixes = ['.pdf']
 ms_office_suffixes = ['.ppt', '.pptx', '.doc', '.docx']
 image_suffixes = ['.png', '.jpeg', '.jpg']
@@ -110,14 +113,17 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
         disk_rw = FileBasedDataReader(os.path.dirname(fn))
         return disk_rw.read(os.path.basename(fn))
 
-    def parse_doc(doc_path: Path):
+    def parse_doc(doc_path: Path, dataset: Dataset | None = None):
         try:
             file_name = str(Path(doc_path).stem)
-            pdf_data = read_fn(doc_path)
+            if dataset is None:
+                pdf_data_or_dataset = read_fn(doc_path)
+            else:
+                pdf_data_or_dataset = dataset
             do_parse(
                 output_dir,
                 file_name,
-                pdf_data,
+                pdf_data_or_dataset,
                 [],
                 method,
                 debug_able,
@@ -130,9 +136,12 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
             logger.exception(e)
 
     if os.path.isdir(path):
+        doc_paths = []
         for doc_path in Path(path).glob('*'):
             if doc_path.suffix in pdf_suffixes + image_suffixes + ms_office_suffixes:
-                parse_doc(doc_path)
+                doc_paths.append(doc_path)
+        datasets = batch_build_dataset(doc_paths, 4, lang)
+        batch_do_parse(output_dir, [str(doc_path.stem) for doc_path in doc_paths], datasets, method, debug_able, lang=lang)
     else:
         parse_doc(Path(path))
 

+ 88 - 11
magic_pdf/tools/common.py

@@ -8,10 +8,10 @@ import magic_pdf.model as model_config
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import FileBasedDataWriter
-from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.data.dataset import PymuDocDataset, Dataset
 from magic_pdf.libs.draw_bbox import draw_char_bbox
-from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
-from magic_pdf.operators.models import InferenceResult
+from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze, batch_doc_analyze
+
 
 # from io import BytesIO
 # from pypdf import PdfReader, PdfWriter
@@ -67,10 +67,10 @@ def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_i
     return output_bytes
 
 
-def do_parse(
+def _do_parse(
     output_dir,
     pdf_file_name,
-    pdf_bytes,
+    pdf_bytes_or_dataset,
     model_list,
     parse_method,
     debug_able,
@@ -92,16 +92,21 @@ def do_parse(
     formula_enable=None,
     table_enable=None,
 ):
+    from magic_pdf.operators.models import InferenceResult
     if debug_able:
         logger.warning('debug mode is on')
         f_draw_model_bbox = True
         f_draw_line_sort_bbox = True
         # f_draw_char_bbox = True
 
-    pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
-        pdf_bytes, start_page_id, end_page_id
-    )
-
+    if isinstance(pdf_bytes_or_dataset, bytes):
+        pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
+            pdf_bytes_or_dataset, start_page_id, end_page_id
+        )
+        ds = PymuDocDataset(pdf_bytes, lang=lang)
+    else:
+        ds = pdf_bytes_or_dataset
+    pdf_bytes = ds._raw_data
     local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
 
     image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
@@ -109,8 +114,6 @@ def do_parse(
     )
     image_dir = str(os.path.basename(local_image_dir))
 
-    ds = PymuDocDataset(pdf_bytes, lang=lang)
-
     if len(model_list) == 0:
         if model_config.__use_inside_model__:
             if parse_method == 'auto':
@@ -241,5 +244,79 @@ def do_parse(
 
     logger.info(f'local output dir is {local_md_dir}')
 
+def do_parse(
+    output_dir,
+    pdf_file_name,
+    pdf_bytes_or_dataset,
+    model_list,
+    parse_method,
+    debug_able,
+    f_draw_span_bbox=True,
+    f_draw_layout_bbox=True,
+    f_dump_md=True,
+    f_dump_middle_json=True,
+    f_dump_model_json=True,
+    f_dump_orig_pdf=True,
+    f_dump_content_list=True,
+    f_make_md_mode=MakeMode.MM_MD,
+    f_draw_model_bbox=False,
+    f_draw_line_sort_bbox=False,
+    f_draw_char_bbox=False,
+    start_page_id=0,
+    end_page_id=None,
+    lang=None,
+    layout_model=None,
+    formula_enable=None,
+    table_enable=None,
+):
+    parallel_count = 1
+    if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
+        parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
+
+    if parallel_count > 1:
+        if isinstance(pdf_bytes_or_dataset, bytes):
+            pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
+                pdf_bytes_or_dataset, start_page_id, end_page_id
+            )
+            ds = PymuDocDataset(pdf_bytes, lang=lang)
+        else:
+            ds = pdf_bytes_or_dataset
+        batch_do_parse(output_dir, [pdf_file_name], [ds], parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
+    else:
+        _do_parse(output_dir, pdf_file_name, pdf_bytes_or_dataset, model_list, parse_method, debug_able, start_page_id=start_page_id, end_page_id=end_page_id, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable,  f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
+
+
+def batch_do_parse(
+    output_dir,
+    pdf_file_names: list[str],
+    pdf_bytes_or_datasets: list[bytes | Dataset],
+    parse_method,
+    debug_able,
+    f_draw_span_bbox=True,
+    f_draw_layout_bbox=True,
+    f_dump_md=True,
+    f_dump_middle_json=True,
+    f_dump_model_json=True,
+    f_dump_orig_pdf=True,
+    f_dump_content_list=True,
+    f_make_md_mode=MakeMode.MM_MD,
+    f_draw_model_bbox=False,
+    f_draw_line_sort_bbox=False,
+    f_draw_char_bbox=False,
+    lang=None,
+    layout_model=None,
+    formula_enable=None,
+    table_enable=None,
+):
+    dss = []
+    for v in pdf_bytes_or_datasets:
+        if isinstance(v, bytes):
+            dss.append(PymuDocDataset(v, lang=lang))
+        else:
+            dss.append(v)
+    infer_results = batch_doc_analyze(dss, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable, one_shot=True)
+    for idx, infer_result in enumerate(infer_results):
+        _do_parse(output_dir, pdf_file_names[idx], dss[idx], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
+
 
 parse_pdf_methods = click.Choice(['ocr', 'txt', 'auto'])