Forráskód Böngészése

refactor(model): optimize batch processing and inference

- Update batch processing logic for improved efficiency
- Refactor image analysis and inference methods
- Optimize dataset handling and image retrieval
- Improve error handling and logging in batch processes
myhloli 7 hónapja
szülő
commit
d2fc9dabf4

+ 58 - 10
magic_pdf/data/batch_build_dataset.py

@@ -103,17 +103,65 @@ def batch_build_dataset(pdf_paths, k, lang=None):
     all_images : list
         List of all processed images
     """
-    # Get page counts for each PDF
-    pdf_info = []
-    total_pages = 0
 
     results = []
     for pdf_path in pdf_paths:
-        try:
-            with open(pdf_path, 'rb') as f:
-                bits = f.read() 
-            results.append(PymuDocDataset(bits, lang))
-        except Exception as e:
-            print(f'Error opening {pdf_path}: {e}')
-
+        with open(pdf_path, 'rb') as f:
+            pdf_bytes = f.read()
+        dataset = PymuDocDataset(pdf_bytes, lang=lang)
+        results.append(dataset)
     return results
+
+
+    #
+    # # Get page counts for each PDF
+    # pdf_info = []
+    # total_pages = 0
+    #
+    # for pdf_path in pdf_paths:
+    #     try:
+    #         doc = fitz.open(pdf_path)
+    #         num_pages = len(doc)
+    #         pdf_info.append((pdf_path, num_pages))
+    #         total_pages += num_pages
+    #         doc.close()
+    #     except Exception as e:
+    #         print(f'Error opening {pdf_path}: {e}')
+    #
+    # # Partition the jobs based on page countEach job has 1 page
+    # partitions = partition_array_greedy(pdf_info, k)
+    #
+    # # Process each partition in parallel
+    # all_images_h = {}
+    #
+    # with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor:
+    #     # Submit one task per partition
+    #     futures = []
+    #     for sn, partition in enumerate(partitions):
+    #         # Get the jobs for this partition
+    #         partition_jobs = [pdf_info[idx] for idx in partition]
+    #
+    #         # Submit the task
+    #         future = executor.submit(
+    #             process_pdf_batch,
+    #             partition_jobs,
+    #             sn
+    #         )
+    #         futures.append(future)
+    #     # Process results as they complete
+    #     for i, future in enumerate(concurrent.futures.as_completed(futures)):
+    #         try:
+    #             idx, images = future.result()
+    #             all_images_h[idx] = images
+    #         except Exception as e:
+    #             print(f'Error processing partition: {e}')
+    # results = [None] * len(pdf_paths)
+    # for i in range(len(partitions)):
+    #     partition = partitions[i]
+    #     for j in range(len(partition)):
+    #         with open(pdf_info[partition[j]][0], 'rb') as f:
+    #             pdf_bytes = f.read()
+    #         dataset = PymuDocDataset(pdf_bytes, lang=lang)
+    #         dataset.set_images(all_images_h[i][j])
+    #         results[partition[j]] = dataset
+    # return results

+ 12 - 3
magic_pdf/data/dataset.py

@@ -150,7 +150,7 @@ class PymuDocDataset(Dataset):
         elif lang == 'auto':
             from magic_pdf.model.sub_modules.language_detection.utils import \
                 auto_detect_lang
-            self._lang = auto_detect_lang(bits)
+            self._lang = auto_detect_lang(self._data_bits)
             logger.info(f'lang: {lang}, detect_lang: {self._lang}')
         else:
             self._lang = lang
@@ -342,8 +342,17 @@ class Doc(PageableData):
                 height: int
             }
         """
-        return fitz_doc_to_image(self._doc)
+        if self._img is None:
+            self._img = fitz_doc_to_image(self._doc)
+        return self._img
 
+    def set_image(self, img):
+        """
+        Args:
+            img (np.ndarray): the image
+        """
+        if self._img is None:
+            self._img = img
 
     def get_doc(self) -> fitz.Page:
         """Get the pymudoc object.
@@ -396,4 +405,4 @@ class Doc(PageableData):
             fontsize (int): font size of the text
             color (list[float] | None):  three element tuple which describe the RGB of the board line, None will use the default font color!
         """
-        self._doc.insert_text(coord, content, fontsize=fontsize, color=color)
+        self._doc.insert_text(coord, content, fontsize=fontsize, color=color)

+ 8 - 2
magic_pdf/model/batch_analyze.py

@@ -30,8 +30,14 @@ class BatchAnalyze:
     
         images_layout_res = []
         layout_start_time = time.time()
