浏览代码

feat(table): add orientation detection and rotation for portrait tables

- Implement table orientation detection to identify if a table is in portrait mode
- Add rotation logic to turn portrait tables 90 degrees clockwise before OCR
- Update OCR processing to work with potentially rotated images
- Improve text box analysis to determine if a table is rotated
myhloli 7 月之前
父节点
当前提交
ac893f325a
共有 1 个文件被更改,包括 53 次插入12 次删除
  1. 53 12
      magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py

+ 53 - 12
magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py

@@ -35,26 +35,67 @@ class RapidTableModel(object):
         #     from rapidocr_onnxruntime import RapidOCR
         #     self.ocr_engine = RapidOCR()
 
-        self.ocr_model_name = "PaddleOCR"
+        # self.ocr_model_name = "PaddleOCR"
         self.ocr_engine = ocr_engine
 
 
     def predict(self, image):
+        bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
 
-        if self.ocr_model_name == "RapidOCR":
-            ocr_result, _ = self.ocr_engine(np.asarray(image))
-        elif self.ocr_model_name == "PaddleOCR":
-            bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
-            ocr_result = self.ocr_engine.ocr(bgr_image)[0]
-            if ocr_result:
-                ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
-                          len(item) == 2 and isinstance(item[1], tuple)]
-            else:
-                ocr_result = None
+        # First check the overall image aspect ratio (height/width)
+        img_height, img_width = bgr_image.shape[:2]
+        img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
+        img_is_portrait = img_aspect_ratio > 1.2
+
+        if img_is_portrait:
+
+            det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
+            # Check if table is rotated by analyzing text box aspect ratios
+            is_rotated = False
+            if det_res:
+                aspect_ratios = []
+                vertical_count = 0
+
+                for box_ocr_res in det_res:
+                    p1, p2, p3, p4 = box_ocr_res
+
+                    # Calculate width and height
+                    width = max(np.linalg.norm(np.array(p1) - np.array(p2)),
+                                np.linalg.norm(np.array(p3) - np.array(p4)))
+                    height = max(np.linalg.norm(np.array(p1) - np.array(p4)),
+                                 np.linalg.norm(np.array(p2) - np.array(p3)))
+
+                    aspect_ratio = width / height if height > 0 else 1.0
+                    aspect_ratios.append(aspect_ratio)
+
+                    # Count vertical vs horizontal text boxes
+                    if aspect_ratio < 0.8:  # Taller than wide - vertical text
+                        vertical_count += 1
+                    # elif aspect_ratio > 1.2:  # Wider than tall - horizontal text
+                    #     horizontal_count += 1
+
+                # If we have more vertical text boxes than horizontal ones,
+                # and vertical ones are significant, table might be rotated
+                if vertical_count >= len(det_res) * 0.3:
+                    is_rotated = True
+
+                # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
+
+            # Rotate image if necessary
+            if is_rotated:
+                # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise")
+                image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE)
+                bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+
+        # Continue with OCR on potentially rotated image
+        ocr_result = self.ocr_engine.ocr(bgr_image)[0]
+        if ocr_result:
+            ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
+                      len(item) == 2 and isinstance(item[1], tuple)]
         else:
-            logger.error("OCR model not supported")
             ocr_result = None
 
+
         if ocr_result:
             table_results = self.table_model(np.asarray(image), ocr_result)
             html_code = table_results.pred_html