Ver código fonte

fix: enhance table prediction logic by incorporating table classification score and refining model selection criteria

myhloli 3 meses atrás
pai
commit
c702302684

+ 3 - 2
mineru/backend/pipeline/batch_analyze.py

@@ -269,8 +269,9 @@ class BatchAnalyze:
                 table_cls_model = atom_model_manager.get_atom_model(
                     atom_model_name=AtomicModel.TableCls,
                 )
+                table_cls_score = 0.5
                 try:
-                    table_label = table_cls_model.predict(table_img)
+                    table_label, table_cls_score = table_cls_model.predict(table_img)
                 except Exception as e:
                     table_label = AtomicModel.WirelessTable
                     logger.warning(
@@ -289,7 +290,7 @@ class BatchAnalyze:
                     lang=_lang,
                 )
 
-                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_img)
+                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_img, table_cls_score)
                 # 判断是否返回正常
                 if html_code:
                     # 检查html_code是否包含'<table>'和'</table>'

+ 7 - 3
mineru/model/ori_cls/paddle_ori_cls.py

@@ -4,6 +4,8 @@ import os
 import cv2
 import numpy as np
 import onnxruntime
+from loguru import logger
+
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
@@ -90,7 +92,7 @@ class PaddleOrientationClsModel:
                     # elif aspect_ratio > 1.2:  # Wider than tall - horizontal text
                     #     horizontal_count += 1
 
-                if vertical_count >= len(det_res) * 0.3 and vertical_count >= 3:
+                if vertical_count >= len(det_res) * 0.28 and vertical_count >= 3:
                     is_rotated = True
                 # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
 
@@ -100,11 +102,13 @@ class PaddleOrientationClsModel:
                     x = self.preprocess(img)
                     (result,) = self.sess.run(None, {"x": x})
                     label = self.labels[np.argmax(result)]
-
+                    # logger.debug(f"Orientation classification result: {label}")
                     if label == "270":
                         rotation = cv2.ROTATE_90_CLOCKWISE
                         img = cv2.rotate(np.asarray(img), rotation)
-                    else:  # 除了270度,都认为是90度
+                    elif label == "90":
                         rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
                         img = cv2.rotate(np.asarray(img), rotation)
+                    else:
+                        pass
         return img

+ 2 - 2
mineru/model/table/cls/paddle_table_cls.py

@@ -67,6 +67,6 @@ class PaddleTableClsModel:
         idx = np.argmax(result)
         conf = float(np.max(result))
         # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
-        if idx == 0 and conf < 0.9:
+        if idx == 0 and conf < 0.85:
             idx = 1
-        return self.labels[idx]
+        return self.labels[idx], conf

+ 1 - 1
mineru/model/table/rec/rapid_table.py

@@ -22,7 +22,7 @@ class RapidTableModel(object):
         self.ocr_engine = ocr_engine
 
 
-    def predict(self, image):
+    def predict(self, image, table_cls_score):
         bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
         # Continue with OCR on potentially rotated image
         ocr_result = self.ocr_engine.ocr(bgr_image)[0]

+ 75 - 10
mineru/model/table/rec/unet_table/unet_table.py

@@ -8,6 +8,7 @@ from typing import List, Optional, Union, Dict, Any
 import cv2
 import numpy as np
 from loguru import logger
+from rapid_table import RapidTableInput, RapidTable
 
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
@@ -47,6 +48,7 @@ class UnetTableRecognition:
         self,
         img: InputType,
         ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
+        ocr_engine = None,
         **kwargs,
     ) -> UnetTableOutput:
         s = time.perf_counter()
@@ -84,7 +86,7 @@ class UnetTableRecognition:
                 )
             cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
             # 如果有识别框没有ocr结果,直接进行rec补充
-            cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map)
+            cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map, ocr_engine)
             # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
             t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
             # 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
@@ -150,14 +152,45 @@ class UnetTableRecognition:
         img: np.ndarray,
         sorted_polygons: np.ndarray,
         cell_box_map: Dict[int, List[str]],
