فهرست منبع

feat: add parallel evalution

icecraft 8 ماه پیش
والد
کامیت
b50f742fd9

+ 33 - 32
magic_pdf/data/batch_build_dataset.py

@@ -1,22 +1,24 @@
-import os
+import concurrent.futures
 import glob
+import os
 import threading
-import concurrent.futures
+
 import fitz
-from magic_pdf.data.utils import fitz_doc_to_image  # PyMuPDF
+
 from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.data.utils import fitz_doc_to_image  # PyMuPDF
+
 
 def partition_array_greedy(arr, k):
-    """
-    Partition an array into k parts using a simple greedy approach.
-    
+    """Partition an array into k parts using a simple greedy approach.
+
     Parameters:
     -----------
     arr : list
         The input array of integers
     k : int
         Number of partitions to create
-        
+
     Returns:
     --------
     partitions : list of lists
@@ -24,37 +26,36 @@ def partition_array_greedy(arr, k):
     """
     # Handle edge cases
     if k <= 0:
-        raise ValueError("k must be a positive integer")
+        raise ValueError('k must be a positive integer')
     if k > len(arr):
         k = len(arr)  # Adjust k if it's too large
     if k == 1:
         return [list(range(len(arr)))]
     if k == len(arr):
         return [[i] for i in range(len(arr))]
-    
+
     # Sort the array in descending order
     sorted_indices = sorted(range(len(arr)), key=lambda i: arr[i][1], reverse=True)
-    
+
     # Initialize k empty partitions
     partitions = [[] for _ in range(k)]
     partition_sums = [0] * k
-    
+
     # Assign each element to the partition with the smallest current sum
     for idx in sorted_indices:
         # Find the partition with the smallest sum
         min_sum_idx = partition_sums.index(min(partition_sums))
-        
+
         # Add the element to this partition
         partitions[min_sum_idx].append(idx)  # Store the original index
         partition_sums[min_sum_idx] += arr[idx][1]
-    
+
     return partitions
 
 
 def process_pdf_batch(pdf_jobs, idx):
-    """
-    Process a batch of PDF pages using multiple threads.
-    
+    """Process a batch of PDF pages using multiple threads.
+
     Parameters:
     -----------
     pdf_jobs : list of tuples
@@ -65,14 +66,14 @@ def process_pdf_batch(pdf_jobs, idx):
         Number of threads to use
     **kwargs :
         Additional arguments for process_pdf_page
-        
+
     Returns:
     --------
     images : list
         List of processed images
     """
     images = []
-    
+
     for pdf_path, _ in pdf_jobs:
         doc = fitz.open(pdf_path)
         tmp = []
@@ -83,9 +84,9 @@ def process_pdf_batch(pdf_jobs, idx):
     return (idx, images)
 
 def batch_build_dataset(pdf_paths, k, lang=None):
-    """
-    Process multiple PDFs by partitioning them into k balanced parts and processing each part in parallel.
-    
+    """Process multiple PDFs by partitioning them into k balanced parts and
+    processing each part in parallel.
+
     Parameters:
     -----------
     pdf_paths : list
@@ -98,7 +99,7 @@ def batch_build_dataset(pdf_paths, k, lang=None):
         Number of threads to use per worker
     **kwargs :
         Additional arguments for process_pdf_page
-        
+
     Returns:
     --------
     all_images : list
@@ -107,7 +108,7 @@ def batch_build_dataset(pdf_paths, k, lang=None):
     # Get page counts for each PDF
     pdf_info = []
     total_pages = 0
-    
+
     for pdf_path in pdf_paths:
         try:
             doc = fitz.open(pdf_path)
@@ -116,24 +117,24 @@ def batch_build_dataset(pdf_paths, k, lang=None):
             total_pages += num_pages
             doc.close()
         except Exception as e:
-            print(f"Error opening {pdf_path}: {e}")
-    
+            print(f'Error opening {pdf_path}: {e}')
+
     # Partition the jobs based on page countEach job has 1 page
     partitions = partition_array_greedy(pdf_info, k)
 
     for i, partition in enumerate(partitions):
-        print(f"Partition {i+1}: {len(partition)} pdfs")
-    
+        print(f'Partition {i+1}: {len(partition)} pdfs')
+
     # Process each partition in parallel
     all_images_h = {}
-    
+
     with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor:
         # Submit one task per partition
         futures = []
         for sn, partition in enumerate(partitions):
             # Get the jobs for this partition
             partition_jobs = [pdf_info[idx] for idx in partition]
