Эх сурвалжийг харах

Merge pull request #374 from myhloli/master

fix&refactor(pdf-extract-kit):  table recognition and ocr
Xiaomeng Zhao 1 жил өмнө
parent
commit
2502db13af

+ 62 - 67
magic_pdf/model/pdf_extract_kit.py

@@ -27,7 +27,7 @@ except ImportError as e:
     logger.exception(e)
     logger.error(
         'Required dependency not installed, please install by \n'
-        '"pip install magic-pdf[full] detectron2 --extra-index-url https://myhloli.github.io/wheels/"')
+        '"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
     exit(1)
 
 from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
@@ -188,50 +188,56 @@ class CustomPEKModel:
             mfr_cost = round(time.time() - mfr_start, 2)
             logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
 
+        # Select regions for OCR / formula regions / table regions
+        ocr_res_list = []
+        table_res_list = []
+        single_page_mfdetrec_res = []
+        for res in layout_res:
+            if int(res['category_id']) in [13, 14]:
+                single_page_mfdetrec_res.append({
+                    "bbox": [int(res['poly'][0]), int(res['poly'][1]),
+                             int(res['poly'][4]), int(res['poly'][5])],
+                })
+            elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
+                ocr_res_list.append(res)
+            elif int(res['category_id']) in [5]:
+                table_res_list.append(res)
+
+        #  Unified crop img logic
+        def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
+            crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
+            crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
+            # Create a white background with an additional width and height of 50
+            crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
+            crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
+            return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
+
+            # Crop image
+            crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
+            cropped_img = input_pil_img.crop(crop_box)
+            return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
+            return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
+            return return_image, return_list
+
+        pil_img = Image.fromarray(image)
+
         # ocr识别
         if self.apply_ocr:
             ocr_start = time.time()
-            pil_img = Image.fromarray(image)
-
-            # 筛选出需要OCR的区域和公式区域
-            ocr_res_list = []
-            single_page_mfdetrec_res = []
-            for res in layout_res:
-                if int(res['category_id']) in [13, 14]:
-                    single_page_mfdetrec_res.append({
-                        "bbox": [int(res['poly'][0]), int(res['poly'][1]),
-                                 int(res['poly'][4]), int(res['poly'][5])],
-                    })
-                elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
-                    ocr_res_list.append(res)
-
-            # 对每一个需OCR处理的区域进行处理
+            # Process each area that requires OCR processing
             for res in ocr_res_list:
-                xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
-                xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
-
-                paste_x = 50
-                paste_y = 50
-                # 创建一个宽高各多50的白色背景
-                new_width = xmax - xmin + paste_x * 2
-                new_height = ymax - ymin + paste_y * 2
-                new_image = Image.new('RGB', (new_width, new_height), 'white')
-
-                # 裁剪图像
-                crop_box = (xmin, ymin, xmax, ymax)
-                cropped_img = pil_img.crop(crop_box)
-                new_image.paste(cropped_img, (paste_x, paste_y))
-
-                # 调整公式区域坐标
+                new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
+                paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
+                # Adjust the coordinates of the formula area
                 adjusted_mfdetrec_res = []
                 for mf_res in single_page_mfdetrec_res:
                     mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
-                    # 将公式区域坐标调整为相对于裁剪区域的坐标
+                    # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
                     x0 = mf_xmin - xmin + paste_x
                     y0 = mf_ymin - ymin + paste_y
                     x1 = mf_xmax - xmin + paste_x
                     y1 = mf_ymax - ymin + paste_y
-                    # 过滤在图外的公式块
+                    # Filter formula blocks outside the graph
                     if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
                         continue
                     else:
@@ -239,17 +245,17 @@ class CustomPEKModel:
                             "bbox": [x0, y0, x1, y1],
                         })
 
-                # OCR识别
+                # OCR recognition
                 new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
                 ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
 
-                # 整合结果
+                # Integration results
                 if ocr_res:
                     for box_ocr_res in ocr_res:
                         p1, p2, p3, p4 = box_ocr_res[0]
                         text, score = box_ocr_res[1]
 
-                        # 将坐标转换回原图坐标系
+                        # Convert the coordinates back to the original coordinate system
                         p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
                         p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
                         p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
@@ -267,35 +273,24 @@ class CustomPEKModel:
 
         # 表格识别 table recognition
         if self.apply_table:
-            pil_img = Image.fromarray(image)
-            for layout in layout_res:
-                if layout.get("category_id", -1) == 5:
-                    poly = layout["poly"]
-                    xmin, ymin = int(poly[0]), int(poly[1])
-                    xmax, ymax = int(poly[4]), int(poly[5])
-
-                    paste_x = 50
-                    paste_y = 50
-                    # 创建一个宽高各多50的白色背景 create a whiteboard with 50 larger width and length
-                    new_width = xmax - xmin + paste_x * 2
-                    new_height = ymax - ymin + paste_y * 2
-                    new_image = Image.new('RGB', (new_width, new_height), 'white')
-
-                    # 裁剪图像 crop image
-                    crop_box = (xmin, ymin, xmax, ymax)
-                    cropped_img = pil_img.crop(crop_box)
-                    new_image.paste(cropped_img, (paste_x, paste_y))
-                    start_time = time.time()
-                    logger.info("------------------table recognition processing begins-----------------")
+            table_start = time.time()
+            for res in table_res_list:
+                new_image, _ = crop_img(res, pil_img)
+                single_table_start_time = time.time()
+                logger.info("------------------table recognition processing begins-----------------")
+                with torch.no_grad():
                     latex_code = self.table_model.image2latex(new_image)[0]
-                    end_time = time.time()
-                    run_time = end_time - start_time
-                    logger.info(f"------------table recognition processing ends within {run_time}s-----")
-                    if run_time > self.table_max_time:
-                        logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
-                    # 判断是否返回正常
-                    if latex_code and latex_code.strip().endswith('end{tabular}'):
-                        layout["latex"] = latex_code
-                    else:
-                        logger.warning(f"------------table recognition processing fails----------")
+                run_time = time.time() - single_table_start_time
+                logger.info(f"------------table recognition processing ends within {run_time}s-----")
+                if run_time > self.table_max_time:
+                    logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
+                # 判断是否返回正常
+                expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
+                if latex_code and expected_ending:
+                    res["latex"] = latex_code
+                else:
+                    logger.warning(f"------------table recognition processing fails----------")
+            table_cost = round(time.time() - table_start, 2)
+            logger.info(f"table cost: {table_cost}")
+
         return layout_res