Ver código fonte

fix: adjust OCR box coordinates and confidence threshold for improved accuracy

myhloli 2 meses atrás
pai
commit
33f4a21ae8

+ 2 - 3
mineru/model/table/cls/paddle_table_cls.py

@@ -1,5 +1,4 @@
 import os
-from pathlib import Path
 
 from PIL import Image
 import cv2
@@ -146,8 +145,8 @@ class PaddleTableClsModel:
                     idx = np.argmax(img_res)
                     conf = float(np.max(img_res))
                     # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
-                    if idx == 0 and conf < 0.8:
-                        idx = 1
+                    # if idx == 0 and conf < 0.8:
+                    #     idx = 1
                     label_res.append((self.labels[idx],conf))
                 pbar.update(len(img_batch))
             for img_info, (label, conf) in zip(img_info_list, label_res):

+ 12 - 4
mineru/model/table/rec/unet_table/main.py

@@ -181,7 +181,7 @@ class WiredTableRecognition:
                 logger.warning(f"No OCR engine provided for box {i}: {box}")
                 continue
             # 从img中截取对应的区域
-            x1, y1, x2, y2 = int(box[0][0]), int(box[0][1]), int(box[2][0]), int(box[2][1])
+            x1, y1, x2, y2 = int(box[0][0])+1, int(box[0][1])+1, int(box[2][0])-1, int(box[2][1])-1
             if x1 >= x2 or y1 >= y2:
                 logger.warning(f"Invalid box coordinates: {box}")
                 continue
@@ -196,6 +196,14 @@ class WiredTableRecognition:
         if len(img_crop_list) > 0:
             # 进行ocr识别
             ocr_result = self.ocr_engine.ocr(img_crop_list, det=False)
+            # ocr_result = [[]]
+            # for crop_img in img_crop_list:
+            #     tmp_ocr_result = self.ocr_engine.ocr(crop_img)
+            #     if tmp_ocr_result[0] and len(tmp_ocr_result[0]) > 0 and isinstance(tmp_ocr_result[0], list) and len(tmp_ocr_result[0][0]) == 2:
+            #         ocr_result[0].append(tmp_ocr_result[0][0][1])
+            #     else:
+            #         ocr_result[0].append(("", 0.0))
+
             if not ocr_result or not isinstance(ocr_result, list) or len(ocr_result) == 0:
                 logger.warning("OCR engine returned no results or invalid result for image crops.")
                 return cell_box_map
@@ -210,10 +218,10 @@ class WiredTableRecognition:
                 # 处理ocr结果
                 ocr_text, ocr_score = ocr_res
                 # logger.debug(f"OCR result for box {i}: {ocr_text} with score {ocr_score}")
-                if ocr_score < 0.9 or ocr_text in ['1']:
+                if ocr_score < 0.6 or ocr_text in ['1','口','■','(204号', '(20', '(2', '(2号', '(20号']:
                     # logger.warning(f"Low confidence OCR result for box {i}: {ocr_text} with score {ocr_score}")
                     box = sorted_polygons[i]
-                    cell_box_map[i] = [[box, "", 0.5]]
+                    cell_box_map[i] = [[box, "", 0.1]]
                     continue
                 cell_box_map[i] = [[box, ocr_text, ocr_score]]
 
@@ -284,7 +292,7 @@ class UnetTableModel:
             gap_of_len = wireless_len - wired_len
             # 判断是否使用无线表格模型的结果
             if (
-                wired_len <= int(wireless_len * 0.55)+1  # 有线模型检测到的单元格数太少(低于无线模型的50%)
+                int(wireless_len * 0.1) <= wired_len <= int(wireless_len * 0.62)+1  # 有线模型检测到的单元格数太少(低于无线模型的55%)
                 or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75))  # 两者相差不大但有线模型结果较少
                 or (gap_of_len == 0 and wired_len <= 4)  # 单元格数量完全相等且总量小于等于4
             ):