-            
+
             # Submit the task
             future = executor.submit(
                 process_pdf_batch,
@@ -145,15 +146,15 @@ def batch_build_dataset(pdf_paths, k, lang=None):
         for i, future in enumerate(concurrent.futures.as_completed(futures)):
             try:
                 idx, images = future.result()
-                print(f"Partition {i+1} completed: processed {len(images)} images")
+                print(f'Partition {i+1} completed: processed {len(images)} images')
                 all_images_h[idx] = images
             except Exception as e:
-                print(f"Error processing partition: {e}")
+                print(f'Error processing partition: {e}')
     results = [None] * len(pdf_paths)
     for i in range(len(partitions)):
         partition = partitions[i]
         for j in range(len(partition)):
-            with open(pdf_info[partition[j]][0], "rb") as f:
+            with open(pdf_info[partition[j]][0], 'rb') as f:
                 pdf_bytes = f.read()
             dataset = PymuDocDataset(pdf_bytes, lang=lang)
             dataset.set_images(all_images_h[i][j])

+ 21 - 23
magic_pdf/data/dataset.py

@@ -97,10 +97,10 @@ class Dataset(ABC):
 
     @abstractmethod
     def dump_to_file(self, file_path: str):
-        """Dump the file
+        """Dump the file.
 
-        Args: 
-            file_path (str): the file path 
+        Args:
+            file_path (str): the file path
         """
         pass
 
@@ -119,7 +119,7 @@ class Dataset(ABC):
 
     @abstractmethod
     def classify(self) -> SupportedPdfParseMethod:
-        """classify the dataset 
+        """classify the dataset.
 
         Returns:
             SupportedPdfParseMethod: _description_
@@ -128,8 +128,7 @@ class Dataset(ABC):
 
     @abstractmethod
     def clone(self):
-        """clone this dataset
-        """
+        """clone this dataset."""
         pass
 
 
@@ -148,12 +147,13 @@ class PymuDocDataset(Dataset):
         if lang == '':
             self._lang = None
         elif lang == 'auto':
-            from magic_pdf.model.sub_modules.language_detection.utils import auto_detect_lang
+            from magic_pdf.model.sub_modules.language_detection.utils import \
+                auto_detect_lang
             self._lang = auto_detect_lang(bits)
-            logger.info(f"lang: {lang}, detect_lang: {self._lang}")
+            logger.info(f'lang: {lang}, detect_lang: {self._lang}')
         else:
             self._lang = lang
-            logger.info(f"lang: {lang}")
+            logger.info(f'lang: {lang}')
 
     def __len__(self) -> int:
         """The page number of the pdf."""
@@ -187,12 +187,12 @@ class PymuDocDataset(Dataset):
         return self._records[page_id]
 
     def dump_to_file(self, file_path: str):
-        """Dump the file
+        """Dump the file.
 
-        Args: 
-            file_path (str): the file path 
+        Args:
+            file_path (str): the file path
         """
-        
+
         dir_name = os.path.dirname(file_path)
         if dir_name not in ('', '.', '..'):
             os.makedirs(dir_name, exist_ok=True)
@@ -213,7 +213,7 @@ class PymuDocDataset(Dataset):
         return proc(self, *args, **kwargs)
 
     def classify(self) -> SupportedPdfParseMethod:
-        """classify the dataset 
+        """classify the dataset.
 
         Returns:
             SupportedPdfParseMethod: _description_
@@ -221,8 +221,7 @@ class PymuDocDataset(Dataset):
         return classify(self._data_bits)
 
     def clone(self):
-        """clone this dataset
-        """
+        """clone this dataset."""
         return PymuDocDataset(self._raw_data)
 
     def set_images(self, images):
@@ -274,10 +273,10 @@ class ImageDataset(Dataset):
         return self._records[page_id]
 
     def dump_to_file(self, file_path: str):
-        """Dump the file
+        """Dump the file.
 
-        Args: 
-            file_path (str): the file path 
+        Args:
+            file_path (str): the file path
         """
         dir_name = os.path.dirname(file_path)
         if dir_name not in ('', '.', '..'):
@@ -297,7 +296,7 @@ class ImageDataset(Dataset):
         return proc(self, *args, **kwargs)
 
     def classify(self) -> SupportedPdfParseMethod:
-        """classify the dataset 
+        """classify the dataset.
 
         Returns:
             SupportedPdfParseMethod: _description_
@@ -305,10 +304,9 @@ class ImageDataset(Dataset):
         return SupportedPdfParseMethod.OCR
 
     def clone(self):
-        """clone this dataset
-        """
+        """clone this dataset."""
         return ImageDataset(self._raw_data)
-    
+
     def set_images(self, images):
         for i in range(len(self._records)):
             self._records[i].set_image(images[i])

+ 26 - 28
magic_pdf/data/utils.py

@@ -1,12 +1,14 @@
 
 import multiprocessing as mp
 import threading
+from concurrent.futures import (ProcessPoolExecutor, ThreadPoolExecutor,
+                                as_completed)
+
 import fitz
 import numpy as np
 from loguru import logger
 
 from magic_pdf.utils.annotations import ImportPIL
-from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
 
 
 @ImportPIL
@@ -69,17 +71,17 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
             images.append(img_dict)
     return images
 
-    
+
 def convert_page(bytes_page):
     pdfs = fitz.open('pdf', bytes_page)
     page = pdfs[0]
     return fitz_doc_to_image(page)
-    
+
 def parallel_process_pdf_safe(pages, num_workers=None, **kwargs):
-    """Process PDF pages in parallel with serialization-safe approach"""
+    """Process PDF pages in parallel with serialization-safe approach."""
     if num_workers is None:
         num_workers = mp.cpu_count()
-    
+
 
     # Process the extracted page data in parallel
     with ProcessPoolExecutor(max_workers=num_workers) as executor:
@@ -87,14 +89,13 @@ def parallel_process_pdf_safe(pages, num_workers=None, **kwargs):
         results = list(
             executor.map(convert_page, pages)
         )
-    
+
     return results
 
 
 def threaded_process_pdf(pdf_path, num_threads=4, **kwargs):
-    """
-    Process all pages of a PDF using multiple threads
-    
+    """Process all pages of a PDF using multiple threads.
+
     Parameters:
     -----------
     pdf_path : str
@@ -103,7 +104,7 @@ def threaded_process_pdf(pdf_path, num_threads=4, **kwargs):
         Number of threads to use
     **kwargs :
         Additional arguments for fitz_doc_to_image
-        
+
     Returns:
     --------
     images : list
@@ -112,10 +113,10 @@ def threaded_process_pdf(pdf_path, num_threads=4, **kwargs):
     # Open the PDF
     doc = fitz.open(pdf_path)
     num_pages = len(doc)
-    
+
     # Create a list to store results in the correct order
     results = [None] * num_pages
-    
+
     # Create a thread pool
     with ThreadPoolExecutor(max_workers=num_threads) as executor:
         # Submit all tasks
@@ -130,27 +131,27 @@ def threaded_process_pdf(pdf_path, num_threads=4, **kwargs):
             try:
                 results[page_num] = future.result()
             except Exception as e:
-                print(f"Error processing page {page_num}: {e}")
+                print(f'Error processing page {page_num}: {e}')
                 results[page_num] = None
-    
+
     # Close the document
     doc.close()
 
-if __name__ == "__main__":
+if __name__ == '__main__':
     pdf = fitz.open('/tmp/[MS-DOC].pdf')
-    
-    
+
+
     pdf_page = [fitz.open() for i in range(pdf.page_count)]
     [pdf_page[i].insert_pdf(pdf, from_page=i, to_page=i) for i in range(pdf.page_count)]
- 
+
     pdf_page = [v.tobytes() for v in pdf_page]
     results = parallel_process_pdf_safe(pdf_page, num_workers=16)
-    
+
     # threaded_process_pdf('/tmp/[MS-DOC].pdf', num_threads=16)
 
     """ benchmark results of multi-threaded processing (fitz page to image)
-    total page nums: 578 
-    thread nums,    time cost 
+    total page nums: 578
+    thread nums,    time cost
     1               7.351 sec
     2               6.334 sec
     4               5.968 sec
@@ -159,14 +160,11 @@ if __name__ == "__main__":
     """
 
     """ benchmark results of multi-processor processing (fitz page to image)
-    total page nums: 578 
-    processor nums,    time cost 
+    total page nums: 578
+    processor nums,    time cost
     1                  17.170 sec
-    2                  10.170 sec 
-    4                  7.841 sec 
+    2                  10.170 sec
+    4                  7.841 sec
     8                  7.900 sec
     16                 7.984 sec
     """
-
-
-

+ 22 - 20
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -1,9 +1,11 @@
+import concurrent.futures as fut
+import multiprocessing as mp
 import os
 import time
-import torch
+
 import numpy as np
-import multiprocessing as mp
-import concurrent.futures as fut
+import torch
+
 os.environ['FLAGS_npu_jit_compile'] = '0'  # 关闭paddle的jit编译
 os.environ['FLAGS_use_stride_kernel'] = '0'
 os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # 让mps可以fallback
@@ -29,6 +31,7 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
                                           get_local_models_dir,
                                           get_table_recog_config)
 from magic_pdf.model.model_list import MODEL
+
 # from magic_pdf.operators.models import InferenceResult
 
 MIN_BATCH_INFERENCE_SIZE = 100
@@ -170,7 +173,7 @@ def doc_analyze(
         else:
             batch_images = [images]
         results = []
-        parallel_count = len(batch_images) # adjust to real parallel count 
+        parallel_count = len(batch_images) # adjust to real parallel count
         # using concurrent.futures to analyze
         """
         with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
@@ -192,8 +195,8 @@ def doc_analyze(
         _, results = may_batch_image_analyze(
             images,
             0,
-            ocr, 
-            show_log, 
+            ocr,
+            show_log,
             lang, layout_model, formula_enable, table_enable)
 
     model_json = []
@@ -234,7 +237,7 @@ def batch_doc_analyze(
             img_dict = page_data.get_image()
             images.append(img_dict['img'])
             page_wh_list.append((img_dict['width'], img_dict['height']))
-    
+
     if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
         if parallel_count is None:
             parallel_count = 2 # should check the gpu memory firstly !
@@ -245,7 +248,7 @@ def batch_doc_analyze(
         else:
             batch_images = [images]
         results = []
-        parallel_count = len(batch_images) # adjust to real parallel count 
+        parallel_count = len(batch_images) # adjust to real parallel count
         # using concurrent.futures to analyze
         """
         with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
@@ -266,8 +269,8 @@ def batch_doc_analyze(
         _, results = may_batch_image_analyze(
             images,
             0,
-            ocr, 
-            show_log, 
+            ocr,
+            show_log,
             lang, layout_model, formula_enable, table_enable)
     infer_results = []
 
@@ -286,20 +289,20 @@ def batch_doc_analyze(
 
 
 def may_batch_image_analyze(
-        images: list[np.ndarray], 
+        images: list[np.ndarray],
         idx: int,
-        ocr: bool = False, 
-        show_log: bool = False, 
-        lang=None, 
-        layout_model=None, 
-        formula_enable=None, 
+        ocr: bool = False,
+        show_log: bool = False,
+        lang=None,
+        layout_model=None,
+        formula_enable=None,
         table_enable=None):
     # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
     # 关闭paddle的信号处理
     import paddle
     paddle.disable_signal_handler()
     from magic_pdf.model.batch_analyze import BatchAnalyze
-    
+
     model_manager = ModelSingleton()
     custom_model = model_manager.get_model(
         ocr, show_log, lang, layout_model, formula_enable, table_enable
@@ -310,14 +313,14 @@ def may_batch_image_analyze(
     device = get_device()
 
     npu_support = False
-    if str(device).startswith("npu"):
+    if str(device).startswith('npu'):
         import torch_npu
         if torch_npu.npu.is_available():
             npu_support = True
             torch.npu.set_compile_mode(jit_compile=False)
 
     if torch.cuda.is_available() and device != 'cpu' or npu_support:
-        gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
+        gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
         if gpu_memory is not None and gpu_memory >= 8:
             if gpu_memory >= 20:
                 batch_ratio = 16
@@ -398,4 +401,3 @@ def may_batch_image_analyze(
         f' speed: {doc_analyze_speed} pages/second'
     )
     return (idx, results)
-

+ 28 - 17
magic_pdf/model/sub_modules/model_init.py

@@ -1,19 +1,27 @@
-import os 
+import os
+
 import torch
 from loguru import logger
 
 from magic_pdf.config.constants import MODEL_NAME
 from magic_pdf.model.model_list import AtomicModel
-from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
-from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
-from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
+from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import \
+    YOLOv11LangDetModel
+from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
+    DocLayoutYOLOModel
+from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
+    Layoutlmv3_Predictor
 from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
 from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
 
 try:
-    from magic_pdf_ascend_plugin.libs.license_verifier import load_license, LicenseFormatError, LicenseSignatureError, LicenseExpiredError
-    from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
-    from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
+    from magic_pdf_ascend_plugin.libs.license_verifier import (
+        LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
+        load_license)
+    from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import \
+        ModifiedPaddleOCR
+    from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import \
+        RapidTableModel
     license_key = load_license()
     logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
                 f' License expired at {license_key["payload"]["date"]["end_date"]}')
@@ -21,21 +29,24 @@ except Exception as e:
     if isinstance(e, ImportError):
         pass
     elif isinstance(e, LicenseFormatError):
-        logger.error("Ascend Plugin: Invalid license format. Please check the license file.")
+        logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
     elif isinstance(e, LicenseSignatureError):
-        logger.error("Ascend Plugin: Invalid signature. The license may be tampered with.")
+        logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
     elif isinstance(e, LicenseExpiredError):
-        logger.error("Ascend Plugin: License has expired. Please renew your license.")
+        logger.error('Ascend Plugin: License has expired. Please renew your license.')
     elif isinstance(e, FileNotFoundError):
-        logger.error("Ascend Plugin: Not found License file.")
+        logger.error('Ascend Plugin: Not found License file.')
     else:
-        logger.error(f"Ascend Plugin: {e}")
+        logger.error(f'Ascend Plugin: {e}')
     from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
     # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
     from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
 
-from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
-from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
+from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
+    StructTableModel
+from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
+    TableMasterPaddleModel
+
 
 def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
     if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
@@ -56,7 +67,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
 
 
 def mfd_model_init(weight, device='cpu'):
-    if str(device).startswith("npu"):
+    if str(device).startswith('npu'):
         device = torch.device(device)
     mfd_model = YOLOv8MFDModel(weight, device)
     return mfd_model
@@ -73,14 +84,14 @@ def layout_model_init(weight, config_file, device):
 
 
 def doclayout_yolo_model_init(weight, device='cpu'):
-    if str(device).startswith("npu"):
+    if str(device).startswith('npu'):
         device = torch.device(device)
     model = DocLayoutYOLOModel(weight, device)
     return model
 
 
 def langdetect_model_init(langdetect_model_weight, device='cpu'):
-    if str(device).startswith("npu"):
+    if str(device).startswith('npu'):
         device = torch.device(device)
     model = YOLOv11LangDetModel(langdetect_model_weight, device)
     return model

+ 8 - 8
magic_pdf/tools/cli.py

@@ -1,20 +1,20 @@
 import os
 import shutil
 import tempfile
+from pathlib import Path
+
 import click
 import fitz
 from loguru import logger
-from pathlib import Path
 
 import magic_pdf.model as model_config
-from magic_pdf.data.data_reader_writer import FileBasedDataReader
 from magic_pdf.data.batch_build_dataset import batch_build_dataset
+from magic_pdf.data.data_reader_writer import FileBasedDataReader
 from magic_pdf.data.dataset import Dataset
 from magic_pdf.libs.version import __version__
-from magic_pdf.tools.common import do_parse, parse_pdf_methods, batch_do_parse
+from magic_pdf.tools.common import batch_do_parse, do_parse, parse_pdf_methods
 from magic_pdf.utils.office_to_pdf import convert_file_to_pdf
 
-
 pdf_suffixes = ['.pdf']
 ms_office_suffixes = ['.ppt', '.pptx', '.doc', '.docx']
 image_suffixes = ['.png', '.jpeg', '.jpg']
@@ -97,19 +97,19 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
     def read_fn(path: Path):
         if path.suffix in ms_office_suffixes:
             convert_file_to_pdf(str(path), temp_dir)
-            fn = os.path.join(temp_dir, f"{path.stem}.pdf")
+            fn = os.path.join(temp_dir, f'{path.stem}.pdf')
         elif path.suffix in image_suffixes:
             with open(str(path), 'rb') as f:
                 bits = f.read()
             pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
-            fn = os.path.join(temp_dir, f"{path.stem}.pdf")
+            fn = os.path.join(temp_dir, f'{path.stem}.pdf')
             with open(fn, 'wb') as f:
                 f.write(pdf_bytes)
         elif path.suffix in pdf_suffixes:
             fn = str(path)
         else:
-            raise Exception(f"Unknown file suffix: {path.suffix}")
-        
+            raise Exception(f'Unknown file suffix: {path.suffix}')
+
         disk_rw = FileBasedDataReader(os.path.dirname(fn))
         return disk_rw.read(os.path.basename(fn))
 

+ 3 - 3
magic_pdf/tools/common.py

@@ -8,10 +8,10 @@ import magic_pdf.model as model_config
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import FileBasedDataWriter
-from magic_pdf.data.dataset import PymuDocDataset, Dataset
+from magic_pdf.data.dataset import Dataset, PymuDocDataset
 from magic_pdf.libs.draw_bbox import draw_char_bbox
-from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze, batch_doc_analyze
-
+from magic_pdf.model.doc_analyze_by_custom_model import (batch_doc_analyze,
+                                                         doc_analyze)
 
 # from io import BytesIO
 # from pypdf import PdfReader, PdfWriter