+        ocr_engine
     ) -> Dict[int, List[Any]]:
         """找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
+        img_crop_info_list = []
+        img_crop_list = []
         for i in range(sorted_polygons.shape[0]):
             if cell_box_map.get(i):
                 continue
             box = sorted_polygons[i]
-            cell_box_map[i] = [[box, "", 1]]
+            if ocr_engine is None:
+                logger.warning(f"No OCR engine provided for box {i}: {box}")
+                continue
+            # 从img中截取对应的区域
+            x1, y1, x2, y2 = box[0][0], box[0][1], box[2][0], box[2][1]
+            if x1 >= x2 or y1 >= y2:
+                logger.warning(f"Invalid box coordinates: {box}")
+                continue
+            img_crop = img[int(y1):int(y2), int(x1):int(x2)]
+            img_crop_list.append(img_crop)
+            img_crop_info_list.append([i, box])
             continue
+
+        if len(img_crop_list) > 0:
+            # 进行ocr识别
+            ocr_res_list = ocr_engine.ocr(img_crop_list, det=False)[0]
+            assert len(ocr_res_list) == len(img_crop_list)
+            for j, ocr_res in enumerate(ocr_res_list):
+                img_crop_info_list[j].append(ocr_res)
+
+
+            for i, box, ocr_res in img_crop_info_list:
+                # 处理ocr结果
+                ocr_text, ocr_score = ocr_res
+                # logger.debug(f"OCR result for box {i}: {ocr_text} with score {ocr_score}")
+                if ocr_score < 0.9:
+                    # logger.warning(f"Low confidence OCR result for box {i}: {ocr_text} with score {ocr_score}")
+                    continue
+                cell_box_map[i] = [[box, ocr_text, ocr_score]]
+
         return cell_box_map
 
 
@@ -169,11 +202,14 @@ def escape_html(input_string):
 class UnetTableModel:
     def __init__(self, ocr_engine):
         model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.unet_structure), ModelPath.unet_structure)
-        input_args = UnetTableInput(model_path=model_path)
-        self.table_model = UnetTableRecognition(input_args)
+        wired_input_args = UnetTableInput(model_path=model_path)
+        self.wired_table_model = UnetTableRecognition(wired_input_args)
+        slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
+        wireless_input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
+        self.wireless_table_model = RapidTable(wireless_input_args)
         self.ocr_engine = ocr_engine
 
-    def predict(self, img):
+    def predict(self, img, table_cls_score):
         bgr_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
         ocr_result = self.ocr_engine.ocr(bgr_img)[0]
 
@@ -187,11 +223,40 @@ class UnetTableModel:
             ocr_result = None
         if ocr_result:
             try:
-                table_results = self.table_model(np.asarray(img), ocr_result)
-                html_code = table_results.pred_html
-                table_cell_bboxes = table_results.cell_bboxes
-                logic_points = table_results.logic_points
-                elapse = table_results.elapse
+                wired_table_results = self.wired_table_model(np.asarray(img), ocr_result, self.ocr_engine)
+                wired_html_code = wired_table_results.pred_html
+                wired_table_cell_bboxes = wired_table_results.cell_bboxes
+                wired_logic_points = wired_table_results.logic_points
+                wired_elapse = wired_table_results.elapse
+
+                wireless_table_results = self.wireless_table_model(np.asarray(img), ocr_result)
+                wireless_html_code = wireless_table_results.pred_html
+                wireless_table_cell_bboxes = wireless_table_results.cell_bboxes
+                wireless_logic_points = wireless_table_results.logic_points
+                wireless_elapse = wireless_table_results.elapse
+
+                wired_len = len(wired_table_cell_bboxes) if wired_table_cell_bboxes is not None else 0
+                wireless_len = len(wireless_table_cell_bboxes) if wireless_table_cell_bboxes is not None else 0
+                # logger.debug(f"wired table cell bboxes: {wired_len}, wireless table cell bboxes: {wireless_len}")
+                # 计算两种模型检测的单元格数量差异
+                gap_of_len = wireless_len - wired_len
+                # 判断是否使用无线表格模型的结果
+                if (
+                    wired_len <= round(wireless_len * 0.5)  # 有线模型检测到的单元格数太少(低于无线模型的50%)
+                    or (wireless_len < wired_len < (2 * wireless_len) and table_cls_score <= 0.949)  # 有线模型检测到的单元格数反而更多
+                    or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75))  # 两者相差不大但有线模型结果较少
+                    or (gap_of_len == 0 and wired_len <= 4)  # 单元格数量完全相等且总量小于等于4
+                ):
+                    # logger.debug("fall back to wireless table model")
+                    html_code = wireless_html_code
+                    table_cell_bboxes = wireless_table_cell_bboxes
+                    logic_points = wireless_logic_points
+                else:
+                    html_code = wired_html_code
+                    table_cell_bboxes = wired_table_cell_bboxes
+                    logic_points = wired_logic_points
+
+                elapse = wired_elapse + wireless_elapse
                 return html_code, table_cell_bboxes, logic_points, elapse
             except Exception as e:
                 logger.exception(e)