Parcourir la source

Merge pull request #1821 from myhloli/dev

perf(mfr): improve Math Formula Recognition by sorting images by area
Xiaomeng Zhao il y a 8 mois
Parent
commit
a01bd7eda7

+ 54 - 0
magic_pdf/libs/performance_stats.py

@@ -0,0 +1,54 @@
+import time
+import functools
+from collections import defaultdict
+from typing import Dict, List
+
+
+class PerformanceStats:
+    """性能统计类,用于收集和展示方法执行时间"""
+
+    _stats: Dict[str, List[float]] = defaultdict(list)
+
+    @classmethod
+    def add_execution_time(cls, func_name: str, execution_time: float):
+        """添加执行时间记录"""
+        cls._stats[func_name].append(execution_time)
+
+    @classmethod
+    def get_stats(cls) -> Dict[str, dict]:
+        """获取统计结果"""
+        results = {}
+        for func_name, times in cls._stats.items():
+            results[func_name] = {
+                'count': len(times),
+                'total_time': sum(times),
+                'avg_time': sum(times) / len(times),
+                'min_time': min(times),
+                'max_time': max(times)
+            }
+        return results
+
+    @classmethod
+    def print_stats(cls):
+        """打印统计结果"""
+        stats = cls.get_stats()
+        print("\n性能统计结果:")
+        print("-" * 80)
+        print(f"{'方法名':<40} {'调用次数':>8} {'总时间(s)':>12} {'平均时间(s)':>12}")
+        print("-" * 80)
+        for func_name, data in stats.items():
+            print(f"{func_name:<40} {data['count']:8d} {data['total_time']:12.6f} {data['avg_time']:12.6f}")
+
+
+def measure_time(func):
+    """测量方法执行时间的装饰器"""
+
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        start_time = time.time()
+        result = func(*args, **kwargs)
+        execution_time = time.time() - start_time
+        PerformanceStats.add_execution_time(func.__name__, execution_time)
+        return result
+
+    return wrapper

