瀏覽代碼

Merge pull request #2172 from myhloli/dev

fix(ocr): handle NaN values in recognition scores, feat(table): add orientation detection and rotation for portrait tables
Xiaomeng Zhao 7 月之前
父節點
當前提交
1db6f89dcd

+ 6 - 0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py

@@ -437,4 +437,10 @@ class TextRecognizer(BaseOCRV20):
                 index += 1
                 pbar.update(current_batch_size)
 
+        # Fix NaN values in recognition results
+        for i in range(len(rec_res)):
+            text, score = rec_res[i]
+            if isinstance(score, float) and math.isnan(score):
+                rec_res[i] = (text, 0.0)
+
         return rec_res, elapse

+ 49 - 12
magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py

@@ -35,26 +35,63 @@ class RapidTableModel(object):
         #     from rapidocr_onnxruntime import RapidOCR
         #     self.ocr_engine = RapidOCR()
 
-        self.ocr_model_name = "PaddleOCR"
+        # self.ocr_model_name = "PaddleOCR"
         self.ocr_engine = ocr_engine
 
 
     def predict(self, image):
+        bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
 
-        if self.ocr_model_name == "RapidOCR":
-            ocr_result, _ = self.ocr_engine(np.asarray(image))
-        elif self.ocr_model_name == "PaddleOCR":
-            bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
-            ocr_result = self.ocr_engine.ocr(bgr_image)[0]
-            if ocr_result:
-                ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
-                          len(item) == 2 and isinstance(item[1], tuple)]
-            else:
-                ocr_result = None
+        # First check the overall image aspect ratio (height/width)
+        img_height, img_width = bgr_image.shape[:2]
+        img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
+        img_is_portrait = img_aspect_ratio > 1.2
+
+        if img_is_portrait:
+
+            det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
+            # Check if table is rotated by analyzing text box aspect ratios
+            is_rotated = False
+            if det_res:
+                vertical_count = 0
+
+                for box_ocr_res in det_res:
+                    p1, p2, p3, p4 = box_ocr_res
+
+                    # Calculate width and height
+                    width = p3[0] - p1[0]
+                    height = p3[1] - p1[1]
+
+                    aspect_ratio = width / height if height > 0 else 1.0
+
+                    # Count vertical vs horizontal text boxes
+                    if aspect_ratio < 0.8:  # Taller than wide - vertical text
+                        vertical_count += 1
+                    # elif aspect_ratio > 1.2:  # Wider than tall - horizontal text
+                    #     horizontal_count += 1
+
+                # If we have more vertical text boxes than horizontal ones,
+                # and vertical ones are significant, table might be rotated
+                if vertical_count >= len(det_res) * 0.3:
+                    is_rotated = True
+
+                # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
+
+            # Rotate image if necessary
+            if is_rotated:
+                # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise")
+                image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE)
+                bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+
+        # Continue with OCR on potentially rotated image
+        ocr_result = self.ocr_engine.ocr(bgr_image)[0]
+        if ocr_result:
+            ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
+                      len(item) == 2 and isinstance(item[1], tuple)]
         else:
-            logger.error("OCR model not supported")
             ocr_result = None
 
+
         if ocr_result:
             table_results = self.table_model(np.asarray(image), ocr_result)
             html_code = table_results.pred_html