Explorar o código

Merge branch 'cxz-dev' into dev

# Conflicts:
#	mineru/backend/pipeline/batch_analyze.py
#	mineru/model/ori_cls/paddle_ori_cls.py
#	mineru/model/table/cls/paddle_table_cls.py
Sidney233 hai 3 meses
pai
achega
efa4a5b7f1

+ 18 - 8
mineru/backend/pipeline/batch_analyze.py

@@ -9,6 +9,7 @@ from .model_list import AtomicModel
 from ...utils.config_reader import get_formula_enable, get_table_enable
 from ...utils.config_reader import get_formula_enable, get_table_enable
 from ...utils.model_utils import crop_img, get_res_list_from_layout_res
 from ...utils.model_utils import crop_img, get_res_list_from_layout_res
 from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
 from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
+from ...utils.pdf_image_tools import get_crop_img
 
 
 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
@@ -41,9 +42,7 @@ class BatchAnalyze:
         images = [image for image, _, _ in images_with_extra_info]
         images = [image for image, _, _ in images_with_extra_info]
 
 
         # doclayout_yolo
         # doclayout_yolo
-        layout_images = []
-        for image_index, image in enumerate(images):
-            layout_images.append(image)
+        layout_images = images.copy()
 
 
         images_layout_res += self.model.layout_model.batch_predict(
         images_layout_res += self.model.layout_model.batch_predict(
             layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
             layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
@@ -89,7 +88,14 @@ class BatchAnalyze:
                                           })
                                           })
 
 
             for table_res in table_res_list:
             for table_res in table_res_list:
-                table_img, _ = crop_img(table_res, pil_img)
+                # table_img, _ = crop_img(table_res, pil_img)
+                # bbox = (241, 208, 1475, 2019)
+                scale = 10/3
+                crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
+                crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
+                bbox = (int(crop_xmin/scale), int(crop_ymin/scale), int(crop_xmax/scale), int(crop_ymax/scale))
+                table_img = get_crop_img(bbox, pil_img, scale=scale)
+
                 table_res_list_all_page.append({'table_res':table_res,
                 table_res_list_all_page.append({'table_res':table_res,
                                                 'lang':_lang,
                                                 'lang':_lang,
                                                 'table_img':table_img,
                                                 'table_img':table_img,
@@ -140,14 +146,17 @@ class BatchAnalyze:
                 )
                 )
 
 
                 # 按分辨率分组并同时完成padding
                 # 按分辨率分组并同时完成padding
+                # RESOLUTION_GROUP_STRIDE = 32
+                RESOLUTION_GROUP_STRIDE = 64  # 定义分辨率分组的步进值
+
                 resolution_groups = defaultdict(list)
                 resolution_groups = defaultdict(list)
                 for crop_info in lang_crop_list:
                 for crop_info in lang_crop_list:
                     cropped_img = crop_info[0]
                     cropped_img = crop_info[0]
                     h, w = cropped_img.shape[:2]
                     h, w = cropped_img.shape[:2]
                     # 使用更大的分组容差,减少分组数量
                     # 使用更大的分组容差,减少分组数量
                     # 将尺寸标准化到32的倍数
                     # 将尺寸标准化到32的倍数
-                    normalized_h = ((h + 32) // 32) * 32  # 向上取整到32的倍数
-                    normalized_w = ((w + 32) // 32) * 32
+                    normalized_h = ((h + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE  # 向上取整到32的倍数
+                    normalized_w = ((w + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
                     group_key = (normalized_h, normalized_w)
                     group_key = (normalized_h, normalized_w)
                     resolution_groups[group_key].append(crop_info)
                     resolution_groups[group_key].append(crop_info)
 
 
@@ -157,8 +166,8 @@ class BatchAnalyze:
                     # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
                     # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
                     max_h = max(crop_info[0].shape[0] for crop_info in group_crops)
                     max_h = max(crop_info[0].shape[0] for crop_info in group_crops)
                     max_w = max(crop_info[0].shape[1] for crop_info in group_crops)
                     max_w = max(crop_info[0].shape[1] for crop_info in group_crops)
-                    target_h = ((max_h + 32 - 1) // 32) * 32
-                    target_w = ((max_w + 32 - 1) // 32) * 32
+                    target_h = ((max_h + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
+                    target_w = ((max_w + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
 
 
                     # 对所有图像进行padding到统一尺寸
                     # 对所有图像进行padding到统一尺寸
                     batch_images = []
                     batch_images = []
@@ -287,6 +296,7 @@ class BatchAnalyze:
                     raise ValueError(
                     raise ValueError(
                         "Table classification failed, please check the model"
                         "Table classification failed, please check the model"
                     )
                     )
+
                 # 根据表格分类结果选择有线表格识别模型和无线表格识别模型
                 # 根据表格分类结果选择有线表格识别模型和无线表格识别模型
                 table_model = atom_model_manager.get_atom_model(
                 table_model = atom_model_manager.get_atom_model(
                     atom_model_name=table_label,
                     atom_model_name=table_label,

+ 1 - 1
mineru/backend/pipeline/model_init.py

@@ -11,7 +11,7 @@ from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
 from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
 from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
 from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
 from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
 from ...model.table.rec.rapid_table import RapidTableModel
 from ...model.table.rec.rapid_table import RapidTableModel
-from ...model.table.rec.unet_table.unet_table import UnetTableModel
+from ...model.table.rec.unet_table.main import UnetTableModel
 from ...utils.enum_class import ModelPath
 from ...utils.enum_class import ModelPath
 from ...utils.models_download_utils import auto_download_and_get_model_root_path
 from ...utils.models_download_utils import auto_download_and_get_model_root_path
 
 

+ 7 - 4
mineru/backend/pipeline/pipeline_analyze.py

@@ -1,7 +1,7 @@
 import os
 import os
 import time
 import time
 from typing import List, Tuple
 from typing import List, Tuple
-import PIL.Image
+from PIL import Image
 from loguru import logger
 from loguru import logger
 
 
 from .model_init import MineruPipelineModel
 from .model_init import MineruPipelineModel
@@ -148,10 +148,9 @@ def doc_analyze(
 
 
 
 
 def batch_image_analyze(
 def batch_image_analyze(
-        images_with_extra_info: List[Tuple[PIL.Image.Image, bool, str]],
+        images_with_extra_info: List[Tuple[Image.Image, bool, str]],
         formula_enable=True,
         formula_enable=True,
         table_enable=True):
         table_enable=True):
-    # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
 
 
     from .batch_analyze import BatchAnalyze
     from .batch_analyze import BatchAnalyze
 
 
@@ -191,10 +190,14 @@ def batch_image_analyze(
             batch_ratio = 1
             batch_ratio = 1
             logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
             logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
 
 
-    if str(device).startswith('mps'):
+    # 检测torch的版本号
+    import torch
+    from packaging import version
+    if version.parse(torch.__version__) >= version.parse("2.8.0") or str(device).startswith('mps'):
         enable_ocr_det_batch = False
         enable_ocr_det_batch = False
     else:
     else:
         enable_ocr_det_batch = True
         enable_ocr_det_batch = True
+
     batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable, enable_ocr_det_batch)
     batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable, enable_ocr_det_batch)
     results = batch_model(images_with_extra_info)
     results = batch_model(images_with_extra_info)
 
 

+ 1 - 1
mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py

@@ -89,7 +89,7 @@ class PytorchPaddleOCR(TextSystem):
         kwargs['det_model_path'] = det_model_path
         kwargs['det_model_path'] = det_model_path
         kwargs['rec_model_path'] = rec_model_path
         kwargs['rec_model_path'] = rec_model_path
         kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
         kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
-        kwargs['rec_batch_num'] = 16
+        kwargs['rec_batch_num'] = 8
 
 
         kwargs['device'] = device
         kwargs['device'] = device
 
 

+ 19 - 25
mineru/model/ori_cls/paddle_ori_cls.py

@@ -1,12 +1,13 @@
 # Copyright (c) Opendatalab. All rights reserved.
 # Copyright (c) Opendatalab. All rights reserved.
 import os
 import os
+
+from PIL import Image
 from collections import defaultdict
 from collections import defaultdict
 from typing import List, Dict
 from typing import List, Dict
 from tqdm import tqdm
 from tqdm import tqdm
 import cv2
 import cv2
 import numpy as np
 import numpy as np
 import onnxruntime
 import onnxruntime
-from PIL import Image
 
 
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
@@ -15,12 +16,7 @@ from mineru.utils.models_download_utils import auto_download_and_get_model_root_
 class PaddleOrientationClsModel:
 class PaddleOrientationClsModel:
     def __init__(self, ocr_engine):
     def __init__(self, ocr_engine):
         self.sess = onnxruntime.InferenceSession(
         self.sess = onnxruntime.InferenceSession(
-            os.path.join(
-                auto_download_and_get_model_root_path(
-                    ModelPath.paddle_orientation_classification
-                ),
-                ModelPath.paddle_orientation_classification,
-            )
+            os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_orientation_classification), ModelPath.paddle_orientation_classification)
         )
         )
         self.ocr_engine = ocr_engine
         self.ocr_engine = ocr_engine
         self.less_length = 256
         self.less_length = 256
@@ -30,15 +26,13 @@ class PaddleOrientationClsModel:
         self.mean = [0.485, 0.456, 0.406]
         self.mean = [0.485, 0.456, 0.406]
         self.labels = ["0", "90", "180", "270"]
         self.labels = ["0", "90", "180", "270"]
 
 
-    def preprocess(self, img):
-        # PIL图像转cv2
-        img = np.array(img)
+    def preprocess(self, input_img):
         # 放大图片,使其最短边长为256
         # 放大图片,使其最短边长为256
-        h, w = img.shape[:2]
+        h, w = input_img.shape[:2]
         scale = 256 / min(h, w)
         scale = 256 / min(h, w)
         h_resize = round(h * scale)
         h_resize = round(h * scale)
         w_resize = round(w * scale)
         w_resize = round(w * scale)
-        img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
+        img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
         # 调整为224*224的正方形
         # 调整为224*224的正方形
         h, w = img.shape[:2]
         h, w = img.shape[:2]
         cw, ch = 224, 224
         cw, ch = 224, 224
@@ -69,8 +63,15 @@ class PaddleOrientationClsModel:
         x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
         x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
         return x
         return x
 
 
-    def predict(self, img):
-        bgr_image = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
+    def predict(self, input_img):
+        rotate_label = "0"  # Default to 0 if no rotation detected or not portrait
+        if isinstance(input_img, Image.Image):
+            np_img = np.asarray(input_img)
+        elif isinstance(input_img, np.ndarray):
+            np_img = input_img
+        else:
+            raise ValueError("Input must be a pillow object or a numpy array.")
+        bgr_image = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
         # First check the overall image aspect ratio (height/width)
         # First check the overall image aspect ratio (height/width)
         img_height, img_width = bgr_image.shape[:2]
         img_height, img_width = bgr_image.shape[:2]
         img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
         img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
@@ -106,19 +107,12 @@ class PaddleOrientationClsModel:
                 # If we have more vertical text boxes than horizontal ones,
                 # If we have more vertical text boxes than horizontal ones,
                 # and vertical ones are significant, table might be rotated
                 # and vertical ones are significant, table might be rotated
                 if is_rotated:
                 if is_rotated:
-                    x = self.preprocess(img)
+                    x = self.preprocess(np_img)
                     (result,) = self.sess.run(None, {"x": x})
                     (result,) = self.sess.run(None, {"x": x})
-                    label = self.labels[np.argmax(result)]
+                    rotate_label = self.labels[np.argmax(result)]
                     # logger.debug(f"Orientation classification result: {label}")
                     # logger.debug(f"Orientation classification result: {label}")
-                    if label == "270":
-                        img = cv2.rotate(np.asarray(img), cv2.ROTATE_90_CLOCKWISE)
-                    elif label == "90":
-                        img = cv2.rotate(
-                            np.asarray(img), cv2.ROTATE_90_COUNTERCLOCKWISE
-                        )
-                    else:
-                        pass
-        return img
+
+        return rotate_label
 
 
     def list_2_batch(self, img_list, batch_size=16):
     def list_2_batch(self, img_list, batch_size=16):
         """
         """

+ 20 - 16
mineru/model/table/cls/paddle_table_cls.py

@@ -1,6 +1,7 @@
 import os
 import os
 from pathlib import Path
 from pathlib import Path
 
 
+from PIL import Image
 import cv2
 import cv2
 import numpy as np
 import numpy as np
 import onnxruntime
 import onnxruntime
@@ -23,15 +24,13 @@ class PaddleTableClsModel:
         self.mean = [0.485, 0.456, 0.406]
         self.mean = [0.485, 0.456, 0.406]
         self.labels = [AtomicModel.WiredTable, AtomicModel.WirelessTable]
         self.labels = [AtomicModel.WiredTable, AtomicModel.WirelessTable]
 
 
-    def preprocess(self, img):
-        # PIL图像转cv2
-        img = np.array(img)
+    def preprocess(self, input_img):
         # 放大图片,使其最短边长为256
         # 放大图片,使其最短边长为256
-        h, w = img.shape[:2]
+        h, w = input_img.shape[:2]
         scale = 256 / min(h, w)
         scale = 256 / min(h, w)
         h_resize = round(h * scale)
         h_resize = round(h * scale)
         w_resize = round(w * scale)
         w_resize = round(w * scale)
-        img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
+        img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
         # 调整为224*224的正方形
         # 调整为224*224的正方形
         h, w = img.shape[:2]
         h, w = img.shape[:2]
         cw, ch = 224, 224
         cw, ch = 224, 224
@@ -62,6 +61,22 @@ class PaddleTableClsModel:
         x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
         x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
         return x
         return x
 
 
+    def predict(self, input_img):
+        if isinstance(input_img, Image.Image):
+            np_img = np.asarray(input_img)
+        elif isinstance(input_img, np.ndarray):
+            np_img = input_img
+        else:
+            raise ValueError("Input must be a pillow object or a numpy array.")
+        x = self.preprocess(np_img)
+        result = self.sess.run(None, {"x": x})
+        idx = np.argmax(result)
+        conf = float(np.max(result))
+        # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
+        if idx == 0 and conf < 0.8:
+            idx = 1
+        return self.labels[idx], conf
+
     def list_2_batch(self, img_list, batch_size=16):
     def list_2_batch(self, img_list, batch_size=16):
         """
         """
         将任意长度的列表按照指定的batch size分成多个batch
         将任意长度的列表按照指定的batch size分成多个batch
@@ -120,17 +135,6 @@ class PaddleTableClsModel:
             res_imgs.append(img)
             res_imgs.append(img)
         x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
         x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
         return x
         return x
-
-    def predict(self, img):
-        x = self.preprocess(img)
-        result = self.sess.run(None, {"x": x})
-        idx = np.argmax(result)
-        conf = float(np.max(result))
-        # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
-        if idx == 0 and conf < 0.8:
-            idx = 1
-        return self.labels[idx], conf
-
     def batch_predict(self, img_info_list, batch_size=16):
     def batch_predict(self, img_info_list, batch_size=16):
         imgs = [item["table_img"] for item in img_info_list]
         imgs = [item["table_img"] for item in img_info_list]
         imgs = self.list_2_batch(imgs, batch_size=batch_size)
         imgs = self.list_2_batch(imgs, batch_size=batch_size)

+ 0 - 1
mineru/model/table/rec/unet_table/__init__.py

@@ -1 +0,0 @@
-# Copyright (c) Opendatalab. All rights reserved.

+ 93 - 40
mineru/model/table/rec/unet_table/unet_table.py → mineru/model/table/rec/unet_table/main.py

@@ -1,21 +1,24 @@
 import html
 import html
+import logging
 import os
 import os
 import time
 import time
 import traceback
 import traceback
 from dataclasses import dataclass, asdict
 from dataclasses import dataclass, asdict
-from typing import List, Optional, Union, Dict, Any
 
 
-import cv2
+from typing import List, Optional, Union, Dict, Any
 import numpy as np
 import numpy as np
+import cv2
+from PIL import Image
 from loguru import logger
 from loguru import logger
 from rapid_table import RapidTableInput, RapidTable
 from rapid_table import RapidTableInput, RapidTable
 
 
+from .table_structure_unet import TSRUnet
+
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
-from .table_structure_unet import TSRUnet
 from .table_recover import TableRecover
 from .table_recover import TableRecover
-from .wired_table_rec_utils import InputType, LoadImage
-from .table_recover_utils import (
+from .utils import InputType, LoadImage, VisTable
+from .utils_table_recover import (
     match_ocr_cell,
     match_ocr_cell,
     plot_html_table,
     plot_html_table,
     box_4_2_poly_to_box_4_1,
     box_4_2_poly_to_box_4_1,
@@ -25,32 +28,32 @@ from .table_recover_utils import (
 
 
 
 
 @dataclass
 @dataclass
-class UnetTableInput:
+class WiredTableInput:
     model_path: str
     model_path: str
     device: str = "cpu"
     device: str = "cpu"
 
 
 
 
 @dataclass
 @dataclass
-class UnetTableOutput:
+class WiredTableOutput:
     pred_html: Optional[str] = None
     pred_html: Optional[str] = None
     cell_bboxes: Optional[np.ndarray] = None
     cell_bboxes: Optional[np.ndarray] = None
     logic_points: Optional[np.ndarray] = None
     logic_points: Optional[np.ndarray] = None
     elapse: Optional[float] = None
     elapse: Optional[float] = None
 
 
 
 
-class UnetTableRecognition:
-    def __init__(self, config: UnetTableInput):
+class WiredTableRecognition:
+    def __init__(self, config: WiredTableInput, ocr_engine=None):
         self.table_structure = TSRUnet(asdict(config))
         self.table_structure = TSRUnet(asdict(config))
         self.load_img = LoadImage()
         self.load_img = LoadImage()
         self.table_recover = TableRecover()
         self.table_recover = TableRecover()
+        self.ocr_engine = ocr_engine
 
 
     def __call__(
     def __call__(
         self,
         self,
         img: InputType,
         img: InputType,
         ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
         ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
-        ocr_engine = None,
         **kwargs,
         **kwargs,
-    ) -> UnetTableOutput:
+    ) -> WiredTableOutput:
         s = time.perf_counter()
         s = time.perf_counter()
         need_ocr = True
         need_ocr = True
         col_threshold = 15
         col_threshold = 15
@@ -62,8 +65,8 @@ class UnetTableRecognition:
         img = self.load_img(img)
         img = self.load_img(img)
         polygons, rotated_polygons = self.table_structure(img, **kwargs)
         polygons, rotated_polygons = self.table_structure(img, **kwargs)
         if polygons is None:
         if polygons is None:
-            # logger.warning("polygons is None.")
-            return UnetTableOutput("", None, None, 0.0)
+            logging.warning("polygons is None.")
+            return WiredTableOutput("", None, None, 0.0)
 
 
         try:
         try:
             table_res, logi_points = self.table_recover(
             table_res, logi_points = self.table_recover(
@@ -78,7 +81,7 @@ class UnetTableRecognition:
                 sorted_polygons, idx_list = sorted_ocr_boxes(
                 sorted_polygons, idx_list = sorted_ocr_boxes(
                     [box_4_2_poly_to_box_4_1(box) for box in polygons]
                     [box_4_2_poly_to_box_4_1(box) for box in polygons]
                 )
                 )
-                return UnetTableOutput(
+                return WiredTableOutput(
                     "",
                     "",
                     sorted_polygons,
                     sorted_polygons,
                     logi_points[idx_list],
                     logi_points[idx_list],
@@ -86,12 +89,12 @@ class UnetTableRecognition:
                 )
                 )
             cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
             cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
             # 如果有识别框没有ocr结果,直接进行rec补充
             # 如果有识别框没有ocr结果,直接进行rec补充
-            cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map, ocr_engine)
+            cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map)
             # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
             # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
             t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
             t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
             # 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
             # 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
             t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
             t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
-            # cell_box_map =
+
             logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list]
             logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list]
             cell_box_det_map = {
             cell_box_det_map = {
                 i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
                 i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
@@ -103,9 +106,9 @@ class UnetTableRecognition:
             elapse = time.perf_counter() - s
             elapse = time.perf_counter() - s
 
 
         except Exception:
         except Exception:
-            logger.warning(traceback.format_exc())
-            return UnetTableOutput("", None, None, 0.0)
-        return UnetTableOutput(pred_html, polygons, logi_points, elapse)
+            logging.warning(traceback.format_exc())
+            return WiredTableOutput("", None, None, 0.0)
+        return WiredTableOutput(pred_html, polygons, logi_points, elapse)
 
 
     def transform_res(
     def transform_res(
         self,
         self,
@@ -139,44 +142,61 @@ class UnetTableRecognition:
     def sort_and_gather_ocr_res(self, res):
     def sort_and_gather_ocr_res(self, res):
         for i, dict_res in enumerate(res):
         for i, dict_res in enumerate(res):
             _, sorted_idx = sorted_ocr_boxes(
             _, sorted_idx = sorted_ocr_boxes(
-                [ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threshold=0.3
+                [ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threhold=0.3
             )
             )
             dict_res["t_ocr_res"] = [dict_res["t_ocr_res"][i] for i in sorted_idx]
             dict_res["t_ocr_res"] = [dict_res["t_ocr_res"][i] for i in sorted_idx]
             dict_res["t_ocr_res"] = gather_ocr_list_by_row(
             dict_res["t_ocr_res"] = gather_ocr_list_by_row(
-                dict_res["t_ocr_res"], threshold=0.3
+                dict_res["t_ocr_res"], threhold=0.3
             )
             )
         return res
         return res
 
 
+    # def fill_blank_rec(
+    #     self,
+    #     img: np.ndarray,
+    #     sorted_polygons: np.ndarray,
+    #     cell_box_map: Dict[int, List[str]],
+    # ) -> Dict[int, List[Any]]:
+    #     """找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
+    #     for i in range(sorted_polygons.shape[0]):
+    #         if cell_box_map.get(i):
+    #             continue
+    #         box = sorted_polygons[i]
+    #         cell_box_map[i] = [[box, "", 1]]
+    #         continue
+    #     return cell_box_map
     def fill_blank_rec(
     def fill_blank_rec(
         self,
         self,
         img: np.ndarray,
         img: np.ndarray,
         sorted_polygons: np.ndarray,
         sorted_polygons: np.ndarray,
         cell_box_map: Dict[int, List[str]],
         cell_box_map: Dict[int, List[str]],
-        ocr_engine
     ) -> Dict[int, List[Any]]:
     ) -> Dict[int, List[Any]]:
         """找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
         """找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
+        bgr_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
         img_crop_info_list = []
         img_crop_info_list = []
         img_crop_list = []
         img_crop_list = []
         for i in range(sorted_polygons.shape[0]):
         for i in range(sorted_polygons.shape[0]):
             if cell_box_map.get(i):
             if cell_box_map.get(i):
                 continue
                 continue
             box = sorted_polygons[i]
             box = sorted_polygons[i]
-            if ocr_engine is None:
+            if self.ocr_engine is None:
                 logger.warning(f"No OCR engine provided for box {i}: {box}")
                 logger.warning(f"No OCR engine provided for box {i}: {box}")
                 continue
                 continue
             # 从img中截取对应的区域
             # 从img中截取对应的区域
-            x1, y1, x2, y2 = box[0][0], box[0][1], box[2][0], box[2][1]
+            x1, y1, x2, y2 = int(box[0][0]), int(box[0][1]), int(box[2][0]), int(box[2][1])
             if x1 >= x2 or y1 >= y2:
             if x1 >= x2 or y1 >= y2:
                 logger.warning(f"Invalid box coordinates: {box}")
                 logger.warning(f"Invalid box coordinates: {box}")
                 continue
                 continue
-            img_crop = img[int(y1):int(y2), int(x1):int(x2)]
+            # 判断长宽比
+            if (x2 - x1) / (y2 - y1) > 20 or (y2 - y1) / (x2 - x1) > 20:
+                logger.warning(f"Box {i} has invalid aspect ratio: {x1, y1, x2, y2}")
+                continue
+            img_crop = bgr_img[int(y1):int(y2), int(x1):int(x2)]
             img_crop_list.append(img_crop)
             img_crop_list.append(img_crop)
             img_crop_info_list.append([i, box])
             img_crop_info_list.append([i, box])
-            continue
 
 
         if len(img_crop_list) > 0:
         if len(img_crop_list) > 0:
             # 进行ocr识别
             # 进行ocr识别
-            ocr_result = ocr_engine.ocr(img_crop_list, det=False)
+            ocr_result = self.ocr_engine.ocr(img_crop_list, det=False)
             if not ocr_result or not isinstance(ocr_result, list) or len(ocr_result) == 0:
             if not ocr_result or not isinstance(ocr_result, list) or len(ocr_result) == 0:
                 logger.warning("OCR engine returned no results or invalid result for image crops.")
                 logger.warning("OCR engine returned no results or invalid result for image crops.")
                 return cell_box_map
                 return cell_box_map
@@ -187,13 +207,14 @@ class UnetTableRecognition:
             for j, ocr_res in enumerate(ocr_res_list):
             for j, ocr_res in enumerate(ocr_res_list):
                 img_crop_info_list[j].append(ocr_res)
                 img_crop_info_list[j].append(ocr_res)
 
 
-
             for i, box, ocr_res in img_crop_info_list:
             for i, box, ocr_res in img_crop_info_list:
                 # 处理ocr结果
                 # 处理ocr结果
                 ocr_text, ocr_score = ocr_res
                 ocr_text, ocr_score = ocr_res
                 # logger.debug(f"OCR result for box {i}: {ocr_text} with score {ocr_score}")
                 # logger.debug(f"OCR result for box {i}: {ocr_text} with score {ocr_score}")
-                if ocr_score < 0.9:
+                if ocr_score < 0.9 or ocr_text in ['1']:
                     # logger.warning(f"Low confidence OCR result for box {i}: {ocr_text} with score {ocr_score}")
                     # logger.warning(f"Low confidence OCR result for box {i}: {ocr_text} with score {ocr_score}")
+                    box = sorted_polygons[i]
+                    cell_box_map[i] = [[box, "", 0.5]]
                     continue
                     continue
                 cell_box_map[i] = [[box, ocr_text, ocr_score]]
                 cell_box_map[i] = [[box, ocr_text, ocr_score]]
 
 
@@ -205,20 +226,37 @@ def escape_html(input_string):
     return html.escape(input_string)
     return html.escape(input_string)
 
 
 
 
+def count_table_cells_physical(html_code):
+    """计算表格的物理单元格数量(合并单元格算一个)"""
+    if not html_code:
+        return 0
+
+    # 简单计数td和th标签的数量
+    html_lower = html_code.lower()
+    td_count = html_lower.count('<td')
+    th_count = html_lower.count('<th')
+    return td_count + th_count
+
+
 class UnetTableModel:
 class UnetTableModel:
     def __init__(self, ocr_engine):
     def __init__(self, ocr_engine):
         model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.unet_structure), ModelPath.unet_structure)
         model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.unet_structure), ModelPath.unet_structure)
-        wired_input_args = UnetTableInput(model_path=model_path)
-        self.wired_table_model = UnetTableRecognition(wired_input_args)
+        wired_input_args = WiredTableInput(model_path=model_path)
+        self.wired_table_model = WiredTableRecognition(wired_input_args, ocr_engine)
         slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
         slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
         wireless_input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
         wireless_input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
         self.wireless_table_model = RapidTable(wireless_input_args)
         self.wireless_table_model = RapidTable(wireless_input_args)
         self.ocr_engine = ocr_engine
         self.ocr_engine = ocr_engine
 
 
-    def predict(self, img, table_cls_score):
-        bgr_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
+    def predict(self, input_img, table_cls_score):
+        if isinstance(input_img, Image.Image):
+            np_img = np.asarray(input_img)
+        elif isinstance(input_img, np.ndarray):
+            np_img = input_img
+        else:
+            raise ValueError("Input must be a pillow object or a numpy array.")
+        bgr_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
         ocr_result = self.ocr_engine.ocr(bgr_img)[0]
         ocr_result = self.ocr_engine.ocr(bgr_img)[0]
-
         if ocr_result:
         if ocr_result:
             ocr_result = [
             ocr_result = [
                 [item[0], escape_html(item[1][0]), item[1][1]]
                 [item[0], escape_html(item[1][0]), item[1][1]]
@@ -229,27 +267,42 @@ class UnetTableModel:
             ocr_result = None
             ocr_result = None
         if ocr_result:
         if ocr_result:
             try:
             try:
-                wired_table_results = self.wired_table_model(np.asarray(img), ocr_result, self.ocr_engine)
+                wired_table_results = self.wired_table_model(np_img, ocr_result)
+
+                # viser = VisTable()
+                # save_html_path = f"outputs/output.html"
+                # save_drawed_path = f"outputs/output_table_vis.jpg"
+                # save_logic_path = (
+                #     f"outputs/output_table_vis_logic.jpg"
+                # )
+                # vis_imged = viser(
+                #     np_img, wired_table_results, save_html_path, save_drawed_path, save_logic_path
+                # )
+
                 wired_html_code = wired_table_results.pred_html
                 wired_html_code = wired_table_results.pred_html
                 wired_table_cell_bboxes = wired_table_results.cell_bboxes
                 wired_table_cell_bboxes = wired_table_results.cell_bboxes
                 wired_logic_points = wired_table_results.logic_points
                 wired_logic_points = wired_table_results.logic_points
                 wired_elapse = wired_table_results.elapse
                 wired_elapse = wired_table_results.elapse
 
 
-                wireless_table_results = self.wireless_table_model(np.asarray(img), ocr_result)
+                wireless_table_results = self.wireless_table_model(np_img, ocr_result)
                 wireless_html_code = wireless_table_results.pred_html
                 wireless_html_code = wireless_table_results.pred_html
                 wireless_table_cell_bboxes = wireless_table_results.cell_bboxes
                 wireless_table_cell_bboxes = wireless_table_results.cell_bboxes
                 wireless_logic_points = wireless_table_results.logic_points
                 wireless_logic_points = wireless_table_results.logic_points
                 wireless_elapse = wireless_table_results.elapse
                 wireless_elapse = wireless_table_results.elapse
 
 
-                wired_len = len(wired_table_cell_bboxes) if wired_table_cell_bboxes is not None else 0
-                wireless_len = len(wireless_table_cell_bboxes) if wireless_table_cell_bboxes is not None else 0
+                # wired_len = len(wired_table_cell_bboxes) if wired_table_cell_bboxes is not None else 0
+                # wireless_len = len(wireless_table_cell_bboxes) if wireless_table_cell_bboxes is not None else 0
+
+                wired_len = count_table_cells_physical(wired_html_code)
+                wireless_len = count_table_cells_physical(wireless_html_code)
+
                 # logger.debug(f"wired table cell bboxes: {wired_len}, wireless table cell bboxes: {wireless_len}")
                 # logger.debug(f"wired table cell bboxes: {wired_len}, wireless table cell bboxes: {wireless_len}")
                 # 计算两种模型检测的单元格数量差异
                 # 计算两种模型检测的单元格数量差异
                 gap_of_len = wireless_len - wired_len
                 gap_of_len = wireless_len - wired_len
                 # 判断是否使用无线表格模型的结果
                 # 判断是否使用无线表格模型的结果
                 if (
                 if (
-                    wired_len <= round(wireless_len * 0.5)  # 有线模型检测到的单元格数太少(低于无线模型的50%)
-                    or ((wireless_len < wired_len) and (wired_len < (2 * wireless_len)) and table_cls_score <= 0.949)  # 有线模型检测到的单元格数反而更多
+                    wired_len <= int(wireless_len * 0.55)+1  # 有线模型检测到的单元格数太少(低于无线模型的50%)
+                    # or ((round(wireless_len*1.2) < wired_len) and (wired_len < (2 * wireless_len)) and table_cls_score <= 0.94)  # 有线模型检测到的单元格数反而更多
                     or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75))  # 两者相差不大但有线模型结果较少
                     or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75))  # 两者相差不大但有线模型结果较少
                     or (gap_of_len == 0 and wired_len <= 4)  # 单元格数量完全相等且总量小于等于4
                     or (gap_of_len == 0 and wired_len <= 4)  # 单元格数量完全相等且总量小于等于4
                 ):
                 ):

+ 1 - 0
mineru/model/table/rec/unet_table/table_recover.py

@@ -1,4 +1,5 @@
 from typing import Dict, List, Tuple
 from typing import Dict, List, Tuple
+
 import numpy as np
 import numpy as np
 
 
 
 

+ 65 - 8
mineru/model/table/rec/unet_table/table_structure_unet.py

@@ -5,15 +5,15 @@ from typing import Optional, Dict, Any, Tuple
 import cv2
 import cv2
 import numpy as np
 import numpy as np
 from skimage import measure
 from skimage import measure
-from .wired_table_rec_utils import OrtInferSession, resize_img
-from .table_line_rec_utils import (
+from .utils import OrtInferSession, resize_img
+from .utils_table_line_rec import (
     get_table_line,
     get_table_line,
     final_adjust_lines,
     final_adjust_lines,
     min_area_rect_box,
     min_area_rect_box,
     draw_lines,
     draw_lines,
     adjust_lines,
     adjust_lines,
 )
 )
-from .table_recover_utils import (
+from.utils_table_recover import (
     sorted_ocr_boxes,
     sorted_ocr_boxes,
     box_4_2_poly_to_box_4_1,
     box_4_2_poly_to_box_4_1,
 )
 )
@@ -50,7 +50,7 @@ class TSRUnet:
         )
         )
         _, idx = sorted_ocr_boxes(
         _, idx = sorted_ocr_boxes(
             [box_4_2_poly_to_box_4_1(poly_box) for poly_box in rotated_polygons],
             [box_4_2_poly_to_box_4_1(poly_box) for poly_box in rotated_polygons],
-            threshold=0.4,
+            threhold=0.4,
         )
         )
         polygons = polygons[idx]
         polygons = polygons[idx]
         rotated_polygons = rotated_polygons[idx]
         rotated_polygons = rotated_polygons[idx]
@@ -94,7 +94,8 @@ class TSRUnet:
         extend_line = (
         extend_line = (
             kwargs.get("extend_line", enhance_box_line) if kwargs else enhance_box_line
             kwargs.get("extend_line", enhance_box_line) if kwargs else enhance_box_line
         )  # 是否进行线段延长使得端点连接
         )  # 是否进行线段延长使得端点连接
-
+        # 是否进行旋转修正
+        rotated_fix = kwargs.get("rotated_fix") if kwargs else True
         ori_shape = img.shape
         ori_shape = img.shape
         pred = np.uint8(pred)
         pred = np.uint8(pred)
         hpred = copy.deepcopy(pred)  # 横线
         hpred = copy.deepcopy(pred)  # 横线
@@ -130,9 +131,16 @@ class TSRUnet:
             rowboxes, colboxes = final_adjust_lines(rowboxes, colboxes)
             rowboxes, colboxes = final_adjust_lines(rowboxes, colboxes)
         line_img = np.zeros(img.shape[:2], dtype="uint8")
         line_img = np.zeros(img.shape[:2], dtype="uint8")
         line_img = draw_lines(line_img, rowboxes + colboxes, color=255, lineW=2)
         line_img = draw_lines(line_img, rowboxes + colboxes, color=255, lineW=2)
-
-        polygons = self.cal_region_boxes(line_img)
-        rotated_polygons = polygons.copy()
+        rotated_angle = self.cal_rotate_angle(line_img)
+        if rotated_fix and abs(rotated_angle) > 0.3:
+            rotated_line_img = self.rotate_image(line_img, rotated_angle)
+            rotated_polygons = self.cal_region_boxes(rotated_line_img)
+            polygons = self.unrotate_polygons(
+                rotated_polygons, rotated_angle, line_img.shape
+            )
+        else:
+            polygons = self.cal_region_boxes(line_img)
+            rotated_polygons = polygons.copy()
         return polygons, rotated_polygons
         return polygons, rotated_polygons
 
 
     def cal_region_boxes(self, tmp):
     def cal_region_boxes(self, tmp):
@@ -147,3 +155,52 @@ class TSRUnet:
             adjust_box=False,
             adjust_box=False,
         )  # 最后一个参数改为False
         )  # 最后一个参数改为False
         return np.array(ceilboxes)
         return np.array(ceilboxes)
+
+    def cal_rotate_angle(self, tmp):
+        # 计算最外侧的旋转框
+        contours, _ = cv2.findContours(tmp, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+        if not contours:
+            return 0
+        largest_contour = max(contours, key=cv2.contourArea)
+        rect = cv2.minAreaRect(largest_contour)
+        # 计算旋转角度
+        angle = rect[2]
+        if angle < -45:
+            angle += 90
+        elif angle > 45:
+            angle -= 90
+        return angle
+
+    def rotate_image(self, image, angle):
+        # 获取图像的中心点
+        (h, w) = image.shape[:2]
+        center = (w // 2, h // 2)
+
+        # 计算旋转矩阵
+        M = cv2.getRotationMatrix2D(center, angle, 1.0)
+
+        # 进行旋转
+        rotated_image = cv2.warpAffine(
+            image, M, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REPLICATE
+        )
+
+        return rotated_image
+
+    def unrotate_polygons(
+        self, polygons: np.ndarray, angle: float, img_shape: tuple
+    ) -> np.ndarray:
+        # 将多边形旋转回原始位置
+        (h, w) = img_shape
+        center = (w // 2, h // 2)
+        M_inv = cv2.getRotationMatrix2D(center, -angle, 1.0)
+
+        # 将 (N, 8) 转换为 (N, 4, 2)
+        polygons_reshaped = polygons.reshape(-1, 4, 2)
+
+        # 批量逆旋转
+        unrotated_polygons = cv2.transform(polygons_reshaped, M_inv)
+
+        # 将 (N, 4, 2) 转换回 (N, 8)
+        unrotated_polygons = unrotated_polygons.reshape(-1, 8)
+
+        return unrotated_polygons

+ 154 - 52
mineru/model/table/rec/unet_table/wired_table_rec_utils.py → mineru/model/table/rec/unet_table/utils.py

@@ -3,9 +3,10 @@ import traceback
 from enum import Enum
 from enum import Enum
 from io import BytesIO
 from io import BytesIO
 from pathlib import Path
 from pathlib import Path
-from typing import List, Union, Dict, Any, Tuple
+from typing import List, Union, Dict, Any, Tuple, Optional
 
 
 import cv2
 import cv2
+import loguru
 import numpy as np
 import numpy as np
 from onnxruntime import (
 from onnxruntime import (
     GraphOptimizationLevel,
     GraphOptimizationLevel,
@@ -15,17 +16,20 @@ from onnxruntime import (
 )
 )
 from PIL import Image, UnidentifiedImageError
 from PIL import Image, UnidentifiedImageError
 
 
+
 root_dir = Path(__file__).resolve().parent
 root_dir = Path(__file__).resolve().parent
 InputType = Union[str, np.ndarray, bytes, Path]
 InputType = Union[str, np.ndarray, bytes, Path]
 
 
+
 class EP(Enum):
 class EP(Enum):
     CPU_EP = "CPUExecutionProvider"
     CPU_EP = "CPUExecutionProvider"
 
 
+
 class OrtInferSession:
 class OrtInferSession:
     def __init__(self, config: Dict[str, Any]):
     def __init__(self, config: Dict[str, Any]):
+        self.logger = loguru.logger
 
 
         model_path = config.get("model_path", None)
         model_path = config.get("model_path", None)
-        self._verify_model(model_path)
 
 
         self.had_providers: List[str] = get_available_providers()
         self.had_providers: List[str] = get_available_providers()
         EP_list = self._get_ep_list()
         EP_list = self._get_ep_list()
@@ -55,18 +59,15 @@ class OrtInferSession:
 
 
         return sess_opt
         return sess_opt
 
 
-    def get_metadata(self, key: str = "character") -> list:
-        meta_dict = self.session.get_modelmeta().custom_metadata_map
-        content_list = meta_dict[key].splitlines()
-        return content_list
-
     def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
     def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
         cpu_provider_opts = {
         cpu_provider_opts = {
             "arena_extend_strategy": "kSameAsRequested",
             "arena_extend_strategy": "kSameAsRequested",
         }
         }
         EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
         EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
+
         return EP_list
         return EP_list
 
 
+
     def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
     def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
         input_dict = dict(zip(self.get_input_names(), input_content))
         input_dict = dict(zip(self.get_input_names(), input_content))
         try:
         try:
@@ -78,54 +79,23 @@ class OrtInferSession:
     def get_input_names(self) -> List[str]:
     def get_input_names(self) -> List[str]:
         return [v.name for v in self.session.get_inputs()]
         return [v.name for v in self.session.get_inputs()]
 
 
-    def get_output_names(self) -> List[str]:
-        return [v.name for v in self.session.get_outputs()]
-
-    def get_character_list(self, key: str = "character") -> List[str]:
-        meta_dict = self.session.get_modelmeta().custom_metadata_map
-        return meta_dict[key].splitlines()
-
-    def have_key(self, key: str = "character") -> bool:
-        meta_dict = self.session.get_modelmeta().custom_metadata_map
-        if key in meta_dict.keys():
-            return True
-        return False
-
-    @staticmethod
-    def _verify_model(model_path: Union[str, Path, None]):
-        if model_path is None:
-            raise ValueError("model_path is None!")
-
-        model_path = Path(model_path)
-        if not model_path.exists():
-            raise FileNotFoundError(f"{model_path} does not exists.")
-
-        if not model_path.is_file():
-            raise FileExistsError(f"{model_path} is not a file.")
-
 
 
 class ONNXRuntimeError(Exception):
 class ONNXRuntimeError(Exception):
     pass
     pass
 
 
 
 
 class LoadImage:
 class LoadImage:
-    """
-    Utility class for loading and converting images from various input types to a numpy ndarray.
-
-    Supported input types:
-        - str or pathlib.Path: Path to an image file.
-        - bytes: Image data in bytes format.
-        - numpy.ndarray: Already loaded image array.
-
-    The class attempts to load the image and convert it to a numpy ndarray in BGR format.
-    Raises LoadImageError for unsupported types or if the image cannot be loaded.
-    """
     def __init__(
     def __init__(
         self,
         self,
     ):
     ):
         pass
         pass
 
 
     def __call__(self, img: InputType) -> np.ndarray:
     def __call__(self, img: InputType) -> np.ndarray:
+        if not isinstance(img, InputType.__args__):
+            raise LoadImageError(
+                f"The img type {type(img)} does not in {InputType.__args__}"
+            )
+
         img = self.load_img(img)
         img = self.load_img(img)
         img = self.convert_img(img)
         img = self.convert_img(img)
         return img
         return img
@@ -139,18 +109,14 @@ class LoadImage:
                 raise LoadImageError(f"cannot identify image file {img}") from e
                 raise LoadImageError(f"cannot identify image file {img}") from e
             return img
             return img
 
 
-        elif isinstance(img, bytes):
-            try:
-                img = np.array(Image.open(BytesIO(img)))
-            except UnidentifiedImageError as e:
-                raise LoadImageError(f"cannot identify image from bytes data") from e
+        if isinstance(img, bytes):
+            img = np.array(Image.open(BytesIO(img)))
             return img
             return img
 
 
-        elif isinstance(img, np.ndarray):
+        if isinstance(img, np.ndarray):
             return img
             return img
 
 
-        else:
-            raise LoadImageError(f"{type(img)} is not supported!")
+        raise LoadImageError(f"{type(img)} is not supported!")
 
 
     def convert_img(self, img: np.ndarray):
     def convert_img(self, img: np.ndarray):
         if img.ndim == 2:
         if img.ndim == 2:
@@ -388,4 +354,140 @@ def _scale_size(size, scale):
     if isinstance(scale, (float, int)):
     if isinstance(scale, (float, int)):
         scale = (scale, scale)
         scale = (scale, scale)
     w, h = size
     w, h = size
-    return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
+    return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
+
+
+class VisTable:
+    def __init__(self):
+        self.load_img = LoadImage()
+
+    def __call__(
+        self,
+        img_path: Union[str, Path],
+        table_results,
+        save_html_path: Optional[Union[str, Path]] = None,
+        save_drawed_path: Optional[Union[str, Path]] = None,
+        save_logic_path: Optional[Union[str, Path]] = None,
+    ):
+        if save_html_path:
+            html_with_border = self.insert_border_style(table_results.pred_html)
+            self.save_html(save_html_path, html_with_border)
+
+        table_cell_bboxes = table_results.cell_bboxes
+        table_logic_points = table_results.logic_points
+        if table_cell_bboxes is None:
+            return None
+
+        img = self.load_img(img_path)
+
+        dims_bboxes = table_cell_bboxes.shape[1]
+        if dims_bboxes == 4:
+            drawed_img = self.draw_rectangle(img, table_cell_bboxes)
+        elif dims_bboxes == 8:
+            drawed_img = self.draw_polylines(img, table_cell_bboxes)
+        else:
+            raise ValueError("Shape of table bounding boxes is not between in 4 or 8.")
+
+        if save_drawed_path:
+            self.save_img(save_drawed_path, drawed_img)
+
+        if save_logic_path:
+            polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes]
+            self.plot_rec_box_with_logic_info(
+                img_path, save_logic_path, table_logic_points, polygons
+            )
+        return drawed_img
+
+    def insert_border_style(self, table_html_str: str):
+        style_res = """<meta charset="UTF-8"><style>
+        table {
+            border-collapse: collapse;
+            width: 100%;
+        }
+        th, td {
+            border: 1px solid black;
+            padding: 8px;
+            text-align: center;
+        }
+        th {
+            background-color: #f2f2f2;
+        }
+                    </style>"""
+
+        prefix_table, suffix_table = table_html_str.split("<body>")
+        html_with_border = f"{prefix_table}{style_res}<body>{suffix_table}"
+        return html_with_border
+
+    def plot_rec_box_with_logic_info(
+        self, img_path, output_path, logic_points, sorted_polygons
+    ):
+        """
+        :param img_path
+        :param output_path
+        :param logic_points: [row_start,row_end,col_start,col_end]
+        :param sorted_polygons: [xmin,ymin,xmax,ymax]
+        :return:
+        """
+        # 读取原图
+        img = cv2.imread(img_path)
+        img = cv2.copyMakeBorder(
+            img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
+        )
+        # 绘制 polygons 矩形
+        for idx, polygon in enumerate(sorted_polygons):
+            x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
+            x0 = round(x0)
+            y0 = round(y0)
+            x1 = round(x1)
+            y1 = round(y1)
+            cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
+            # 增大字体大小和线宽
+            font_scale = 0.9  # 原先是0.5
+            thickness = 1  # 原先是1
+            logic_point = logic_points[idx]
+            cv2.putText(
+                img,
+                f"row: {logic_point[0]}-{logic_point[1]}",
+                (x0 + 3, y0 + 8),
+                cv2.FONT_HERSHEY_PLAIN,
+                font_scale,
+                (0, 0, 255),
+                thickness,
+            )
+            cv2.putText(
+                img,
+                f"col: {logic_point[2]}-{logic_point[3]}",
+                (x0 + 3, y0 + 18),
+                cv2.FONT_HERSHEY_PLAIN,
+                font_scale,
+                (0, 0, 255),
+                thickness,
+            )
+            os.makedirs(os.path.dirname(output_path), exist_ok=True)
+            # 保存绘制后的图像
+            self.save_img(output_path, img)
+
+    @staticmethod
+    def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
+        img_copy = img.copy()
+        for box in boxes.astype(int):
+            x1, y1, x2, y2 = box
+            cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2)
+        return img_copy
+
+    @staticmethod
+    def draw_polylines(img: np.ndarray, points) -> np.ndarray:
+        img_copy = img.copy()
+        for point in points.astype(int):
+            point = point.reshape(4, 2)
+            cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2)
+        return img_copy
+
+    @staticmethod
+    def save_img(save_path: Union[str, Path], img: np.ndarray):
+        cv2.imwrite(str(save_path), img)
+
+    @staticmethod
+    def save_html(save_path: Union[str, Path], html: str):
+        with open(save_path, "w", encoding="utf-8") as f:
+            f.write(html)

+ 96 - 5
mineru/model/table/rec/unet_table/table_line_rec_utils.py → mineru/model/table/rec/unet_table/utils_table_line_rec.py

@@ -6,6 +6,68 @@ from scipy.spatial import distance as dist
 from skimage import measure
 from skimage import measure
 
 
 
 
+def transform_preds(coords, center, scale, output_size, rot=0):
+    target_coords = np.zeros(coords.shape)
+    trans = get_affine_transform(center, scale, rot, output_size, inv=1)
+    for p in range(coords.shape[0]):
+        target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
+    return target_coords
+
+
+def get_affine_transform(
+    center, scale, rot, output_size, shift=np.array([0, 0], dtype=np.float32), inv=0
+):
+    if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
+        scale = np.array([scale, scale], dtype=np.float32)
+
+    scale_tmp = scale
+    src_w = scale_tmp[0]
+    dst_w = output_size[0]
+    dst_h = output_size[1]
+
+    rot_rad = np.pi * rot / 180
+    src_dir = get_dir([0, src_w * -0.5], rot_rad)
+    dst_dir = np.array([0, dst_w * -0.5], np.float32)
+
+    src = np.zeros((3, 2), dtype=np.float32)
+    dst = np.zeros((3, 2), dtype=np.float32)
+    src[0, :] = center + scale_tmp * shift
+    src[1, :] = center + src_dir + scale_tmp * shift
+    dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
+    dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir
+
+    src[2:, :] = get_3rd_point(src[0, :], src[1, :])
+    dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
+
+    if inv:
+        trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
+    else:
+        trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+
+    return trans
+
+
+def affine_transform(pt, t):
+    new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32).T
+    new_pt = np.dot(t, new_pt)
+    return new_pt[:2]
+
+
+def get_dir(src_point, rot_rad):
+    sn, cs = np.sin(rot_rad), np.cos(rot_rad)
+
+    src_result = [0, 0]
+    src_result[0] = src_point[0] * cs - src_point[1] * sn
+    src_result[1] = src_point[0] * sn + src_point[1] * cs
+
+    return src_result
+
+
+def get_3rd_point(a, b):
+    direct = a - b
+    return b + np.array([-direct[1], direct[0]], dtype=np.float32)
+
+
 def get_table_line(binimg, axis=0, lineW=10):
 def get_table_line(binimg, axis=0, lineW=10):
     ##获取表格线
     ##获取表格线
     ##axis=0 横线
     ##axis=0 横线
@@ -38,7 +100,7 @@ def min_area_rect(coords):
     box = image_location_sort_box(box)
     box = image_location_sort_box(box)
 
 
     x1, y1, x2, y2, x3, y3, x4, y4 = box
     x1, y1, x2, y2, x3, y3, x4, y4 = box
-    w, h = calculate_center_rotate_angle(box)
+    degree, w, h, cx, cy = calculate_center_rotate_angle(box)
     if w < h:
     if w < h:
         xmin = (x1 + x2) / 2
         xmin = (x1 + x2) / 2
         xmax = (x3 + x4) / 2
         xmax = (x3 + x4) / 2
@@ -50,6 +112,9 @@ def min_area_rect(coords):
         xmax = (x2 + x3) / 2
         xmax = (x2 + x3) / 2
         ymin = (y1 + y4) / 2
         ymin = (y1 + y4) / 2
         ymax = (y2 + y3) / 2
         ymax = (y2 + y3) / 2
+    # degree,w,h,cx,cy = solve(box)
+    # x1,y1,x2,y2,x3,y3,x4,y4 = box
+    # return {'degree':degree,'w':w,'h':h,'cx':cx,'cy':cy}
     return [xmin, ymin, xmax, ymax]
     return [xmin, ymin, xmax, ymax]
 
 
 
 
@@ -62,8 +127,21 @@ def image_location_sort_box(box):
 
 
 
 
 def calculate_center_rotate_angle(box):
 def calculate_center_rotate_angle(box):
+    """
+    绕 cx,cy点 w,h 旋转 angle 的坐标,能一定程度缓解图片的内部倾斜,但是还是依赖模型稳妥
+    x = cx-w/2
+    y = cy-h/2
+    x1-cx = -w/2*cos(angle) +h/2*sin(angle)
+    y1 -cy= -w/2*sin(angle) -h/2*cos(angle)
+
+    h(x1-cx) = -wh/2*cos(angle) +hh/2*sin(angle)
+    w(y1 -cy)= -ww/2*sin(angle) -hw/2*cos(angle)
+    (hh+ww)/2sin(angle) = h(x1-cx)-w(y1 -cy)
 
 
+    """
     x1, y1, x2, y2, x3, y3, x4, y4 = box[:8]
     x1, y1, x2, y2, x3, y3, x4, y4 = box[:8]
+    cx = (x1 + x3 + x2 + x4) / 4.0
+    cy = (y1 + y3 + y4 + y2) / 4.0
     w = (
     w = (
         np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
         np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
         + np.sqrt((x3 - x4) ** 2 + (y3 - y4) ** 2)
         + np.sqrt((x3 - x4) ** 2 + (y3 - y4) ** 2)
@@ -72,8 +150,11 @@ def calculate_center_rotate_angle(box):
         np.sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2)
         np.sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2)
         + np.sqrt((x1 - x4) ** 2 + (y1 - y4) ** 2)
         + np.sqrt((x1 - x4) ** 2 + (y1 - y4) ** 2)
     ) / 2
     ) / 2
-
-    return w, h
+    # x = cx-w/2
+    # y = cy-h/2
+    sinA = (h * (x1 - cx) - w * (y1 - cy)) * 1.0 / (h * h + w * w) * 2
+    angle = np.arcsin(sinA)
+    return angle, w, h, cx, cy
 
 
 
 
 def _order_points(pts):
 def _order_points(pts):
@@ -222,8 +303,18 @@ def min_area_rect_box(
         box = box.reshape((8,)).tolist()
         box = box.reshape((8,)).tolist()
         box = image_location_sort_box(box)
         box = image_location_sort_box(box)
         x1, y1, x2, y2, x3, y3, x4, y4 = box
         x1, y1, x2, y2, x3, y3, x4, y4 = box
-        w, h = calculate_center_rotate_angle(box)
-
+        angle, w, h, cx, cy = calculate_center_rotate_angle(box)
+        # if adjustBox:
+        #     x1, y1, x2, y2, x3, y3, x4, y4 = xy_rotate_box(cx, cy, w + 5, h + 5, angle=0, degree=None)
+        #     x1, x4 = max(x1, 0), max(x4, 0)
+        #     y1, y2 = max(y1, 0), max(y2, 0)
+
+        # if w > 32 and h > 32 and flag:
+        #     if abs(angle / np.pi * 180) < 20:
+        #         if filtersmall and (w < 10 or h < 10):
+        #             continue
+        #         boxes.append([x1, y1, x2, y2, x3, y3, x4, y4])
+        # else:
         if w * h < 0.5 * W * H:
         if w * h < 0.5 * W * H:
             if filtersmall and (
             if filtersmall and (
                 w < 15 or h < 15
                 w < 15 or h < 15

+ 24 - 133
mineru/model/table/rec/unet_table/table_recover_utils.py → mineru/model/table/rec/unet_table/utils_table_recover.py

@@ -1,33 +1,6 @@
 from typing import Any, Dict, List, Union, Tuple
 from typing import Any, Dict, List, Union, Tuple
 
 
 import numpy as np
 import numpy as np
-import shapely
-from shapely.geometry import MultiPoint, Polygon
-
-
-def sorted_boxes(dt_boxes: np.ndarray) -> np.ndarray:
-    """
-    Sort text boxes in order from top to bottom, left to right
-    args:
-        dt_boxes(array):detected text boxes with shape (N, 4, 2)
-    return:
-        sorted boxes(array) with shape (N, 4, 2)
-    """
-    num_boxes = dt_boxes.shape[0]
-    dt_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
-    _boxes = list(dt_boxes)
-
-    # 解决相邻框,后边比前面y轴小,则会被排到前面去的问题
-    for i in range(num_boxes - 1):
-        for j in range(i, -1, -1):
-            if (
-                abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10
-                and _boxes[j + 1][0][0] < _boxes[j][0][0]
-            ):
-                _boxes[j], _boxes[j + 1] = _boxes[j + 1], _boxes[j]
-            else:
-                break
-    return np.array(_boxes)
 
 
 
 
 def calculate_iou(
 def calculate_iou(
@@ -63,6 +36,7 @@ def calculate_iou(
     return iou
     return iou
 
 
 
 
+
 def is_box_contained(
 def is_box_contained(
     box1: Union[np.ndarray, List], box2: Union[np.ndarray, List], threshold=0.2
     box1: Union[np.ndarray, List], box2: Union[np.ndarray, List], threshold=0.2
 ) -> Union[int, None]:
 ) -> Union[int, None]:
@@ -111,7 +85,7 @@ def is_single_axis_contained(
     box1: Union[np.ndarray, List],
     box1: Union[np.ndarray, List],
     box2: Union[np.ndarray, List],
     box2: Union[np.ndarray, List],
     axis="x",
     axis="x",
-    threshold: float = 0.2,
+    threhold: float = 0.2,
 ) -> Union[int, None]:
 ) -> Union[int, None]:
     """
     """
     :param box1: Iterable [xmin,ymin,xmax,ymax]
     :param box1: Iterable [xmin,ymin,xmax,ymax]
@@ -136,15 +110,15 @@ def is_single_axis_contained(
 
 
     ratio_b1 = b1_outside_area / b1_area if b1_area > 0 else 0
     ratio_b1 = b1_outside_area / b1_area if b1_area > 0 else 0
     ratio_b2 = b2_outside_area / b2_area if b2_area > 0 else 0
     ratio_b2 = b2_outside_area / b2_area if b2_area > 0 else 0
-    if ratio_b1 < threshold:
+    if ratio_b1 < threhold:
         return 1
         return 1
-    if ratio_b2 < threshold:
+    if ratio_b2 < threhold:
         return 2
         return 2
     return None
     return None
 
 
 
 
 def sorted_ocr_boxes(
 def sorted_ocr_boxes(
-    dt_boxes: Union[np.ndarray, list], threshold: float = 0.2
+    dt_boxes: Union[np.ndarray, list], threhold: float = 0.2
 ) -> Tuple[Union[np.ndarray, list], List[int]]:
 ) -> Tuple[Union[np.ndarray, list], List[int]]:
     """
     """
     Sort text boxes in order from top to bottom, left to right
     Sort text boxes in order from top to bottom, left to right
@@ -161,18 +135,19 @@ def sorted_ocr_boxes(
     _boxes, indices = zip(*sorted_boxes_with_idx)
     _boxes, indices = zip(*sorted_boxes_with_idx)
     indices = list(indices)
     indices = list(indices)
     _boxes = [dt_boxes[i] for i in indices]
     _boxes = [dt_boxes[i] for i in indices]
+    threahold = 20
     # 避免输出和输入格式不对应,与函数功能不符合
     # 避免输出和输入格式不对应,与函数功能不符合
     if isinstance(dt_boxes, np.ndarray):
     if isinstance(dt_boxes, np.ndarray):
         _boxes = np.array(_boxes)
         _boxes = np.array(_boxes)
     for i in range(num_boxes - 1):
     for i in range(num_boxes - 1):
         for j in range(i, -1, -1):
         for j in range(i, -1, -1):
             c_idx = is_single_axis_contained(
             c_idx = is_single_axis_contained(
-                _boxes[j], _boxes[j + 1], axis="y", threshold=threshold
+                _boxes[j], _boxes[j + 1], axis="y", threhold=threhold
             )
             )
             if (
             if (
                 c_idx is not None
                 c_idx is not None
                 and _boxes[j + 1][0] < _boxes[j][0]
                 and _boxes[j + 1][0] < _boxes[j][0]
-                and abs(_boxes[j][1] - _boxes[j + 1][1]) < 20
+                and abs(_boxes[j][1] - _boxes[j + 1][1]) < threahold
             ):
             ):
                 _boxes[j], _boxes[j + 1] = _boxes[j + 1].copy(), _boxes[j].copy()
                 _boxes[j], _boxes[j + 1] = _boxes[j + 1].copy(), _boxes[j].copy()
                 indices[j], indices[j + 1] = indices[j + 1], indices[j]
                 indices[j], indices[j + 1] = indices[j + 1], indices[j]
@@ -181,6 +156,11 @@ def sorted_ocr_boxes(
     return _boxes, indices
     return _boxes, indices
 
 
 
 
+def box_4_1_poly_to_box_4_2(poly_box: Union[list, np.ndarray]) -> List[List[float]]:
+    xmin, ymin, xmax, ymax = tuple(poly_box)
+    return [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
+
+
 def box_4_2_poly_to_box_4_1(poly_box: Union[list, np.ndarray]) -> List[Any]:
 def box_4_2_poly_to_box_4_1(poly_box: Union[list, np.ndarray]) -> List[Any]:
     """
     """
     将poly_box转换为box_4_1
     将poly_box转换为box_4_1
@@ -221,18 +201,12 @@ def match_ocr_cell(dt_rec_boxes: List[List[Union[Any, str]]], pred_bboxes: np.nd
     return matched, not_match_orc_boxes
     return matched, not_match_orc_boxes
 
 
 
 
-def gather_ocr_list_by_row(ocr_list: List[Any], threshold: float = 0.2) -> List[Any]:
+def gather_ocr_list_by_row(ocr_list: List[Any], threhold: float = 0.2) -> List[Any]:
     """
     """
-        Groups OCR results by row based on the vertical (y-axis) overlap of their bounding boxes.
-    Args:
-        ocr_list (List[Any]): A list of OCR results, where each item is a list containing a bounding box
-            in the format [xmin, ymin, xmax, ymax] and the recognized text.
-        threshold (float, optional): The threshold for determining if two boxes are in the same row,
-            based on their y-axis overlap. Default is 0.2.
-    Returns:
-        List[Any]: A new list of OCR results where texts in the same row are merged, and their bounding
-            boxes are updated to encompass the merged text.
+    :param ocr_list: [[[xmin,ymin,xmax,ymax], text]]
+    :return:
     """
     """
+    threshold = 10
     for i in range(len(ocr_list)):
     for i in range(len(ocr_list)):
         if not ocr_list[i]:
         if not ocr_list[i]:
             continue
             continue
@@ -245,11 +219,11 @@ def gather_ocr_list_by_row(ocr_list: List[Any], threshold: float = 0.2) -> List[
             cur_box = cur[0]
             cur_box = cur[0]
             next_box = next[0]
             next_box = next[0]
             c_idx = is_single_axis_contained(
             c_idx = is_single_axis_contained(
-                cur[0], next[0], axis="y", threshold=threshold
+                cur[0], next[0], axis="y", threhold=threhold
             )
             )
             if c_idx:
             if c_idx:
                 dis = max(next_box[0] - cur_box[2], 0)
                 dis = max(next_box[0] - cur_box[2], 0)
-                blank_str = int(dis / 10) * " "
+                blank_str = int(dis / threshold) * " "
                 cur[1] = cur[1] + blank_str + next[1]
                 cur[1] = cur[1] + blank_str + next[1]
                 xmin = min(cur_box[0], next_box[0])
                 xmin = min(cur_box[0], next_box[0])
                 xmax = max(cur_box[2], next_box[2])
                 xmax = max(cur_box[2], next_box[2])
@@ -264,93 +238,6 @@ def gather_ocr_list_by_row(ocr_list: List[Any], threshold: float = 0.2) -> List[
     return ocr_list
     return ocr_list
 
 
 
 
-def compute_poly_iou(a: np.ndarray, b: np.ndarray) -> float:
-    """计算两个多边形的IOU
-
-    Args:
-        poly1 (np.ndarray): (4, 2)
-        poly2 (np.ndarray): (4, 2)
-
-    Returns:
-        float: iou
-    """
-    poly1 = Polygon(a).convex_hull
-    poly2 = Polygon(b).convex_hull
-
-    union_poly = np.concatenate((a, b))
-
-    if not poly1.intersects(poly2):
-        return 0.0
-
-    try:
-        inter_area = poly1.intersection(poly2).area
-        union_area = MultiPoint(union_poly).convex_hull.area
-    except shapely.geos.TopologicalError:
-        print("shapely.geos.TopologicalError occured, iou set to 0")
-        return 0.0
-
-    if union_area == 0:
-        return 0.0
-
-    return float(inter_area) / union_area
-
-
-def merge_adjacent_polys(polygons: np.ndarray) -> np.ndarray:
-    """合并相邻iou大于阈值的框"""
-    combine_iou_thresh = 0.1
-    pair_polygons = list(zip(polygons, polygons[1:, ...]))
-    pair_ious = np.array([compute_poly_iou(p1, p2) for p1, p2 in pair_polygons])
-    idxs = np.argwhere(pair_ious >= combine_iou_thresh)
-
-    if idxs.size <= 0:
-        return polygons
-
-    polygons = combine_two_poly(polygons, idxs)
-
-    # 注意:递归调用
-    polygons = merge_adjacent_polys(polygons)
-    return polygons
-
-
-def combine_two_poly(polygons: np.ndarray, idxs: np.ndarray) -> np.ndarray:
-    del_idxs, insert_boxes = [], []
-    idxs = idxs.squeeze(-1)
-    for idx in idxs:
-        # idx 和 idx + 1 是重合度过高的
-        # 合并,取两者各个点的最大值
-        new_poly = []
-        pre_poly, pos_poly = polygons[idx], polygons[idx + 1]
-
-        # 四个点,每个点逐一比较
-        new_poly.append(np.minimum(pre_poly[0], pos_poly[0]))
-
-        x_2 = min(pre_poly[1][0], pos_poly[1][0])
-        y_2 = max(pre_poly[1][1], pos_poly[1][1])
-        new_poly.append([x_2, y_2])
-
-        # 第3个点
-        new_poly.append(np.maximum(pre_poly[2], pos_poly[2]))
-
-        # 第4个点
-        x_4 = max(pre_poly[3][0], pos_poly[3][0])
-        y_4 = min(pre_poly[3][1], pos_poly[3][1])
-        new_poly.append([x_4, y_4])
-
-        new_poly = np.array(new_poly)
-
-        # 删除已经合并的两个框,插入新的框
-        del_idxs.extend([idx, idx + 1])
-        insert_boxes.append(new_poly)
-
-    # 整合合并后的框
-    polygons = np.delete(polygons, del_idxs, axis=0)
-
-    insert_boxes = np.array(insert_boxes)
-    polygons = np.append(polygons, insert_boxes, axis=0)
-    polygons = sorted_boxes(polygons)
-    return polygons
-
-
 def plot_html_table(
 def plot_html_table(
     logi_points: Union[Union[np.ndarray, List]], cell_box_map: Dict[int, List[str]]
     logi_points: Union[Union[np.ndarray, List]], cell_box_map: Dict[int, List[str]]
 ) -> str:
 ) -> str:
@@ -405,7 +292,8 @@ def plot_html_table(
                     continue
                     continue
                 if row == row_start and col == col_start:
                 if row == row_start and col == col_start:
                     ocr_rec_text = cell_box_map.get(i)
                     ocr_rec_text = cell_box_map.get(i)
-                    text = "<br>".join(ocr_rec_text)
+                    # text = "<br>".join(ocr_rec_text)
+                    text = "".join(ocr_rec_text)
                     # 如果是起始单元格
                     # 如果是起始单元格
                     row_span = row_end - row_start + 1
                     row_span = row_end - row_start + 1
                     col_span = col_end - col_start + 1
                     col_span = col_end - col_start + 1
@@ -418,3 +306,6 @@ def plot_html_table(
 
 
     table_html += "</table></body></html>"
     table_html += "</table></body></html>"
     return table_html
     return table_html
+
+
+

+ 25 - 7
mineru/utils/model_utils.py

@@ -324,7 +324,7 @@ def remove_overlaps_low_confidence_blocks(combined_res_list, overlap_threshold=0
                 marked_indices.add(i)  # 标记当前索引为已处理
                 marked_indices.add(i)  # 标记当前索引为已处理
     return blocks_to_remove
     return blocks_to_remove
 
 
-
+# @todo 这个方法以后需要重构
 def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshold=0.8, area_threshold=0.8):
 def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshold=0.8, area_threshold=0.8):
     """Extract OCR, table and other regions from layout results."""
     """Extract OCR, table and other regions from layout results."""
     ocr_res_list = []
     ocr_res_list = []
@@ -358,14 +358,31 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
     filtered_table_res_list = filter_nested_tables(
     filtered_table_res_list = filter_nested_tables(
         table_res_list, overlap_threshold, area_threshold)
         table_res_list, overlap_threshold, area_threshold)
 
 
+    for table_res in filtered_table_res_list:
+        table_res['bbox'] = [int(table_res['poly'][0]), int(table_res['poly'][1]), int(table_res['poly'][4]), int(table_res['poly'][5])]
+
+    filtered_table_res_list, table_need_remove = remove_overlaps_min_blocks(filtered_table_res_list)
+
+    for res in filtered_table_res_list:
+        # 将res的poly使用bbox重构
+        res['poly'] = [res['bbox'][0], res['bbox'][1], res['bbox'][2], res['bbox'][1],
+                       res['bbox'][2], res['bbox'][3], res['bbox'][0], res['bbox'][3]]
+        # 删除res的bbox
+        del res['bbox']
+
+    if len(table_need_remove) > 0:
+        for res in table_need_remove:
+            del res['bbox']
+            if res in layout_res:
+                layout_res.remove(res)
+
     # Remove filtered out tables from layout_res
     # Remove filtered out tables from layout_res
     if len(filtered_table_res_list) < len(table_res_list):
     if len(filtered_table_res_list) < len(table_res_list):
         kept_tables = set(id(table) for table in filtered_table_res_list)
         kept_tables = set(id(table) for table in filtered_table_res_list)
-        to_remove = [table_indices[i] for i, table in enumerate(table_res_list)
-                     if id(table) not in kept_tables]
-
-        for idx in sorted(to_remove, reverse=True):
-            del layout_res[idx]
+        tables_to_remove = [table for table in table_res_list if id(table) not in kept_tables]
+        for table in tables_to_remove:
+            if table in layout_res:
+                layout_res.remove(table)
 
 
     # Remove overlaps in OCR and text regions
     # Remove overlaps in OCR and text regions
     text_res_list, need_remove = remove_overlaps_min_blocks(text_res_list)
     text_res_list, need_remove = remove_overlaps_min_blocks(text_res_list)
@@ -381,7 +398,8 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
     if len(need_remove) > 0:
     if len(need_remove) > 0:
         for res in need_remove:
         for res in need_remove:
             del res['bbox']
             del res['bbox']
-            layout_res.remove(res)
+            if res in layout_res:
+                layout_res.remove(res)
 
 
     # 检测大block内部是否包含多个小block, 合并ocr和table列表进行检测
     # 检测大block内部是否包含多个小block, 合并ocr和table列表进行检测
     combined_res_list = ocr_res_list + filtered_table_res_list
     combined_res_list = ocr_res_list + filtered_table_res_list

+ 1 - 18
pyproject.toml

@@ -68,7 +68,7 @@ pipeline = [
     "shapely>=2.0.7,<3",
     "shapely>=2.0.7,<3",
     "pyclipper>=1.3.0,<2",
     "pyclipper>=1.3.0,<2",
     "omegaconf>=2.3.0,<3",
     "omegaconf>=2.3.0,<3",
-    "torch>=2.2.2,!=2.5.0,!=2.5.1,<3",
+    "torch>=2.6.0,<3",
     "torchvision",
     "torchvision",
     "transformers>=4.49.0,!=4.51.0,<5.0.0",
     "transformers>=4.49.0,!=4.51.0,<5.0.0",
 ]
 ]
@@ -91,23 +91,6 @@ all = [
     "mineru[core]",
     "mineru[core]",
     "mineru[sglang]",
     "mineru[sglang]",
 ]
 ]
-pipeline_old_linux = [
-    "matplotlib>=3.10,<=3.10.1",
-    "ultralytics>=8.3.48,<=8.3.104",
-    "doclayout_yolo==0.0.4",
-    "dill==0.3.8",
-    "PyYAML==6.0.2",
-    "ftfy==6.3.1",
-    "openai==1.71.0",
-    "shapely==2.1.0",
-    "pyclipper==1.3.0.post6",
-    "omegaconf==2.3.0",
-    "albumentations==1.4.20",
-    "rapid_table==1.0.3",
-    "torch>=2.2.2,!=2.5.0,!=2.5.1,<3",
-    "torchvision",
-    "transformers>=4.49.0,!=4.51.0,<5.0.0",
-]
 
 
 [project.urls]
 [project.urls]
 homepage = "https://mineru.net/"
 homepage = "https://mineru.net/"