Parcourir la source

Merge pull request #3321 from myhloli/dev

Dev
Xiaomeng Zhao il y a 3 mois
Parent
commit
0137913fd2

+ 29 - 14
mineru/backend/pipeline/batch_analyze.py

@@ -9,6 +9,7 @@ from .model_list import AtomicModel
 from ...utils.config_reader import get_formula_enable, get_table_enable
 from ...utils.model_utils import crop_img, get_res_list_from_layout_res
 from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
+from ...utils.pdf_image_tools import get_crop_img
 
 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
@@ -40,10 +41,7 @@ class BatchAnalyze:
         images = [image for image, _, _ in images_with_extra_info]
 
         # doclayout_yolo
-        layout_images = []
-        for image_index, image in enumerate(images):
-            layout_images.append(image)
-
+        layout_images = images.copy()
 
         images_layout_res += self.model.layout_model.batch_predict(
             layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
@@ -89,7 +87,14 @@ class BatchAnalyze:
                                           })
 
             for table_res in table_res_list:
-                table_img, _ = crop_img(table_res, pil_img)
+                # table_img, _ = crop_img(table_res, pil_img)
+                # bbox = (241, 208, 1475, 2019)
+                scale = 10/3
+                crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
+                crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
+                bbox = (int(crop_xmin/scale), int(crop_ymin/scale), int(crop_xmax/scale), int(crop_ymax/scale))
+                table_img = get_crop_img(bbox, pil_img, scale=scale)
+
                 table_res_list_all_page.append({'table_res':table_res,
                                                 'lang':_lang,
                                                 'table_img':table_img,
@@ -140,14 +145,17 @@ class BatchAnalyze:
                 )
 
                 # 按分辨率分组并同时完成padding
+                # RESOLUTION_GROUP_STRIDE = 32
+                RESOLUTION_GROUP_STRIDE = 64  # 定义分辨率分组的步进值
+
                 resolution_groups = defaultdict(list)
                 for crop_info in lang_crop_list:
                     cropped_img = crop_info[0]
                     h, w = cropped_img.shape[:2]
                     # 使用更大的分组容差,减少分组数量
                     # 将尺寸标准化到32的倍数
-                    normalized_h = ((h + 32) // 32) * 32  # 向上取整到32的倍数
-                    normalized_w = ((w + 32) // 32) * 32
+                    normalized_h = ((h + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE  # 向上取整到32的倍数
+                    normalized_w = ((w + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
                     group_key = (normalized_h, normalized_w)
                     resolution_groups[group_key].append(crop_info)
 
@@ -157,8 +165,8 @@ class BatchAnalyze:
                     # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
                     max_h = max(crop_info[0].shape[0] for crop_info in group_crops)
                     max_w = max(crop_info[0].shape[1] for crop_info in group_crops)
-                    target_h = ((max_h + 32 - 1) // 32) * 32
-                    target_w = ((max_w + 32 - 1) // 32) * 32
+                    target_h = ((max_h + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
+                    target_w = ((max_w + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
 
                     # 对所有图像进行padding到统一尺寸
                     batch_images = []
@@ -256,14 +264,22 @@ class BatchAnalyze:
                     atom_model_name=AtomicModel.ImgOrientationCls,
                 )
                 try:
-                    table_img = img_orientation_cls_model.predict(
+                    rotate_label = img_orientation_cls_model.predict(
                         table_res_dict["table_img"]
                     )
                 except Exception as e:
                     logger.warning(
                         f"Image orientation classification failed: {e}, using original image"
                     )
-                    table_img = table_res_dict["table_img"]
+                    rotate_label = "0"
+
+                np_table_img = np.asarray(table_res_dict["table_img"])
+                if rotate_label == "270":
+                    np_table_img = cv2.rotate(np_table_img, cv2.ROTATE_90_CLOCKWISE)
+                elif rotate_label == "90":
+                    np_table_img = cv2.rotate(np_table_img, cv2.ROTATE_90_COUNTERCLOCKWISE)
+                else:
+                    pass
 
                 # 有线表/无线表分类
                 table_cls_model = atom_model_manager.get_atom_model(
@@ -271,7 +287,7 @@ class BatchAnalyze:
                 )
                 table_cls_score = 0.5
                 try:
-                    table_label, table_cls_score = table_cls_model.predict(table_img)
+                    table_label, table_cls_score = table_cls_model.predict(np_table_img)
                 except Exception as e:
                     table_label = AtomicModel.WirelessTable
                     logger.warning(
@@ -289,8 +305,7 @@ class BatchAnalyze:
                     atom_model_name=table_label,
                     lang=_lang,
                 )
-
-                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_img, table_cls_score)
+                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(np_table_img, table_cls_score)
                 # 判断是否返回正常
                 if html_code:
                     # 检查html_code是否包含'<table>'和'</table>'

+ 1 - 1
mineru/backend/pipeline/model_init.py

@@ -11,7 +11,7 @@ from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
 from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
 from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
 from ...model.table.rec.rapid_table import RapidTableModel
-from ...model.table.rec.unet_table.unet_table import UnetTableModel
+from ...model.table.rec.unet_table.main import UnetTableModel
 from ...utils.enum_class import ModelPath
 from ...utils.models_download_utils import auto_download_and_get_model_root_path
 

+ 7 - 4
mineru/backend/pipeline/pipeline_analyze.py

@@ -1,7 +1,7 @@
 import os
 import time
 from typing import List, Tuple
-import PIL.Image
+from PIL import Image
 from loguru import logger
 
 from .model_init import MineruPipelineModel
@@ -148,10 +148,9 @@ def doc_analyze(
 
 
 def batch_image_analyze(
-        images_with_extra_info: List[Tuple[PIL.Image.Image, bool, str]],
+        images_with_extra_info: List[Tuple[Image.Image, bool, str]],
         formula_enable=True,
         table_enable=True):
-    # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
 
     from .batch_analyze import BatchAnalyze
 
@@ -191,10 +190,14 @@ def batch_image_analyze(
             batch_ratio = 1
             logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
 
-    if str(device).startswith('mps'):
+    # 检测torch的版本号
+    import torch
+    from packaging import version
+    if version.parse(torch.__version__) >= version.parse("2.8.0") or str(device).startswith('mps'):
         enable_ocr_det_batch = False
     else:
         enable_ocr_det_batch = True
+
     batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable, enable_ocr_det_batch)
     results = batch_model(images_with_extra_info)
 

+ 1 - 1
mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py

@@ -89,7 +89,7 @@ class PytorchPaddleOCR(TextSystem):
         kwargs['det_model_path'] = det_model_path
         kwargs['rec_model_path'] = rec_model_path
         kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
-        kwargs['rec_batch_num'] = 16
+        kwargs['rec_batch_num'] = 8
 
         kwargs['device'] = device
 

+ 17 - 18
mineru/model/ori_cls/paddle_ori_cls.py

@@ -1,6 +1,7 @@
 # Copyright (c) Opendatalab. All rights reserved.
 import os
 
+from PIL import Image
 import cv2
 import numpy as np
 import onnxruntime
@@ -23,15 +24,13 @@ class PaddleOrientationClsModel:
         self.mean = [0.485, 0.456, 0.406]
         self.labels = ["0", "90", "180", "270"]
 
-    def preprocess(self, img):
-        # PIL图像转cv2
-        img = np.array(img)
+    def preprocess(self, input_img):
         # 放大图片,使其最短边长为256
-        h, w = img.shape[:2]
+        h, w = input_img.shape[:2]
         scale = 256 / min(h, w)
         h_resize = round(h * scale)
         w_resize = round(w * scale)
-        img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
+        img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
         # 调整为224*224的正方形
         h, w = img.shape[:2]
         cw, ch = 224, 224
@@ -62,8 +61,15 @@ class PaddleOrientationClsModel:
         x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
         return x
 
-    def predict(self, img):
-        bgr_image = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
+    def predict(self, input_img):
+        rotate_label = "0"  # Default to 0 if no rotation detected or not portrait
+        if isinstance(input_img, Image.Image):
+            np_img = np.asarray(input_img)
+        elif isinstance(input_img, np.ndarray):
+            np_img = input_img
+        else:
+            raise ValueError("Input must be a pillow object or a numpy array.")
+        bgr_image = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
         # First check the overall image aspect ratio (height/width)
         img_height, img_width = bgr_image.shape[:2]
         img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
@@ -99,16 +105,9 @@ class PaddleOrientationClsModel:
                 # If we have more vertical text boxes than horizontal ones,
                 # and vertical ones are significant, table might be rotated
                 if is_rotated:
-                    x = self.preprocess(img)
+                    x = self.preprocess(np_img)
                     (result,) = self.sess.run(None, {"x": x})
-                    label = self.labels[np.argmax(result)]
+                    rotate_label = self.labels[np.argmax(result)]
                     # logger.debug(f"Orientation classification result: {label}")
-                    if label == "270":
-                        rotation = cv2.ROTATE_90_CLOCKWISE
-                        img = cv2.rotate(np.asarray(img), rotation)
-                    elif label == "90":
-                        rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
-                        img = cv2.rotate(np.asarray(img), rotation)
-                    else:
-                        pass
-        return img
+
+        return rotate_label

+ 12 - 7
mineru/model/table/cls/paddle_table_cls.py

@@ -1,5 +1,6 @@
 import os
 
+from PIL import Image
 import cv2
 import numpy as np
 import onnxruntime
@@ -22,15 +23,13 @@ class PaddleTableClsModel:
         self.mean = [0.485, 0.456, 0.406]
         self.labels = [AtomicModel.WiredTable, AtomicModel.WirelessTable]
 
-    def preprocess(self, img):
-        # PIL图像转cv2
-        img = np.array(img)
+    def preprocess(self, input_img):
         # 放大图片,使其最短边长为256
-        h, w = img.shape[:2]
+        h, w = input_img.shape[:2]
         scale = 256 / min(h, w)
         h_resize = round(h * scale)
         w_resize = round(w * scale)
-        img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
+        img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
         # 调整为224*224的正方形
         h, w = img.shape[:2]
         cw, ch = 224, 224
@@ -61,8 +60,14 @@ class PaddleTableClsModel:
         x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
         return x
 
-    def predict(self, img):
-        x = self.preprocess(img)
+    def predict(self, input_img):
+        if isinstance(input_img, Image.Image):
+            np_img = np.asarray(input_img)
+        elif isinstance(input_img, np.ndarray):
+            np_img = input_img
+        else:
+            raise ValueError("Input must be a pillow object or a numpy array.")
+        x = self.preprocess(np_img)
         result = self.sess.run(None, {"x": x})
         idx = np.argmax(result)
         conf = float(np.max(result))

+ 0 - 1
mineru/model/table/rec/unet_table/__init__.py

@@ -1 +0,0 @@
-# Copyright (c) Opendatalab. All rights reserved.

+ 88 - 39
mineru/model/table/rec/unet_table/unet_table.py → mineru/model/table/rec/unet_table/main.py

@@ -1,21 +1,24 @@
 import html
+import logging
 import os
 import time
 import traceback
 from dataclasses import dataclass, asdict
-from typing import List, Optional, Union, Dict, Any
 
-import cv2
+from typing import List, Optional, Union, Dict, Any
 import numpy as np
+import cv2
+from PIL import Image
 from loguru import logger
 from rapid_table import RapidTableInput, RapidTable
 
+from .table_structure_unet import TSRUnet
+
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
-from .table_structure_unet import TSRUnet
 from .table_recover import TableRecover
-from .wired_table_rec_utils import InputType, LoadImage
-from .table_recover_utils import (
+from .utils import InputType, LoadImage, VisTable
+from .utils_table_recover import (
     match_ocr_cell,
     plot_html_table,
     box_4_2_poly_to_box_4_1,
@@ -25,32 +28,32 @@ from .table_recover_utils import (
 
 
 @dataclass
-class UnetTableInput:
+class WiredTableInput:
     model_path: str
     device: str = "cpu"
 
 
 @dataclass
-class UnetTableOutput:
+class WiredTableOutput:
     pred_html: Optional[str] = None
     cell_bboxes: Optional[np.ndarray] = None
     logic_points: Optional[np.ndarray] = None
     elapse: Optional[float] = None
 
 
-class UnetTableRecognition:
-    def __init__(self, config: UnetTableInput):
+class WiredTableRecognition:
+    def __init__(self, config: WiredTableInput, ocr_engine=None):
         self.table_structure = TSRUnet(asdict(config))
         self.load_img = LoadImage()
         self.table_recover = TableRecover()
+        self.ocr_engine = ocr_engine
 
     def __call__(
         self,
         img: InputType,
         ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
-        ocr_engine = None,
         **kwargs,
-    ) -> UnetTableOutput:
+    ) -> WiredTableOutput:
         s = time.perf_counter()
         need_ocr = True
         col_threshold = 15
@@ -62,8 +65,8 @@ class UnetTableRecognition:
         img = self.load_img(img)
         polygons, rotated_polygons = self.table_structure(img, **kwargs)
         if polygons is None:
-            # logger.warning("polygons is None.")
-            return UnetTableOutput("", None, None, 0.0)
+            logging.warning("polygons is None.")
+            return WiredTableOutput("", None, None, 0.0)
 
         try:
             table_res, logi_points = self.table_recover(
@@ -78,7 +81,7 @@ class UnetTableRecognition:
                 sorted_polygons, idx_list = sorted_ocr_boxes(
                     [box_4_2_poly_to_box_4_1(box) for box in polygons]
                 )
-                return UnetTableOutput(
+                return WiredTableOutput(
                     "",
                     sorted_polygons,
                     logi_points[idx_list],
@@ -86,12 +89,12 @@ class UnetTableRecognition:
                 )
             cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
             # 如果有识别框没有ocr结果,直接进行rec补充
-            cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map, ocr_engine)
+            cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map)
             # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
             t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
             # 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
             t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
-            # cell_box_map =
+
             logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list]
             cell_box_det_map = {
                 i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
@@ -103,9 +106,9 @@ class UnetTableRecognition:
             elapse = time.perf_counter() - s
 
         except Exception:
-            logger.warning(traceback.format_exc())
-            return UnetTableOutput("", None, None, 0.0)
-        return UnetTableOutput(pred_html, polygons, logi_points, elapse)
+            logging.warning(traceback.format_exc())
+            return WiredTableOutput("", None, None, 0.0)
+        return WiredTableOutput(pred_html, polygons, logi_points, elapse)
 
     def transform_res(
         self,
@@ -139,29 +142,43 @@ class UnetTableRecognition:
     def sort_and_gather_ocr_res(self, res):
         for i, dict_res in enumerate(res):
             _, sorted_idx = sorted_ocr_boxes(
-                [ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threshold=0.3
+                [ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threhold=0.3
             )
             dict_res["t_ocr_res"] = [dict_res["t_ocr_res"][i] for i in sorted_idx]
             dict_res["t_ocr_res"] = gather_ocr_list_by_row(
-                dict_res["t_ocr_res"], threshold=0.3
+                dict_res["t_ocr_res"], threhold=0.3
             )
         return res
 
+    # def fill_blank_rec(
+    #     self,
+    #     img: np.ndarray,
+    #     sorted_polygons: np.ndarray,
+    #     cell_box_map: Dict[int, List[str]],
+    # ) -> Dict[int, List[Any]]:
+    #     """找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
+    #     for i in range(sorted_polygons.shape[0]):
+    #         if cell_box_map.get(i):
+    #             continue
+    #         box = sorted_polygons[i]
+    #         cell_box_map[i] = [[box, "", 1]]
+    #         continue
+    #     return cell_box_map
     def fill_blank_rec(
         self,
         img: np.ndarray,
         sorted_polygons: np.ndarray,
         cell_box_map: Dict[int, List[str]],
-        ocr_engine
     ) -> Dict[int, List[Any]]:
         """找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
+        bgr_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
         img_crop_info_list = []
         img_crop_list = []
         for i in range(sorted_polygons.shape[0]):
             if cell_box_map.get(i):
                 continue
             box = sorted_polygons[i]
-            if ocr_engine is None:
+            if self.ocr_engine is None:
                 logger.warning(f"No OCR engine provided for box {i}: {box}")
                 continue
             # 从img中截取对应的区域
@@ -169,14 +186,13 @@ class UnetTableRecognition:
             if x1 >= x2 or y1 >= y2:
                 logger.warning(f"Invalid box coordinates: {box}")
                 continue
-            img_crop = img[int(y1):int(y2), int(x1):int(x2)]
+            img_crop = bgr_img[int(y1):int(y2), int(x1):int(x2)]
             img_crop_list.append(img_crop)
             img_crop_info_list.append([i, box])
-            continue
 
         if len(img_crop_list) > 0:
             # 进行ocr识别
-            ocr_result = ocr_engine.ocr(img_crop_list, det=False)
+            ocr_result = self.ocr_engine.ocr(img_crop_list, det=False)
             if not ocr_result or not isinstance(ocr_result, list) or len(ocr_result) == 0:
                 logger.warning("OCR engine returned no results or invalid result for image crops.")
                 return cell_box_map
@@ -187,13 +203,14 @@ class UnetTableRecognition:
             for j, ocr_res in enumerate(ocr_res_list):
                 img_crop_info_list[j].append(ocr_res)
 
-
             for i, box, ocr_res in img_crop_info_list:
                 # 处理ocr结果
                 ocr_text, ocr_score = ocr_res
                 # logger.debug(f"OCR result for box {i}: {ocr_text} with score {ocr_score}")
-                if ocr_score < 0.9:
+                if ocr_score < 0.9 or ocr_text in ['1']:
                     # logger.warning(f"Low confidence OCR result for box {i}: {ocr_text} with score {ocr_score}")
+                    box = sorted_polygons[i]
+                    cell_box_map[i] = [[box, "", 0.5]]
                     continue
                 cell_box_map[i] = [[box, ocr_text, ocr_score]]
 
@@ -205,20 +222,37 @@ def escape_html(input_string):
     return html.escape(input_string)
 
 
+def count_table_cells_physical(html_code):
+    """计算表格的物理单元格数量(合并单元格算一个)"""
+    if not html_code:
+        return 0
+
+    # 简单计数td和th标签的数量
+    html_lower = html_code.lower()
+    td_count = html_lower.count('<td')
+    th_count = html_lower.count('<th')
+    return td_count + th_count
+
+
 class UnetTableModel:
     def __init__(self, ocr_engine):
         model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.unet_structure), ModelPath.unet_structure)
-        wired_input_args = UnetTableInput(model_path=model_path)
-        self.wired_table_model = UnetTableRecognition(wired_input_args)
+        wired_input_args = WiredTableInput(model_path=model_path)
+        self.wired_table_model = WiredTableRecognition(wired_input_args, ocr_engine)
         slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
         wireless_input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
         self.wireless_table_model = RapidTable(wireless_input_args)
         self.ocr_engine = ocr_engine
 
-    def predict(self, img, table_cls_score):
-        bgr_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
+    def predict(self, input_img, table_cls_score):
+        if isinstance(input_img, Image.Image):
+            np_img = np.asarray(input_img)
+        elif isinstance(input_img, np.ndarray):
+            np_img = input_img
+        else:
+            raise ValueError("Input must be a pillow object or a numpy array.")
+        bgr_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
         ocr_result = self.ocr_engine.ocr(bgr_img)[0]
-
         if ocr_result:
             ocr_result = [
                 [item[0], escape_html(item[1][0]), item[1][1]]
@@ -229,27 +263,42 @@ class UnetTableModel:
             ocr_result = None
         if ocr_result:
             try:
-                wired_table_results = self.wired_table_model(np.asarray(img), ocr_result, self.ocr_engine)
+                wired_table_results = self.wired_table_model(np_img, ocr_result)
+
+                # viser = VisTable()
+                # save_html_path = f"outputs/output.html"
+                # save_drawed_path = f"outputs/output_table_vis.jpg"
+                # save_logic_path = (
+                #     f"outputs/output_table_vis_logic.jpg"
+                # )
+                # vis_imged = viser(
+                #     np_img, wired_table_results, save_html_path, save_drawed_path, save_logic_path
+                # )
+
                 wired_html_code = wired_table_results.pred_html
                 wired_table_cell_bboxes = wired_table_results.cell_bboxes
                 wired_logic_points = wired_table_results.logic_points
                 wired_elapse = wired_table_results.elapse
 
-                wireless_table_results = self.wireless_table_model(np.asarray(img), ocr_result)
+                wireless_table_results = self.wireless_table_model(np_img, ocr_result)
                 wireless_html_code = wireless_table_results.pred_html
                 wireless_table_cell_bboxes = wireless_table_results.cell_bboxes
                 wireless_logic_points = wireless_table_results.logic_points
                 wireless_elapse = wireless_table_results.elapse
 
-                wired_len = len(wired_table_cell_bboxes) if wired_table_cell_bboxes is not None else 0
-                wireless_len = len(wireless_table_cell_bboxes) if wireless_table_cell_bboxes is not None else 0
+                # wired_len = len(wired_table_cell_bboxes) if wired_table_cell_bboxes is not None else 0
+                # wireless_len = len(wireless_table_cell_bboxes) if wireless_table_cell_bboxes is not None else 0
+
+                wired_len = count_table_cells_physical(wired_html_code)
+                wireless_len = count_table_cells_physical(wireless_html_code)
+
                 # logger.debug(f"wired table cell bboxes: {wired_len}, wireless table cell bboxes: {wireless_len}")
                 # 计算两种模型检测的单元格数量差异
                 gap_of_len = wireless_len - wired_len
                 # 判断是否使用无线表格模型的结果
                 if (
-                    wired_len <= round(wireless_len * 0.5)  # 有线模型检测到的单元格数太少(低于无线模型的50%)
-                    or ((wireless_len < wired_len) and (wired_len < (2 * wireless_len)) and table_cls_score <= 0.949)  # 有线模型检测到的单元格数反而更多
+                    wired_len <= int(wireless_len * 0.55)+1  # 有线模型检测到的单元格数太少(低于无线模型的50%)
+                    # or ((round(wireless_len*1.2) < wired_len) and (wired_len < (2 * wireless_len)) and table_cls_score <= 0.94)  # 有线模型检测到的单元格数反而更多
                     or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75))  # 两者相差不大但有线模型结果较少
                     or (gap_of_len == 0 and wired_len <= 4)  # 单元格数量完全相等且总量小于等于4
                 ):

+ 1 - 0
mineru/model/table/rec/unet_table/table_recover.py

@@ -1,4 +1,5 @@
 from typing import Dict, List, Tuple
+
 import numpy as np
 
 

+ 65 - 8
mineru/model/table/rec/unet_table/table_structure_unet.py

@@ -5,15 +5,15 @@ from typing import Optional, Dict, Any, Tuple
 import cv2
 import numpy as np
 from skimage import measure
-from .wired_table_rec_utils import OrtInferSession, resize_img
-from .table_line_rec_utils import (
+from .utils import OrtInferSession, resize_img
+from .utils_table_line_rec import (
     get_table_line,
     final_adjust_lines,
     min_area_rect_box,
     draw_lines,
     adjust_lines,
 )
-from .table_recover_utils import (
+from.utils_table_recover import (
     sorted_ocr_boxes,
     box_4_2_poly_to_box_4_1,
 )
@@ -50,7 +50,7 @@ class TSRUnet:
         )
         _, idx = sorted_ocr_boxes(
             [box_4_2_poly_to_box_4_1(poly_box) for poly_box in rotated_polygons],
-            threshold=0.4,
+            threhold=0.4,
         )
         polygons = polygons[idx]
         rotated_polygons = rotated_polygons[idx]
@@ -94,7 +94,8 @@ class TSRUnet:
         extend_line = (
             kwargs.get("extend_line", enhance_box_line) if kwargs else enhance_box_line
         )  # 是否进行线段延长使得端点连接
-
+        # 是否进行旋转修正
+        rotated_fix = kwargs.get("rotated_fix") if kwargs else True
         ori_shape = img.shape
         pred = np.uint8(pred)
         hpred = copy.deepcopy(pred)  # 横线
@@ -130,9 +131,16 @@ class TSRUnet:
             rowboxes, colboxes = final_adjust_lines(rowboxes, colboxes)
         line_img = np.zeros(img.shape[:2], dtype="uint8")
         line_img = draw_lines(line_img, rowboxes + colboxes, color=255, lineW=2)
-
-        polygons = self.cal_region_boxes(line_img)
-        rotated_polygons = polygons.copy()
+        rotated_angle = self.cal_rotate_angle(line_img)
+        if rotated_fix and abs(rotated_angle) > 0.3:
+            rotated_line_img = self.rotate_image(line_img, rotated_angle)
+            rotated_polygons = self.cal_region_boxes(rotated_line_img)
+            polygons = self.unrotate_polygons(
+                rotated_polygons, rotated_angle, line_img.shape
+            )
+        else:
+            polygons = self.cal_region_boxes(line_img)
+            rotated_polygons = polygons.copy()
         return polygons, rotated_polygons
 
     def cal_region_boxes(self, tmp):
@@ -147,3 +155,52 @@ class TSRUnet:
             adjust_box=False,
         )  # 最后一个参数改为False
         return np.array(ceilboxes)
+
+    def cal_rotate_angle(self, tmp):
+        # 计算最外侧的旋转框
+        contours, _ = cv2.findContours(tmp, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+        if not contours:
+            return 0
+        largest_contour = max(contours, key=cv2.contourArea)
+        rect = cv2.minAreaRect(largest_contour)
+        # 计算旋转角度
+        angle = rect[2]
+        if angle < -45:
+            angle += 90
+        elif angle > 45:
+            angle -= 90
+        return angle
+
+    def rotate_image(self, image, angle):
+        # 获取图像的中心点
+        (h, w) = image.shape[:2]
+        center = (w // 2, h // 2)
+
+        # 计算旋转矩阵
+        M = cv2.getRotationMatrix2D(center, angle, 1.0)
+
+        # 进行旋转
+        rotated_image = cv2.warpAffine(
+            image, M, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REPLICATE
+        )
+
+        return rotated_image
+
+    def unrotate_polygons(
+        self, polygons: np.ndarray, angle: float, img_shape: tuple
+    ) -> np.ndarray:
+        # 将多边形旋转回原始位置
+        (h, w) = img_shape
+        center = (w // 2, h // 2)
+        M_inv = cv2.getRotationMatrix2D(center, -angle, 1.0)
+
+        # 将 (N, 8) 转换为 (N, 4, 2)
+        polygons_reshaped = polygons.reshape(-1, 4, 2)
+
+        # 批量逆旋转
+        unrotated_polygons = cv2.transform(polygons_reshaped, M_inv)
+
+        # 将 (N, 4, 2) 转换回 (N, 8)
+        unrotated_polygons = unrotated_polygons.reshape(-1, 8)
+
+        return unrotated_polygons

+ 154 - 52
mineru/model/table/rec/unet_table/wired_table_rec_utils.py → mineru/model/table/rec/unet_table/utils.py

@@ -3,9 +3,10 @@ import traceback
 from enum import Enum
 from io import BytesIO
 from pathlib import Path
-from typing import List, Union, Dict, Any, Tuple
+from typing import List, Union, Dict, Any, Tuple, Optional
 
 import cv2
+import loguru
 import numpy as np
 from onnxruntime import (
     GraphOptimizationLevel,
@@ -15,17 +16,20 @@ from onnxruntime import (
 )
 from PIL import Image, UnidentifiedImageError
 
+
 root_dir = Path(__file__).resolve().parent
 InputType = Union[str, np.ndarray, bytes, Path]
 
+
 class EP(Enum):
     CPU_EP = "CPUExecutionProvider"
 
+
 class OrtInferSession:
     def __init__(self, config: Dict[str, Any]):
+        self.logger = loguru.logger
 
         model_path = config.get("model_path", None)
-        self._verify_model(model_path)
 
         self.had_providers: List[str] = get_available_providers()
         EP_list = self._get_ep_list()
@@ -55,18 +59,15 @@ class OrtInferSession:
 
         return sess_opt
 
-    def get_metadata(self, key: str = "character") -> list:
-        meta_dict = self.session.get_modelmeta().custom_metadata_map
-        content_list = meta_dict[key].splitlines()
-        return content_list
-
     def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
         cpu_provider_opts = {
             "arena_extend_strategy": "kSameAsRequested",
         }
         EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
+
         return EP_list
 
+
     def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
         input_dict = dict(zip(self.get_input_names(), input_content))
         try:
@@ -78,54 +79,23 @@ class OrtInferSession:
     def get_input_names(self) -> List[str]:
         return [v.name for v in self.session.get_inputs()]
 
-    def get_output_names(self) -> List[str]:
-        return [v.name for v in self.session.get_outputs()]
-
-    def get_character_list(self, key: str = "character") -> List[str]:
-        meta_dict = self.session.get_modelmeta().custom_metadata_map
-        return meta_dict[key].splitlines()
-
-    def have_key(self, key: str = "character") -> bool:
-        meta_dict = self.session.get_modelmeta().custom_metadata_map
-        if key in meta_dict.keys():
-            return True
-        return False
-
-    @staticmethod
-    def _verify_model(model_path: Union[str, Path, None]):
-        if model_path is None:
-            raise ValueError("model_path is None!")
-
-        model_path = Path(model_path)
-        if not model_path.exists():
-            raise FileNotFoundError(f"{model_path} does not exists.")
-
-        if not model_path.is_file():
-            raise FileExistsError(f"{model_path} is not a file.")
-
 
 class ONNXRuntimeError(Exception):
     pass
 
 
 class LoadImage:
-    """
-    Utility class for loading and converting images from various input types to a numpy ndarray.
-
-    Supported input types:
-        - str or pathlib.Path: Path to an image file.
-        - bytes: Image data in bytes format.
-        - numpy.ndarray: Already loaded image array.
-
-    The class attempts to load the image and convert it to a numpy ndarray in BGR format.
-    Raises LoadImageError for unsupported types or if the image cannot be loaded.
-    """
     def __init__(
         self,
     ):
         pass
 
     def __call__(self, img: InputType) -> np.ndarray:
+        if not isinstance(img, InputType.__args__):
+            raise LoadImageError(
+                f"The img type {type(img)} does not in {InputType.__args__}"
+            )
+
         img = self.load_img(img)
         img = self.convert_img(img)
         return img
@@ -139,18 +109,14 @@ class LoadImage:
                 raise LoadImageError(f"cannot identify image file {img}") from e
             return img
 
-        elif isinstance(img, bytes):
-            try:
-                img = np.array(Image.open(BytesIO(img)))
-            except UnidentifiedImageError as e:
-                raise LoadImageError(f"cannot identify image from bytes data") from e
+        if isinstance(img, bytes):
+            img = np.array(Image.open(BytesIO(img)))
             return img
 
-        elif isinstance(img, np.ndarray):
+        if isinstance(img, np.ndarray):
             return img
 
-        else:
-            raise LoadImageError(f"{type(img)} is not supported!")
+        raise LoadImageError(f"{type(img)} is not supported!")
 
     def convert_img(self, img: np.ndarray):
         if img.ndim == 2:
@@ -388,4 +354,140 @@ def _scale_size(size, scale):
     if isinstance(scale, (float, int)):
         scale = (scale, scale)
     w, h = size
-    return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
+    return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
+
+
+class VisTable:
+    def __init__(self):
+        self.load_img = LoadImage()
+
+    def __call__(
+        self,
+        img_path: Union[str, Path],
+        table_results,
+        save_html_path: Optional[Union[str, Path]] = None,
+        save_drawed_path: Optional[Union[str, Path]] = None,
+        save_logic_path: Optional[Union[str, Path]] = None,
+    ):
+        if save_html_path:
+            html_with_border = self.insert_border_style(table_results.pred_html)
+            self.save_html(save_html_path, html_with_border)
+
+        table_cell_bboxes = table_results.cell_bboxes
+        table_logic_points = table_results.logic_points
+        if table_cell_bboxes is None:
+            return None
+
+        img = self.load_img(img_path)
+
+        dims_bboxes = table_cell_bboxes.shape[1]
+        if dims_bboxes == 4:
+            drawed_img = self.draw_rectangle(img, table_cell_bboxes)
+        elif dims_bboxes == 8:
+            drawed_img = self.draw_polylines(img, table_cell_bboxes)
+        else:
+            raise ValueError("Shape of table bounding boxes is not between in 4 or 8.")
+
+        if save_drawed_path:
+            self.save_img(save_drawed_path, drawed_img)
+
+        if save_logic_path:
+            polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes]
+            self.plot_rec_box_with_logic_info(
+                img_path, save_logic_path, table_logic_points, polygons
+            )
+        return drawed_img
+
+    def insert_border_style(self, table_html_str: str):
+        style_res = """<meta charset="UTF-8"><style>
+        table {
+            border-collapse: collapse;
+            width: 100%;
+        }
+        th, td {
+            border: 1px solid black;
+            padding: 8px;
+            text-align: center;
+        }
+        th {
+            background-color: #f2f2f2;
+        }
+                    </style>"""
+
+        prefix_table, suffix_table = table_html_str.split("<body>")
+        html_with_border = f"{prefix_table}{style_res}<body>{suffix_table}"
+        return html_with_border
+
+    def plot_rec_box_with_logic_info(
+        self, img_path, output_path, logic_points, sorted_polygons
+    ):
+        """
+        :param img_path
+        :param output_path
+        :param logic_points: [row_start,row_end,col_start,col_end]
+        :param sorted_polygons: [xmin,ymin,xmax,ymax]
+        :return:
+        """
+        # 读取原图
+        img = cv2.imread(img_path)
+        img = cv2.copyMakeBorder(
+            img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
+        )
+        # 绘制 polygons 矩形
+        for idx, polygon in enumerate(sorted_polygons):
+            x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
+            x0 = round(x0)
+            y0 = round(y0)
+            x1 = round(x1)
+            y1 = round(y1)
+            cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
+            # 增大字体大小和线宽
+            font_scale = 0.9  # 原先是0.5
+            thickness = 1  # 原先是1
+            logic_point = logic_points[idx]
+            cv2.putText(
+                img,
+                f"row: {logic_point[0]}-{logic_point[1]}",
+                (x0 + 3, y0 + 8),
+                cv2.FONT_HERSHEY_PLAIN,
+                font_scale,
+                (0, 0, 255),
+                thickness,
+            )
+            cv2.putText(
+                img,
+                f"col: {logic_point[2]}-{logic_point[3]}",
+                (x0 + 3, y0 + 18),
+                cv2.FONT_HERSHEY_PLAIN,
+                font_scale,
+                (0, 0, 255),
+                thickness,
+            )
+            os.makedirs(os.path.dirname(output_path), exist_ok=True)
+            # 保存绘制后的图像
+            self.save_img(output_path, img)
+
+    @staticmethod
+    def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
+        img_copy = img.copy()
+        for box in boxes.astype(int):
+            x1, y1, x2, y2 = box
+            cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2)
+        return img_copy
+
+    @staticmethod
+    def draw_polylines(img: np.ndarray, points) -> np.ndarray:
+        img_copy = img.copy()
+        for point in points.astype(int):
+            point = point.reshape(4, 2)
+            cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2)
+        return img_copy
+
+    @staticmethod
+    def save_img(save_path: Union[str, Path], img: np.ndarray):
+        cv2.imwrite(str(save_path), img)
+
+    @staticmethod
+    def save_html(save_path: Union[str, Path], html: str):
+        with open(save_path, "w", encoding="utf-8") as f:
+            f.write(html)

+ 96 - 5
mineru/model/table/rec/unet_table/table_line_rec_utils.py → mineru/model/table/rec/unet_table/utils_table_line_rec.py

@@ -6,6 +6,68 @@ from scipy.spatial import distance as dist
 from skimage import measure
 
 
+def transform_preds(coords, center, scale, output_size, rot=0):
+    target_coords = np.zeros(coords.shape)
+    trans = get_affine_transform(center, scale, rot, output_size, inv=1)
+    for p in range(coords.shape[0]):
+        target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
+    return target_coords
+
+
+def get_affine_transform(
+    center, scale, rot, output_size, shift=np.array([0, 0], dtype=np.float32), inv=0
+):
+    if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
+        scale = np.array([scale, scale], dtype=np.float32)
+
+    scale_tmp = scale
+    src_w = scale_tmp[0]
+    dst_w = output_size[0]
+    dst_h = output_size[1]
+
+    rot_rad = np.pi * rot / 180
+    src_dir = get_dir([0, src_w * -0.5], rot_rad)
+    dst_dir = np.array([0, dst_w * -0.5], np.float32)
+
+    src = np.zeros((3, 2), dtype=np.float32)
+    dst = np.zeros((3, 2), dtype=np.float32)
+    src[0, :] = center + scale_tmp * shift
+    src[1, :] = center + src_dir + scale_tmp * shift
+    dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
+    dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir
+
+    src[2:, :] = get_3rd_point(src[0, :], src[1, :])
+    dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
+
+    if inv:
+        trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
+    else:
+        trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+
+    return trans
+
+
+def affine_transform(pt, t):
+    new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32).T
+    new_pt = np.dot(t, new_pt)
+    return new_pt[:2]
+
+
+def get_dir(src_point, rot_rad):
+    sn, cs = np.sin(rot_rad), np.cos(rot_rad)
+
+    src_result = [0, 0]
+    src_result[0] = src_point[0] * cs - src_point[1] * sn
+    src_result[1] = src_point[0] * sn + src_point[1] * cs
+
+    return src_result
+
+
+def get_3rd_point(a, b):
+    direct = a - b
+    return b + np.array([-direct[1], direct[0]], dtype=np.float32)
+
+
 def get_table_line(binimg, axis=0, lineW=10):
     ##获取表格线
     ##axis=0 横线
@@ -38,7 +100,7 @@ def min_area_rect(coords):
     box = image_location_sort_box(box)
 
     x1, y1, x2, y2, x3, y3, x4, y4 = box
-    w, h = calculate_center_rotate_angle(box)
+    degree, w, h, cx, cy = calculate_center_rotate_angle(box)
     if w < h:
         xmin = (x1 + x2) / 2
         xmax = (x3 + x4) / 2
@@ -50,6 +112,9 @@ def min_area_rect(coords):
         xmax = (x2 + x3) / 2
         ymin = (y1 + y4) / 2
         ymax = (y2 + y3) / 2
+    # degree,w,h,cx,cy = solve(box)
+    # x1,y1,x2,y2,x3,y3,x4,y4 = box
+    # return {'degree':degree,'w':w,'h':h,'cx':cx,'cy':cy}
     return [xmin, ymin, xmax, ymax]
 
 
@@ -62,8 +127,21 @@ def image_location_sort_box(box):
 
 
 def calculate_center_rotate_angle(box):
+    """
+    绕 cx,cy点 w,h 旋转 angle 的坐标,能一定程度缓解图片的内部倾斜,但是还是依赖模型稳妥
+    x = cx-w/2
+    y = cy-h/2
+    x1-cx = -w/2*cos(angle) +h/2*sin(angle)
+    y1 -cy= -w/2*sin(angle) -h/2*cos(angle)
+
+    h(x1-cx) = -wh/2*cos(angle) +hh/2*sin(angle)
+    w(y1 -cy)= -ww/2*sin(angle) -hw/2*cos(angle)
+    (hh+ww)/2sin(angle) = h(x1-cx)-w(y1 -cy)
 
+    """
     x1, y1, x2, y2, x3, y3, x4, y4 = box[:8]
+    cx = (x1 + x3 + x2 + x4) / 4.0
+    cy = (y1 + y3 + y4 + y2) / 4.0
     w = (
         np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
         + np.sqrt((x3 - x4) ** 2 + (y3 - y4) ** 2)
@@ -72,8 +150,11 @@ def calculate_center_rotate_angle(box):
         np.sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2)
         + np.sqrt((x1 - x4) ** 2 + (y1 - y4) ** 2)
     ) / 2
-
-    return w, h
+    # x = cx-w/2
+    # y = cy-h/2
+    sinA = (h * (x1 - cx) - w * (y1 - cy)) * 1.0 / (h * h + w * w) * 2
+    angle = np.arcsin(sinA)
+    return angle, w, h, cx, cy
 
 
 def _order_points(pts):
@@ -222,8 +303,18 @@ def min_area_rect_box(
         box = box.reshape((8,)).tolist()
         box = image_location_sort_box(box)
         x1, y1, x2, y2, x3, y3, x4, y4 = box
-        w, h = calculate_center_rotate_angle(box)
-
+        angle, w, h, cx, cy = calculate_center_rotate_angle(box)
+        # if adjustBox:
+        #     x1, y1, x2, y2, x3, y3, x4, y4 = xy_rotate_box(cx, cy, w + 5, h + 5, angle=0, degree=None)
+        #     x1, x4 = max(x1, 0), max(x4, 0)
+        #     y1, y2 = max(y1, 0), max(y2, 0)
+
+        # if w > 32 and h > 32 and flag:
+        #     if abs(angle / np.pi * 180) < 20:
+        #         if filtersmall and (w < 10 or h < 10):
+        #             continue
+        #         boxes.append([x1, y1, x2, y2, x3, y3, x4, y4])
+        # else:
         if w * h < 0.5 * W * H:
             if filtersmall and (
                 w < 15 or h < 15

+ 24 - 133
mineru/model/table/rec/unet_table/table_recover_utils.py → mineru/model/table/rec/unet_table/utils_table_recover.py

@@ -1,33 +1,6 @@
 from typing import Any, Dict, List, Union, Tuple
 
 import numpy as np
-import shapely
-from shapely.geometry import MultiPoint, Polygon
-
-
-def sorted_boxes(dt_boxes: np.ndarray) -> np.ndarray:
-    """
-    Sort text boxes in order from top to bottom, left to right
-    args:
-        dt_boxes(array):detected text boxes with shape (N, 4, 2)
-    return:
-        sorted boxes(array) with shape (N, 4, 2)
-    """
-    num_boxes = dt_boxes.shape[0]
-    dt_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
-    _boxes = list(dt_boxes)
-
-    # 解决相邻框,后边比前面y轴小,则会被排到前面去的问题
-    for i in range(num_boxes - 1):
-        for j in range(i, -1, -1):
-            if (
-                abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10
-                and _boxes[j + 1][0][0] < _boxes[j][0][0]
-            ):
-                _boxes[j], _boxes[j + 1] = _boxes[j + 1], _boxes[j]
-            else:
-                break
-    return np.array(_boxes)
 
 
 def calculate_iou(
@@ -63,6 +36,7 @@ def calculate_iou(
     return iou
 
 
+
 def is_box_contained(
     box1: Union[np.ndarray, List], box2: Union[np.ndarray, List], threshold=0.2
 ) -> Union[int, None]:
@@ -111,7 +85,7 @@ def is_single_axis_contained(
     box1: Union[np.ndarray, List],
     box2: Union[np.ndarray, List],
     axis="x",
-    threshold: float = 0.2,
+    threhold: float = 0.2,
 ) -> Union[int, None]:
     """
     :param box1: Iterable [xmin,ymin,xmax,ymax]
@@ -136,15 +110,15 @@ def is_single_axis_contained(
 
     ratio_b1 = b1_outside_area / b1_area if b1_area > 0 else 0
     ratio_b2 = b2_outside_area / b2_area if b2_area > 0 else 0
-    if ratio_b1 < threshold:
+    if ratio_b1 < threhold:
         return 1
-    if ratio_b2 < threshold:
+    if ratio_b2 < threhold:
         return 2
     return None
 
 
 def sorted_ocr_boxes(
-    dt_boxes: Union[np.ndarray, list], threshold: float = 0.2
+    dt_boxes: Union[np.ndarray, list], threhold: float = 0.2
 ) -> Tuple[Union[np.ndarray, list], List[int]]:
     """
     Sort text boxes in order from top to bottom, left to right
@@ -161,18 +135,19 @@ def sorted_ocr_boxes(
     _boxes, indices = zip(*sorted_boxes_with_idx)
     indices = list(indices)
     _boxes = [dt_boxes[i] for i in indices]
+    threahold = 20
     # 避免输出和输入格式不对应,与函数功能不符合
     if isinstance(dt_boxes, np.ndarray):
         _boxes = np.array(_boxes)
     for i in range(num_boxes - 1):
         for j in range(i, -1, -1):
             c_idx = is_single_axis_contained(
-                _boxes[j], _boxes[j + 1], axis="y", threshold=threshold
+                _boxes[j], _boxes[j + 1], axis="y", threhold=threhold
             )
             if (
                 c_idx is not None
                 and _boxes[j + 1][0] < _boxes[j][0]
-                and abs(_boxes[j][1] - _boxes[j + 1][1]) < 20
+                and abs(_boxes[j][1] - _boxes[j + 1][1]) < threahold
             ):
                 _boxes[j], _boxes[j + 1] = _boxes[j + 1].copy(), _boxes[j].copy()
                 indices[j], indices[j + 1] = indices[j + 1], indices[j]
@@ -181,6 +156,11 @@ def sorted_ocr_boxes(
     return _boxes, indices
 
 
+def box_4_1_poly_to_box_4_2(poly_box: Union[list, np.ndarray]) -> List[List[float]]:
+    xmin, ymin, xmax, ymax = tuple(poly_box)
+    return [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
+
+
 def box_4_2_poly_to_box_4_1(poly_box: Union[list, np.ndarray]) -> List[Any]:
     """
     将poly_box转换为box_4_1
@@ -221,18 +201,12 @@ def match_ocr_cell(dt_rec_boxes: List[List[Union[Any, str]]], pred_bboxes: np.nd
     return matched, not_match_orc_boxes
 
 
-def gather_ocr_list_by_row(ocr_list: List[Any], threshold: float = 0.2) -> List[Any]:
+def gather_ocr_list_by_row(ocr_list: List[Any], threhold: float = 0.2) -> List[Any]:
     """
-        Groups OCR results by row based on the vertical (y-axis) overlap of their bounding boxes.
-    Args:
-        ocr_list (List[Any]): A list of OCR results, where each item is a list containing a bounding box
-            in the format [xmin, ymin, xmax, ymax] and the recognized text.
-        threshold (float, optional): The threshold for determining if two boxes are in the same row,
-            based on their y-axis overlap. Default is 0.2.
-    Returns:
-        List[Any]: A new list of OCR results where texts in the same row are merged, and their bounding
-            boxes are updated to encompass the merged text.
+    :param ocr_list: [[[xmin,ymin,xmax,ymax], text]]
+    :return:
     """
+    threshold = 10
     for i in range(len(ocr_list)):
         if not ocr_list[i]:
             continue
@@ -245,11 +219,11 @@ def gather_ocr_list_by_row(ocr_list: List[Any], threshold: float = 0.2) -> List[
             cur_box = cur[0]
             next_box = next[0]
             c_idx = is_single_axis_contained(
-                cur[0], next[0], axis="y", threshold=threshold
+                cur[0], next[0], axis="y", threhold=threhold
             )
             if c_idx:
                 dis = max(next_box[0] - cur_box[2], 0)
-                blank_str = int(dis / 10) * " "
+                blank_str = int(dis / threshold) * " "
                 cur[1] = cur[1] + blank_str + next[1]
                 xmin = min(cur_box[0], next_box[0])
                 xmax = max(cur_box[2], next_box[2])
@@ -264,93 +238,6 @@ def gather_ocr_list_by_row(ocr_list: List[Any], threshold: float = 0.2) -> List[
     return ocr_list
 
 
-def compute_poly_iou(a: np.ndarray, b: np.ndarray) -> float:
-    """计算两个多边形的IOU
-
-    Args:
-        poly1 (np.ndarray): (4, 2)
-        poly2 (np.ndarray): (4, 2)
-
-    Returns:
-        float: iou
-    """
-    poly1 = Polygon(a).convex_hull
-    poly2 = Polygon(b).convex_hull
-
-    union_poly = np.concatenate((a, b))
-
-    if not poly1.intersects(poly2):
-        return 0.0
-
-    try:
-        inter_area = poly1.intersection(poly2).area
-        union_area = MultiPoint(union_poly).convex_hull.area
-    except shapely.geos.TopologicalError:
-        print("shapely.geos.TopologicalError occured, iou set to 0")
-        return 0.0
-
-    if union_area == 0:
-        return 0.0
-
-    return float(inter_area) / union_area
-
-
-def merge_adjacent_polys(polygons: np.ndarray) -> np.ndarray:
-    """合并相邻iou大于阈值的框"""
-    combine_iou_thresh = 0.1
-    pair_polygons = list(zip(polygons, polygons[1:, ...]))
-    pair_ious = np.array([compute_poly_iou(p1, p2) for p1, p2 in pair_polygons])
-    idxs = np.argwhere(pair_ious >= combine_iou_thresh)
-
-    if idxs.size <= 0:
-        return polygons
-
-    polygons = combine_two_poly(polygons, idxs)
-
-    # 注意:递归调用
-    polygons = merge_adjacent_polys(polygons)
-    return polygons
-
-
-def combine_two_poly(polygons: np.ndarray, idxs: np.ndarray) -> np.ndarray:
-    del_idxs, insert_boxes = [], []
-    idxs = idxs.squeeze(-1)
-    for idx in idxs:
-        # idx 和 idx + 1 是重合度过高的
-        # 合并,取两者各个点的最大值
-        new_poly = []
-        pre_poly, pos_poly = polygons[idx], polygons[idx + 1]
-
-        # 四个点,每个点逐一比较
-        new_poly.append(np.minimum(pre_poly[0], pos_poly[0]))
-
-        x_2 = min(pre_poly[1][0], pos_poly[1][0])
-        y_2 = max(pre_poly[1][1], pos_poly[1][1])
-        new_poly.append([x_2, y_2])
-
-        # 第3个点
-        new_poly.append(np.maximum(pre_poly[2], pos_poly[2]))
-
-        # 第4个点
-        x_4 = max(pre_poly[3][0], pos_poly[3][0])
-        y_4 = min(pre_poly[3][1], pos_poly[3][1])
-        new_poly.append([x_4, y_4])
-
-        new_poly = np.array(new_poly)
-
-        # 删除已经合并的两个框,插入新的框
-        del_idxs.extend([idx, idx + 1])
-        insert_boxes.append(new_poly)
-
-    # 整合合并后的框
-    polygons = np.delete(polygons, del_idxs, axis=0)
-
-    insert_boxes = np.array(insert_boxes)
-    polygons = np.append(polygons, insert_boxes, axis=0)
-    polygons = sorted_boxes(polygons)
-    return polygons
-
-
 def plot_html_table(
     logi_points: Union[Union[np.ndarray, List]], cell_box_map: Dict[int, List[str]]
 ) -> str:
@@ -405,7 +292,8 @@ def plot_html_table(
                     continue
                 if row == row_start and col == col_start:
                     ocr_rec_text = cell_box_map.get(i)
-                    text = "<br>".join(ocr_rec_text)
+                    # text = "<br>".join(ocr_rec_text)
+                    text = "".join(ocr_rec_text)
                     # 如果是起始单元格
                     row_span = row_end - row_start + 1
                     col_span = col_end - col_start + 1
@@ -418,3 +306,6 @@ def plot_html_table(
 
     table_html += "</table></body></html>"
     return table_html
+
+
+

+ 25 - 7
mineru/utils/model_utils.py

@@ -324,7 +324,7 @@ def remove_overlaps_low_confidence_blocks(combined_res_list, overlap_threshold=0
                 marked_indices.add(i)  # 标记当前索引为已处理
     return blocks_to_remove
 
-
+# @todo 这个方法以后需要重构
 def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshold=0.8, area_threshold=0.8):
     """Extract OCR, table and other regions from layout results."""
     ocr_res_list = []
@@ -358,14 +358,31 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
     filtered_table_res_list = filter_nested_tables(
         table_res_list, overlap_threshold, area_threshold)
 
+    for table_res in filtered_table_res_list:
+        table_res['bbox'] = [int(table_res['poly'][0]), int(table_res['poly'][1]), int(table_res['poly'][4]), int(table_res['poly'][5])]
+
+    filtered_table_res_list, table_need_remove = remove_overlaps_min_blocks(filtered_table_res_list)
+
+    for res in filtered_table_res_list:
+        # 将res的poly使用bbox重构
+        res['poly'] = [res['bbox'][0], res['bbox'][1], res['bbox'][2], res['bbox'][1],
+                       res['bbox'][2], res['bbox'][3], res['bbox'][0], res['bbox'][3]]
+        # 删除res的bbox
+        del res['bbox']
+
+    if len(table_need_remove) > 0:
+        for res in table_need_remove:
+            del res['bbox']
+            if res in layout_res:
+                layout_res.remove(res)
+
     # Remove filtered out tables from layout_res
     if len(filtered_table_res_list) < len(table_res_list):
         kept_tables = set(id(table) for table in filtered_table_res_list)
-        to_remove = [table_indices[i] for i, table in enumerate(table_res_list)
-                     if id(table) not in kept_tables]
-
-        for idx in sorted(to_remove, reverse=True):
-            del layout_res[idx]
+        tables_to_remove = [table for table in table_res_list if id(table) not in kept_tables]
+        for table in tables_to_remove:
+            if table in layout_res:
+                layout_res.remove(table)
 
     # Remove overlaps in OCR and text regions
     text_res_list, need_remove = remove_overlaps_min_blocks(text_res_list)
@@ -381,7 +398,8 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
     if len(need_remove) > 0:
         for res in need_remove:
             del res['bbox']
-            layout_res.remove(res)
+            if res in layout_res:
+                layout_res.remove(res)
 
     # 检测大block内部是否包含多个小block, 合并ocr和table列表进行检测
     combined_res_list = ocr_res_list + filtered_table_res_list

+ 1 - 18
pyproject.toml

@@ -68,7 +68,7 @@ pipeline = [
     "shapely>=2.0.7,<3",
     "pyclipper>=1.3.0,<2",
     "omegaconf>=2.3.0,<3",
-    "torch>=2.2.2,!=2.5.0,!=2.5.1,<3",
+    "torch>=2.6.0,<3",
     "torchvision",
     "transformers>=4.49.0,!=4.51.0,<5.0.0",
 ]
@@ -91,23 +91,6 @@ all = [
     "mineru[core]",
     "mineru[sglang]",
 ]
-pipeline_old_linux = [
-    "matplotlib>=3.10,<=3.10.1",
-    "ultralytics>=8.3.48,<=8.3.104",
-    "doclayout_yolo==0.0.4",
-    "dill==0.3.8",
-    "PyYAML==6.0.2",
-    "ftfy==6.3.1",
-    "openai==1.71.0",
-    "shapely==2.1.0",
-    "pyclipper==1.3.0.post6",
-    "omegaconf==2.3.0",
-    "albumentations==1.4.20",
-    "rapid_table==1.0.3",
-    "torch>=2.2.2,!=2.5.0,!=2.5.1,<3",
-    "torchvision",
-    "transformers>=4.49.0,!=4.51.0,<5.0.0",
-]
 
 [project.urls]
 homepage = "https://mineru.net/"