ソースを参照

refactor(pdf_extract_kit): optimize image processing and table recognition logicRefactor the image processing logic for OCR and table recognition to ensure
consistency and improve performance. Remove redundant initialization of PIL images,
unify image cropping logic, and streamline the handling of formula detection results.
Also, adjust the table recognition process to improve integration with the updated image
processing logic and enhance overall efficiency.

myhloli 1 年間 前
コミット
29e590a701
1 ファイル変更61 行追加67 行削除
  1. 61 67
      magic_pdf/model/pdf_extract_kit.py

+ 61 - 67
magic_pdf/model/pdf_extract_kit.py

@@ -27,7 +27,7 @@ except ImportError as e:
     logger.exception(e)
     logger.error(
         'Required dependency not installed, please install by \n'
-        '"pip install magic-pdf[full] detectron2 --extra-index-url https://myhloli.github.io/wheels/"')
+        '"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
     exit(1)
 
 from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
@@ -188,50 +188,56 @@ class CustomPEKModel:
             mfr_cost = round(time.time() - mfr_start, 2)
             logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
 
+        # Select regions for OCR / formula regions / table regions
+        ocr_res_list = []
+        table_res_list = []
+        single_page_mfdetrec_res = []
+        for res in layout_res:
+            if int(res['category_id']) in [13, 14]:
+                single_page_mfdetrec_res.append({
+                    "bbox": [int(res['poly'][0]), int(res['poly'][1]),
+                             int(res['poly'][4]), int(res['poly'][5])],
+                })
+            elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
+                ocr_res_list.append(res)
+            elif int(res['category_id']) in [5]:
+                table_res_list.append(res)
+
+        #  Unified crop img logic
+        def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
+            crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
+            crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
+            # Create a white background with an additional width and height of 50
+            crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
+            crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
+            return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
+
+            # Crop image
+            crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
+            cropped_img = input_pil_img.crop(crop_box)
+            return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
+            return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
+            return return_image, return_list
+
+        pil_img = Image.fromarray(image)
+
         # ocr识别
         if self.apply_ocr:
             ocr_start = time.time()
-            pil_img = Image.fromarray(image)
-
-            # 筛选出需要OCR的区域和公式区域
-            ocr_res_list = []
-            single_page_mfdetrec_res = []
-            for res in layout_res:
-                if int(res['category_id']) in [13, 14]:
-                    single_page_mfdetrec_res.append({
-                        "bbox": [int(res['poly'][0]), int(res['poly'][1]),
-                                 int(res['poly'][4]), int(res['poly'][5])],
-                    })
-                elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
-                    ocr_res_list.append(res)
-
-            # 对每一个需OCR处理的区域进行处理
+            # Process each area that requires OCR processing
             for res in ocr_res_list:
-                xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
-                xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
-
-                paste_x = 50
-                paste_y = 50
-                # 创建一个宽高各多50的白色背景
-                new_width = xmax - xmin + paste_x * 2
-                new_height = ymax - ymin + paste_y * 2
-                new_image = Image.new('RGB', (new_width, new_height), 'white')
-
-                # 裁剪图像
-                crop_box = (xmin, ymin, xmax, ymax)
-                cropped_img = pil_img.crop(crop_box)
-                new_image.paste(cropped_img, (paste_x, paste_y))
-
-                # 调整公式区域坐标
+                new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
+                paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
+                # Adjust the coordinates of the formula area
                 adjusted_mfdetrec_res = []
                 for mf_res in single_page_mfdetrec_res:
                     mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
-                    # 将公式区域坐标调整为相对于裁剪区域的坐标
+                    # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
                     x0 = mf_xmin - xmin + paste_x
                     y0 = mf_ymin - ymin + paste_y
                     x1 = mf_xmax - xmin + paste_x
                     y1 = mf_ymax - ymin + paste_y
-                    # 过滤在图外的公式块
+                    # Filter formula blocks outside the graph
                     if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
                         continue
                     else:
@@ -239,17 +245,17 @@ class CustomPEKModel:
                             "bbox": [x0, y0, x1, y1],
                         })
 
-                # OCR识别
+                # OCR recognition
                 new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
                 ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
 
-                # 整合结果
+                # Integration results
                 if ocr_res:
                     for box_ocr_res in ocr_res:
                         p1, p2, p3, p4 = box_ocr_res[0]
                         text, score = box_ocr_res[1]
 
-                        # 将坐标转换回原图坐标系
+                        # Convert the coordinates back to the original coordinate system
                         p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
                         p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
                         p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
@@ -267,35 +273,23 @@ class CustomPEKModel:
 
         # 表格识别 table recognition
         if self.apply_table:
-            pil_img = Image.fromarray(image)
-            for layout in layout_res:
-                if layout.get("category_id", -1) == 5:
-                    poly = layout["poly"]
-                    xmin, ymin = int(poly[0]), int(poly[1])
-                    xmax, ymax = int(poly[4]), int(poly[5])
-
-                    paste_x = 50
-                    paste_y = 50
-                    # 创建一个宽高各多50的白色背景 create a whiteboard with 50 larger width and length
-                    new_width = xmax - xmin + paste_x * 2
-                    new_height = ymax - ymin + paste_y * 2
-                    new_image = Image.new('RGB', (new_width, new_height), 'white')
-
-                    # 裁剪图像 crop image
-                    crop_box = (xmin, ymin, xmax, ymax)
-                    cropped_img = pil_img.crop(crop_box)
-                    new_image.paste(cropped_img, (paste_x, paste_y))
-                    start_time = time.time()
-                    logger.info("------------------table recognition processing begins-----------------")
+            table_start = time.time()
+            for res in table_res_list:
+                new_image, _ = crop_img(res, pil_img)
+                single_table_start_time = time.time()
+                logger.info("------------------table recognition processing begins-----------------")
+                with torch.no_grad():
                     latex_code = self.table_model.image2latex(new_image)[0]
-                    end_time = time.time()
-                    run_time = end_time - start_time
-                    logger.info(f"------------table recognition processing ends within {run_time}s-----")
-                    if run_time > self.table_max_time:
-                        logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
-                    # 判断是否返回正常
-                    if latex_code and latex_code.strip().endswith('end{tabular}'):
-                        layout["latex"] = latex_code
-                    else:
-                        logger.warning(f"------------table recognition processing fails----------")
+                run_time = time.time() - single_table_start_time
+                logger.info(f"------------table recognition processing ends within {run_time}s-----")
+                if run_time > self.table_max_time:
+                    logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
+                # 判断是否返回正常
+                if latex_code and latex_code.strip().endswith('end{tabular}'):
+                    res["latex"] = latex_code
+                else:
+                    logger.warning(f"------------table recognition processing fails----------")
+            table_cost = round(time.time() - table_start, 2)
+            logger.info(f"table cost: {table_cost}")
+
         return layout_res