+ 1 - 7
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -170,13 +170,7 @@ def doc_analyze(
         gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
         if gpu_memory is not None and gpu_memory >= 8:
 
-            if gpu_memory >= 40:
-                batch_ratio = 32
-            elif gpu_memory >=20:
-                batch_ratio = 16
-            elif gpu_memory >= 16:
-                batch_ratio = 8
-            elif gpu_memory >= 10:
+            if gpu_memory >= 10:
                 batch_ratio = 4
             else:
                 batch_ratio = 2

+ 74 - 9
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py

@@ -100,20 +100,61 @@ class UnimernetModel(object):
             res["latex"] = latex_rm_whitespace(latex)
         return formula_list
 
-    def batch_predict(
-        self, images_mfd_res: list, images: list, batch_size: int = 64
-    ) -> list:
+    # def batch_predict(
+    #     self, images_mfd_res: list, images: list, batch_size: int = 64
+    # ) -> list:
+    #     images_formula_list = []
+    #     mf_image_list = []
+    #     backfill_list = []
+    #     for image_index in range(len(images_mfd_res)):
+    #         mfd_res = images_mfd_res[image_index]
+    #         pil_img = Image.fromarray(images[image_index])
+    #         formula_list = []
+    #
+    #         for xyxy, conf, cla in zip(
+    #             mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
+    #         ):
+    #             xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+    #             new_item = {
+    #                 "category_id": 13 + int(cla.item()),
+    #                 "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+    #                 "score": round(float(conf.item()), 2),
+    #                 "latex": "",
+    #             }
+    #             formula_list.append(new_item)
+    #             bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
+    #             mf_image_list.append(bbox_img)
+    #
+    #         images_formula_list.append(formula_list)
+    #         backfill_list += formula_list
+    #
+    #     dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
+    #     dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
+    #     mfr_res = []
+    #     for mf_img in dataloader:
+    #         mf_img = mf_img.to(self.device)
+    #         with torch.no_grad():
+    #             output = self.model.generate({"image": mf_img})
+    #         mfr_res.extend(output["pred_str"])
+    #     for res, latex in zip(backfill_list, mfr_res):
+    #         res["latex"] = latex_rm_whitespace(latex)
+    #     return images_formula_list
+
+    def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
         images_formula_list = []
         mf_image_list = []
         backfill_list = []
+        image_info = []  # Store (area, original_index, image) tuples
+
+        # Collect images with their original indices
         for image_index in range(len(images_mfd_res)):
             mfd_res = images_mfd_res[image_index]
             pil_img = Image.fromarray(images[image_index])
             formula_list = []
 
-            for xyxy, conf, cla in zip(
-                mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
-            ):
+            for idx, (xyxy, conf, cla) in enumerate(zip(
+                    mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
+            )):
                 xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
                 new_item = {
                     "category_id": 13 + int(cla.item()),
@@ -123,19 +164,43 @@ class UnimernetModel(object):
                 }
                 formula_list.append(new_item)
                 bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
+                area = (xmax - xmin) * (ymax - ymin)
+
+                curr_idx = len(mf_image_list)
+                image_info.append((area, curr_idx, bbox_img))
                 mf_image_list.append(bbox_img)
 
             images_formula_list.append(formula_list)
             backfill_list += formula_list
 
-        dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
+        # Stable sort by area
+        image_info.sort(key=lambda x: x[0])  # sort by area
+        sorted_indices = [x[1] for x in image_info]
+        sorted_images = [x[2] for x in image_info]
+
+        # Create mapping for results
+        index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
+
+        # Create dataset with sorted images
+        dataset = MathDataset(sorted_images, transform=self.mfr_transform)
         dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
+
+        # Process batches and store results
         mfr_res = []
         for mf_img in dataloader:
             mf_img = mf_img.to(self.device)
             with torch.no_grad():
                 output = self.model.generate({"image": mf_img})
             mfr_res.extend(output["pred_str"])
-        for res, latex in zip(backfill_list, mfr_res):
-            res["latex"] = latex_rm_whitespace(latex)
+
+        # Restore original order
+        unsorted_results = [""] * len(mfr_res)
+        for new_idx, latex in enumerate(mfr_res):
+            original_idx = index_mapping[new_idx]
+            unsorted_results[original_idx] = latex_rm_whitespace(latex)
+
+        # Fill results back
+        for res, latex in zip(backfill_list, unsorted_results):
+            res["latex"] = latex
+
         return images_formula_list

+ 7 - 3
magic_pdf/pdf_parse_union_core_v2.py

@@ -21,9 +21,12 @@ from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_l
 from magic_pdf.libs.convert_utils import dict_to_list
 from magic_pdf.libs.hash_utils import compute_md5
 from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
+from magic_pdf.libs.performance_stats import measure_time, PerformanceStats
 from magic_pdf.model.magic_model import MagicModel
 from magic_pdf.post_proc.llm_aided import llm_aided_formula, llm_aided_text, llm_aided_title
 
+from concurrent.futures import ThreadPoolExecutor
+
 try:
     import torchtext
 
@@ -215,7 +218,7 @@ def calculate_contrast(img, img_mode) -> float:
     # logger.info(f"contrast: {contrast}")
     return round(contrast, 2)
 
-
+# @measure_time
 def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
     # cid用0xfffd表示,连字符拆开
     # text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
@@ -489,7 +492,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
     else:
         return [[x0, y0, x1, y1]]
 
-
+# @measure_time
 def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
     page_line_list = []
 
@@ -923,7 +926,6 @@ def pdf_parse_union(
     magic_model = MagicModel(model_list, dataset)
 
     """根据输入的起始范围解析pdf"""
-    # end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
     end_page_id = (
         end_page_id
         if end_page_id is not None and end_page_id >= 0
@@ -960,6 +962,8 @@ def pdf_parse_union(
             )
         pdf_info_dict[f'page_{page_id}'] = page_info
 
+    # PerformanceStats.print_stats()
+
     """分段"""
     para_split(pdf_info_dict)