-        _, fst_ocr, fst_lang = images_with_extra_info[0]
-        self.model = self.model_manager.get_model(fst_ocr, self.show_log, fst_lang, self.layout_model, self.formula_enable, self.table_enable)
+        self.model = self.model_manager.get_model(
+            ocr=True,
+            show_log=self.show_log,
+            lang = None,
+            layout_model = self.layout_model,
+            formula_enable = self.formula_enable,
+            table_enable = self.table_enable,
+        )
 
         images = [image for image, _, _ in images_with_extra_info]
 

+ 33 - 39
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -138,31 +138,27 @@ def doc_analyze(
     )
 
     MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
-    batch_size = MIN_BATCH_INFERENCE_SIZE
     images = []
     page_wh_list = []
-    images_with_extra_info = []
-    results = []
     for index in range(len(dataset)):
         if start_page_id <= index <= end_page_id:
             page_data = dataset.get_page(index)
             img_dict = page_data.get_image()
             images.append(img_dict['img'])
             page_wh_list.append((img_dict['width'], img_dict['height']))
-            if lang is None or lang == 'auto':
-                images_with_extra_info.append((images[index], ocr, dataset._lang))
-            else:
-                images_with_extra_info.append((images[index], ocr, lang))
-                
-            if len(images_with_extra_info) == batch_size:
-                _, result = may_batch_image_analyze(images_with_extra_info, 0, ocr, show_log, layout_model, formula_enable, table_enable)
-                results.extend(result)
-                images_with_extra_info = [] 
-
-    if len(images_with_extra_info) > 0:
-        _, result = may_batch_image_analyze(images_with_extra_info, 0, ocr, show_log, layout_model, formula_enable, table_enable)
+
+    images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
+
+    if len(images) >= MIN_BATCH_INFERENCE_SIZE:
+        batch_size = MIN_BATCH_INFERENCE_SIZE
+        batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
+    else:
+        batch_images = [images_with_extra_info]
+
+    results = []
+    for batch_image in batch_images:
+        result = may_batch_image_analyze(batch_image, ocr, show_log,layout_model, formula_enable, table_enable)
         results.extend(result)
-        images_with_extra_info = [] 
 
     model_json = []
     for index in range(len(dataset)):
@@ -183,7 +179,7 @@ def doc_analyze(
 
 def batch_doc_analyze(
     datasets: list[Dataset],
-    parse_method: str,
+    parse_method: str = 'auto',
     show_log: bool = False,
     lang=None,
     layout_model=None,
@@ -192,36 +188,35 @@ def batch_doc_analyze(
 ):
     MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
     batch_size = MIN_BATCH_INFERENCE_SIZE
-    images = []
     page_wh_list = []
-    results = []
 
     images_with_extra_info = []
     for dataset in datasets:
-        for index in range(len(dataset)):
-            if lang is None or lang == 'auto':
-                _lang = dataset._lang
-            else:
-                _lang = lang
 
+        ocr = False
+        if parse_method == 'auto':
+            if dataset.classify() == SupportedPdfParseMethod.TXT:
+                ocr = False
+            elif dataset.classify() == SupportedPdfParseMethod.OCR:
+                ocr = True
+        elif parse_method == 'ocr':
+            ocr = True
+        elif parse_method == 'txt':
+            ocr = False
+
+        _lang = dataset._lang
+
+        for index in range(len(dataset)):
             page_data = dataset.get_page(index)
             img_dict = page_data.get_image()
-            images.append(img_dict['img'])
             page_wh_list.append((img_dict['width'], img_dict['height']))
-            if parse_method == 'auto':
-                images_with_extra_info.append((images[-1], dataset.classify() == SupportedPdfParseMethod.OCR, _lang))
-            else:
-                images_with_extra_info.append((images[-1], parse_method == 'ocr', _lang))
+            images_with_extra_info.append((img_dict['img'], ocr, _lang))
 
-            if len(images_with_extra_info) == batch_size:
-                _, result = may_batch_image_analyze(images_with_extra_info, 0, True, show_log, layout_model, formula_enable, table_enable)
-                results.extend(result)
-                images_with_extra_info = [] 
-
-    if len(images_with_extra_info) > 0:
-        _, result = may_batch_image_analyze(images_with_extra_info, 0, True, show_log, layout_model, formula_enable, table_enable)
+    batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
+    results = []
+    for batch_image in batch_images:
+        result = may_batch_image_analyze(batch_image, True, show_log, layout_model, formula_enable, table_enable)
         results.extend(result)
-        images_with_extra_info = [] 
 
     infer_results = []
     from magic_pdf.operators.models import InferenceResult
@@ -240,7 +235,6 @@ def batch_doc_analyze(
 
 def may_batch_image_analyze(
         images_with_extra_info: list[(np.ndarray, bool, str)],
-        idx: int,
         ocr: bool,
         show_log: bool = False,
         layout_model=None,
@@ -298,4 +292,4 @@ def may_batch_image_analyze(
     #     f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
     #     f' speed: {doc_analyze_speed} pages/second'
     # )
-    return idx, results
+    return results