Selaa lähdekoodia

Merge pull request #3242 from myhloli/dev

feat:  Improve the parsing accuracy of wired tables
Xiaomeng Zhao 3 kuukautta sitten
vanhempi
commit
8b6d217efe

+ 40 - 5
mineru/backend/pipeline/batch_analyze.py

@@ -5,6 +5,7 @@ from collections import defaultdict
 import numpy as np
 
 from .model_init import AtomModelSingleton
+from .model_list import AtomicModel
 from ...utils.config_reader import get_formula_enable, get_table_enable
 from ...utils.model_utils import crop_img, get_res_list_from_layout_res
 from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
@@ -133,7 +134,7 @@ class BatchAnalyze:
 
                 # 获取OCR模型
                 ocr_model = atom_model_manager.get_atom_model(
-                    atom_model_name='ocr',
+                    atom_model_name=AtomicModel.OCR,
                     det_db_box_thresh=0.3,
                     lang=lang
                 )
@@ -219,7 +220,7 @@ class BatchAnalyze:
                 _lang = ocr_res_list_dict['lang']
                 # Get OCR results for this language's images
                 ocr_model = atom_model_manager.get_atom_model(
-                    atom_model_name='ocr',
+                    atom_model_name=AtomicModel.OCR,
                     ocr_show_log=False,
                     det_db_box_thresh=0.3,
                     lang=_lang
@@ -249,11 +250,45 @@ class BatchAnalyze:
         if self.table_enable:
             for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
                 _lang = table_res_dict['lang']
+
+                # 有线表/无线表分类
+                table_cls_model = atom_model_manager.get_atom_model(
+                    atom_model_name=AtomicModel.TableCls,
+                )
+                try:
+                    table_label = table_cls_model.predict(table_res_dict["table_img"])
+                except Exception as e:
+                    table_label = AtomicModel.WirelessTable
+                    logger.warning(
+                        f"Table classification failed: {e}, using default model {table_label}"
+                    )
+                # logger.debug(f"Table classification result: {table_label}")
+                if table_label not in [AtomicModel.WirelessTable, AtomicModel.WiredTable]:
+                    raise ValueError(
+                        "Table classification failed, please check the model"
+                    )
+
+                # 根据表格分类结果选择有线表格识别模型和无线表格识别模型
                 table_model = atom_model_manager.get_atom_model(
-                    atom_model_name='table',
+                    atom_model_name=table_label,
                     lang=_lang,
                 )
-                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict['table_img'])
+
+                # 调整图片方向
+                img_orientation_cls_model = atom_model_manager.get_atom_model(
+                    atom_model_name=AtomicModel.ImgOrientationCls,
+                )
+                try:
+                    table_img = img_orientation_cls_model.predict(
+                        table_res_dict["table_img"]
+                    )
+                except Exception as e:
+                    logger.warning(
+                        f"Image orientation classification failed: {e}, using original image"
+                    )
+                    table_img = table_res_dict["table_img"]
+
+                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_img)
                 # 判断是否返回正常
                 if html_code:
                     # 检查html_code是否包含'<table>'和'</table>'
@@ -305,7 +340,7 @@ class BatchAnalyze:
                     # Get OCR results for this language's images
 
                     ocr_model = atom_model_manager.get_atom_model(
-                        atom_model_name='ocr',
+                        atom_model_name=AtomicModel.OCR,
                         det_db_box_thresh=0.3,
                         lang=lang
                     )

+ 52 - 14
mineru/backend/pipeline/model_init.py

@@ -8,18 +8,40 @@ from ...model.layout.doclayout_yolo import DocLayoutYOLOModel
 from ...model.mfd.yolo_v8 import YOLOv8MFDModel
 from ...model.mfr.unimernet.Unimernet import UnimernetModel
 from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
-from ...model.table.rapid_table import RapidTableModel
+from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
+from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
+from ...model.table.rec.rapid_table import RapidTableModel
+from ...model.table.rec.unet_table.unet_table import UnetTableModel
 from ...utils.enum_class import ModelPath
 from ...utils.models_download_utils import auto_download_and_get_model_root_path
 
 
-def table_model_init(lang=None):
+def img_orientation_cls_model_init():
     atom_model_manager = AtomModelSingleton()
     ocr_engine = atom_model_manager.get_atom_model(
-        atom_model_name='ocr',
-        det_db_box_thresh=0.5,
-        det_db_unclip_ratio=1.6,
-        lang=lang
+        atom_model_name="ocr", det_db_box_thresh=0.5, det_db_unclip_ratio=1.6, lang="ch_lite"
+    )
+    cls_model = PaddleOrientationClsModel(ocr_engine)
+    return cls_model
+
+
+def table_cls_model_init():
+    return PaddleTableClsModel()
+
+
+def wired_table_model_init(lang=None):
+    atom_model_manager = AtomModelSingleton()
+    ocr_engine = atom_model_manager.get_atom_model(
+        atom_model_name="ocr", det_db_box_thresh=0.5, det_db_unclip_ratio=1.6, lang=lang
+    )
+    table_model = UnetTableModel(ocr_engine)
+    return table_model
+
+
+def wireless_table_model_init(lang=None):
+    atom_model_manager = AtomModelSingleton()
+    ocr_engine = atom_model_manager.get_atom_model(
+        atom_model_name="ocr", det_db_box_thresh=0.5, det_db_unclip_ratio=1.6, lang=lang
     )
     table_model = RapidTableModel(ocr_engine)
     return table_model
@@ -76,12 +98,9 @@ class AtomModelSingleton:
     def get_atom_model(self, atom_model_name: str, **kwargs):
 
         lang = kwargs.get('lang', None)
-        table_model_name = kwargs.get('table_model_name', None)
 
-        if atom_model_name in [AtomicModel.OCR]:
+        if atom_model_name in [AtomicModel.OCR, AtomicModel.WiredTable, AtomicModel.WirelessTable]:
             key = (atom_model_name, lang)
-        elif atom_model_name in [AtomicModel.Table]:
-            key = (atom_model_name, table_model_name, lang)
         else:
             key = atom_model_name
 
@@ -111,10 +130,18 @@ def atom_model_init(model_name: str, **kwargs):
             kwargs.get('det_db_box_thresh'),
             kwargs.get('lang'),
         )
-    elif model_name == AtomicModel.Table:
-        atom_model = table_model_init(
+    elif model_name == AtomicModel.WirelessTable:
+        atom_model = wireless_table_model_init(
+            kwargs.get('lang'),
+        )
+    elif model_name == AtomicModel.WiredTable:
+        atom_model = wired_table_model_init(
             kwargs.get('lang'),
         )
+    elif model_name == AtomicModel.TableCls:
+        atom_model = table_cls_model_init()
+    elif model_name == AtomicModel.ImgOrientationCls:
+        atom_model = img_orientation_cls_model_init()
     else:
         logger.error('model name not allow')
         exit(1)
@@ -174,8 +201,19 @@ class MineruPipelineModel:
         )
         # init table model
         if self.apply_table:
-            self.table_model = atom_model_manager.get_atom_model(
-                atom_model_name=AtomicModel.Table,
+            self.wired_table_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.WiredTable,
+                lang=self.lang,
+            )
+            self.wireless_table_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.WirelessTable,
+                lang=self.lang,
+            )
+            self.table_cls_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.TableCls,
+            )
+            self.img_orientation_cls_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.ImgOrientationCls,
                 lang=self.lang,
             )
 

+ 5 - 1
mineru/backend/pipeline/model_list.py

@@ -3,4 +3,8 @@ class AtomicModel:
     MFD = "mfd"
     MFR = "mfr"
     OCR = "ocr"
-    Table = "table"
+    WirelessTable = "wireless_table"
+    WiredTable = "wired_table"
+    TableCls = "table_cls"
+    ImgOrientationCls = "img_ori_cls"
+

+ 5 - 1
mineru/backend/pipeline/pipeline_analyze.py

@@ -190,7 +190,11 @@ def batch_image_analyze(
             batch_ratio = 1
             logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
 
-    batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable)
+    if str(device).startswith('mps'):
+        enable_ocr_det_batch = False
+    else:
+        enable_ocr_det_batch = True
+    batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable, enable_ocr_det_batch)
     results = batch_model(images_with_extra_info)
 
     clean_memory(get_device())

+ 1 - 0
mineru/model/ori_cls/__init__.py

@@ -0,0 +1 @@
+# Copyright (c) Opendatalab. All rights reserved.

+ 115 - 0
mineru/model/ori_cls/paddle_ori_cls.py

@@ -0,0 +1,115 @@
+# Copyright (c) Opendatalab. All rights reserved.
+import os
+
+import cv2
+import numpy as np
+import onnxruntime
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
+
+
+class PaddleOrientationClsModel:
+    def __init__(self, ocr_engine):
+        self.sess = onnxruntime.InferenceSession(
+            os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_orientation_classification), ModelPath.paddle_orientation_classification)
+        )
+        self.ocr_engine = ocr_engine
+        self.less_length = 256
+        self.cw, self.ch = 224, 224
+        self.std = [0.229, 0.224, 0.225]
+        self.scale = 0.00392156862745098
+        self.mean = [0.485, 0.456, 0.406]
+        self.labels = ["0", "90", "180", "270"]
+
+    def preprocess(self, img):
+        # PIL图像转cv2
+        img = np.array(img)
+        # 放大图片,使其最短边长为256
+        h, w = img.shape[:2]
+        scale = 256 / min(h, w)
+        h_resize = round(h * scale)
+        w_resize = round(w * scale)
+        img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
+        # 调整为224*224的正方形
+        h, w = img.shape[:2]
+        cw, ch = 224, 224
+        x1 = max(0, (w - cw) // 2)
+        y1 = max(0, (h - ch) // 2)
+        x2 = min(w, x1 + cw)
+        y2 = min(h, y1 + ch)
+        if w < cw or h < ch:
+            raise ValueError(
+                f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
+            )
+        img = img[y1:y2, x1:x2, ...]
+        # 正则化
+        split_im = list(cv2.split(img))
+        std = [0.229, 0.224, 0.225]
+        scale = 0.00392156862745098
+        mean = [0.485, 0.456, 0.406]
+        alpha = [scale / std[i] for i in range(len(std))]
+        beta = [-mean[i] / std[i] for i in range(len(std))]
+        for c in range(img.shape[2]):
+            split_im[c] = split_im[c].astype(np.float32)
+            split_im[c] *= alpha[c]
+            split_im[c] += beta[c]
+        img = cv2.merge(split_im)
+        # 5. 转换为 CHW 格式
+        img = img.transpose((2, 0, 1))
+        imgs = [img]
+        x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
+        return x
+
+    def predict(self, img):
+        bgr_image = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
+        # First check the overall image aspect ratio (height/width)
+        img_height, img_width = bgr_image.shape[:2]
+        img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
+        img_is_portrait = img_aspect_ratio > 1.2
+
+        if img_is_portrait:
+
+            det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
+            # Check if table is rotated by analyzing text box aspect ratios
+            if det_res:
+                vertical_count = 0
+                is_rotated = False
+
+                for box_ocr_res in det_res:
+                    p1, p2, p3, p4 = box_ocr_res
+
+                    # Calculate width and height
+                    width = p3[0] - p1[0]
+                    height = p3[1] - p1[1]
+
+                    aspect_ratio = width / height if height > 0 else 1.0
+
+                    # Count vertical vs horizontal text boxes
+                    if aspect_ratio < 0.8:  # Taller than wide - vertical text
+                        vertical_count += 1
+                    # elif aspect_ratio > 1.2:  # Wider than tall - horizontal text
+                    #     horizontal_count += 1
+
+                if vertical_count >= len(det_res) * 0.3:
+                    is_rotated = True
+                # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
+
+                # If we have more vertical text boxes than horizontal ones,
+                # and vertical ones are significant, table might be rotated
+                if is_rotated:
+                    x = self.preprocess(img)
+                    (result,) = self.sess.run(None, {"x": x})
+                    label = self.labels[np.argmax(result)]
+
+                    if label == "90":
+                        rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
+                        img = cv2.rotate(np.asarray(img), rotation)
+                    elif label == "180":
+                        rotation = cv2.ROTATE_180
+                        img = cv2.rotate(np.asarray(img), rotation)
+                    elif label == "270":
+                        rotation = cv2.ROTATE_90_CLOCKWISE
+                        img = cv2.rotate(np.asarray(img), rotation)
+                    else:
+                        img = np.array(img)
+        return img

+ 1 - 0
mineru/model/table/cls/__init__.py

@@ -0,0 +1 @@
+# Copyright (c) Opendatalab. All rights reserved.

+ 73 - 0
mineru/model/table/cls/paddle_table_cls.py

@@ -0,0 +1,73 @@
+import os
+
+import cv2
+import numpy as np
+import onnxruntime
+from loguru import logger
+
+from mineru.backend.pipeline.model_list import AtomicModel
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
+
+
+class PaddleTableClsModel:
+    def __init__(self):
+        self.sess = onnxruntime.InferenceSession(
+            os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_table_cls), ModelPath.paddle_table_cls)
+        )
+        self.less_length = 256
+        self.cw, self.ch = 224, 224
+        self.std = [0.229, 0.224, 0.225]
+        self.scale = 0.00392156862745098
+        self.mean = [0.485, 0.456, 0.406]
+        self.labels = [AtomicModel.WiredTable, AtomicModel.WirelessTable]
+
+    def preprocess(self, img):
+        # PIL图像转cv2
+        img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        # 放大图片,使其最短边长为256
+        h, w = img.shape[:2]
+        scale = 256 / min(h, w)
+        h_resize = round(h * scale)
+        w_resize = round(w * scale)
+        img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
+        # 调整为224*224的正方形
+        h, w = img.shape[:2]
+        cw, ch = 224, 224
+        x1 = max(0, (w - cw) // 2)
+        y1 = max(0, (h - ch) // 2)
+        x2 = min(w, x1 + cw)
+        y2 = min(h, y1 + ch)
+        if w < cw or h < ch:
+            raise ValueError(
+                f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
+            )
+        img = img[y1:y2, x1:x2, ...]
+        # 正则化
+        split_im = list(cv2.split(img))
+        std = [0.229, 0.224, 0.225]
+        scale = 0.00392156862745098
+        mean = [0.485, 0.456, 0.406]
+        alpha = [scale / std[i] for i in range(len(std))]
+        beta = [-mean[i] / std[i] for i in range(len(std))]
+        for c in range(img.shape[2]):
+            split_im[c] = split_im[c].astype(np.float32)
+            split_im[c] *= alpha[c]
+            split_im[c] += beta[c]
+        img = cv2.merge(split_im)
+        # 5. 转换为 CHW 格式
+        img = img.transpose((2, 0, 1))
+        imgs = [img]
+        x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
+        return x
+
+    def predict(self, img):
+        x = self.preprocess(img)
+        result = self.sess.run(None, {"x": x})
+        idx = np.argmax(result)
+        conf = float(np.max(result))
+        # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
+        if idx == 0 and conf < 0.9:
+            idx = 1
+        return self.labels[idx]

+ 0 - 89
mineru/model/table/rapid_table.py

@@ -1,89 +0,0 @@
-import os
-import html
-import cv2
-import numpy as np
-from loguru import logger
-from rapid_table import RapidTable, RapidTableInput
-
-from mineru.utils.enum_class import ModelPath
-from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
-
-
-def escape_html(input_string):
-    """Escape HTML Entities."""
-    return html.escape(input_string)
-
-
-class RapidTableModel(object):
-    def __init__(self, ocr_engine):
-        slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
-        input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
-        self.table_model = RapidTable(input_args)
-        self.ocr_engine = ocr_engine
-
-
-    def predict(self, image):
-        bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
-
-        # First check the overall image aspect ratio (height/width)
-        img_height, img_width = bgr_image.shape[:2]
-        img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
-        img_is_portrait = img_aspect_ratio > 1.2
-
-        if img_is_portrait:
-
-            det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
-            # Check if table is rotated by analyzing text box aspect ratios
-            is_rotated = False
-            if det_res:
-                vertical_count = 0
-
-                for box_ocr_res in det_res:
-                    p1, p2, p3, p4 = box_ocr_res
-
-                    # Calculate width and height
-                    width = p3[0] - p1[0]
-                    height = p3[1] - p1[1]
-
-                    aspect_ratio = width / height if height > 0 else 1.0
-
-                    # Count vertical vs horizontal text boxes
-                    if aspect_ratio < 0.8:  # Taller than wide - vertical text
-                        vertical_count += 1
-                    # elif aspect_ratio > 1.2:  # Wider than tall - horizontal text
-                    #     horizontal_count += 1
-
-                # If we have more vertical text boxes than horizontal ones,
-                # and vertical ones are significant, table might be rotated
-                if vertical_count >= len(det_res) * 0.3:
-                    is_rotated = True
-
-                # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
-
-            # Rotate image if necessary
-            if is_rotated:
-                # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise")
-                image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE)
-                bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
-
-        # Continue with OCR on potentially rotated image
-        ocr_result = self.ocr_engine.ocr(bgr_image)[0]
-        if ocr_result:
-            ocr_result = [[item[0], escape_html(item[1][0]), item[1][1]] for item in ocr_result if
-                      len(item) == 2 and isinstance(item[1], tuple)]
-        else:
-            ocr_result = None
-
-
-        if ocr_result:
-            try:
-                table_results = self.table_model(np.asarray(image), ocr_result)
-                html_code = table_results.pred_html
-                table_cell_bboxes = table_results.cell_bboxes
-                logic_points = table_results.logic_points
-                elapse = table_results.elapse
-                return html_code, table_cell_bboxes, logic_points, elapse
-            except Exception as e:
-                logger.exception(e)
-
-        return None, None, None, None

+ 1 - 0
mineru/model/table/rec/__init__.py

@@ -0,0 +1 @@
+# Copyright (c) Opendatalab. All rights reserved.

+ 46 - 0
mineru/model/table/rec/rapid_table.py

@@ -0,0 +1,46 @@
+import os
+import html
+import cv2
+import numpy as np
+from loguru import logger
+from rapid_table import RapidTable, RapidTableInput
+
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
+
+
+def escape_html(input_string):
+    """Escape HTML Entities."""
+    return html.escape(input_string)
+
+
+class RapidTableModel(object):
+    def __init__(self, ocr_engine):
+        slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
+        input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
+        self.table_model = RapidTable(input_args)
+        self.ocr_engine = ocr_engine
+
+
+    def predict(self, image):
+        bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
+        # Continue with OCR on potentially rotated image
+        ocr_result = self.ocr_engine.ocr(bgr_image)[0]
+        if ocr_result:
+            ocr_result = [[item[0], escape_html(item[1][0]), item[1][1]] for item in ocr_result if
+                      len(item) == 2 and isinstance(item[1], tuple)]
+        else:
+            ocr_result = None
+
+        if ocr_result:
+            try:
+                table_results = self.table_model(np.asarray(image), ocr_result)
+                html_code = table_results.pred_html
+                table_cell_bboxes = table_results.cell_bboxes
+                logic_points = table_results.logic_points
+                elapse = table_results.elapse
+                return html_code, table_cell_bboxes, logic_points, elapse
+            except Exception as e:
+                logger.exception(e)
+
+        return None, None, None, None

+ 1 - 0
mineru/model/table/rec/unet_table/__init__.py

@@ -0,0 +1 @@
+# Copyright (c) Opendatalab. All rights reserved.

+ 285 - 0
mineru/model/table/rec/unet_table/table_line_rec_utils.py

@@ -0,0 +1,285 @@
+import math
+
+import cv2
+import numpy as np
+from scipy.spatial import distance as dist
+from skimage import measure
+
+
+def get_table_line(binimg, axis=0, lineW=10):
+    ##获取表格线
+    ##axis=0 横线
+    ##axis=1 竖线
+    labels = measure.label(binimg > 0, connectivity=2)  # 8连通区域标记
+    regions = measure.regionprops(labels)
+    if axis == 1:
+        lineboxes = [
+            min_area_rect(line.coords)
+            for line in regions
+            if line.bbox[2] - line.bbox[0] > lineW
+        ]
+    else:
+        lineboxes = [
+            min_area_rect(line.coords)
+            for line in regions
+            if line.bbox[3] - line.bbox[1] > lineW
+        ]
+    return lineboxes
+
+
+def min_area_rect(coords):
+    """
+    多边形外接矩形
+    """
+    rect = cv2.minAreaRect(coords[:, ::-1])
+    box = cv2.boxPoints(rect)
+    box = box.reshape((8,)).tolist()
+
+    box = image_location_sort_box(box)
+
+    x1, y1, x2, y2, x3, y3, x4, y4 = box
+    degree, w, h, cx, cy = calculate_center_rotate_angle(box)
+    if w < h:
+        xmin = (x1 + x2) / 2
+        xmax = (x3 + x4) / 2
+        ymin = (y1 + y2) / 2
+        ymax = (y3 + y4) / 2
+
+    else:
+        xmin = (x1 + x4) / 2
+        xmax = (x2 + x3) / 2
+        ymin = (y1 + y4) / 2
+        ymax = (y2 + y3) / 2
+    # degree,w,h,cx,cy = solve(box)
+    # x1,y1,x2,y2,x3,y3,x4,y4 = box
+    # return {'degree':degree,'w':w,'h':h,'cx':cx,'cy':cy}
+    return [xmin, ymin, xmax, ymax]
+
+
+def image_location_sort_box(box):
+    x1, y1, x2, y2, x3, y3, x4, y4 = box[:8]
+    pts = (x1, y1), (x2, y2), (x3, y3), (x4, y4)
+    pts = np.array(pts, dtype="float32")
+    (x1, y1), (x2, y2), (x3, y3), (x4, y4) = _order_points(pts)
+    return [x1, y1, x2, y2, x3, y3, x4, y4]
+
+
+def calculate_center_rotate_angle(box):
+    """
+    绕 cx,cy点 w,h 旋转 angle 的坐标,能一定程度缓解图片的内部倾斜,但是还是依赖模型稳妥
+    x = cx-w/2
+    y = cy-h/2
+    x1-cx = -w/2*cos(angle) +h/2*sin(angle)
+    y1 -cy= -w/2*sin(angle) -h/2*cos(angle)
+
+    h(x1-cx) = -wh/2*cos(angle) +hh/2*sin(angle)
+    w(y1 -cy)= -ww/2*sin(angle) -hw/2*cos(angle)
+    (hh+ww)/2sin(angle) = h(x1-cx)-w(y1 -cy)
+
+    """
+    x1, y1, x2, y2, x3, y3, x4, y4 = box[:8]
+    cx = (x1 + x3 + x2 + x4) / 4.0
+    cy = (y1 + y3 + y4 + y2) / 4.0
+    w = (
+        np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
+        + np.sqrt((x3 - x4) ** 2 + (y3 - y4) ** 2)
+    ) / 2
+    h = (
+        np.sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2)
+        + np.sqrt((x1 - x4) ** 2 + (y1 - y4) ** 2)
+    ) / 2
+    # x = cx-w/2
+    # y = cy-h/2
+    sinA = (h * (x1 - cx) - w * (y1 - cy)) * 1.0 / (h * h + w * w) * 2
+    angle = np.arcsin(sinA)
+    return angle, w, h, cx, cy
+
+
+def _order_points(pts):
+    # 根据x坐标对点进行排序
+    """
+    ---------------------
+    本项目中是为了排序后得到[(xmin,ymin),(xmax,ymin),(xmax,ymax),(xmin,ymax)]
+    作者:Tong_T
+    来源:CSDN
+    原文:https://blog.csdn.net/Tong_T/article/details/81907132
+    版权声明:本文为博主原创文章,转载请附上博文链接!
+    """
+    x_sorted = pts[np.argsort(pts[:, 0]), :]
+
+    left_most = x_sorted[:2, :]
+    right_most = x_sorted[2:, :]
+    left_most = left_most[np.argsort(left_most[:, 1]), :]
+    (tl, bl) = left_most
+
+    distance = dist.cdist(tl[np.newaxis], right_most, "euclidean")[0]
+    (br, tr) = right_most[np.argsort(distance)[::-1], :]
+
+    return np.array([tl, tr, br, bl], dtype="float32")
+
+
+def sqrt(p1, p2):
+    return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
+
+
+def adjust_lines(lines, alph=50, angle=50):
+    lines_n = len(lines)
+    new_lines = []
+    for i in range(lines_n):
+        x1, y1, x2, y2 = lines[i]
+        cx1, cy1 = (x1 + x2) / 2, (y1 + y2) / 2
+        for j in range(lines_n):
+            if i != j:
+                x3, y3, x4, y4 = lines[j]
+                cx2, cy2 = (x3 + x4) / 2, (y3 + y4) / 2
+                if (x3 < cx1 < x4 or y3 < cy1 < y4) or (
+                    x1 < cx2 < x2 or y1 < cy2 < y2
+                ):  # 判断两个横线在y方向的投影重不重合
+                    continue
+                else:
+                    r = sqrt((x1, y1), (x3, y3))
+                    k = abs((y3 - y1) / (x3 - x1 + 1e-10))
+                    a = math.atan(k) * 180 / math.pi
+                    if r < alph and a < angle:
+                        new_lines.append((x1, y1, x3, y3))
+
+                    r = sqrt((x1, y1), (x4, y4))
+                    k = abs((y4 - y1) / (x4 - x1 + 1e-10))
+                    a = math.atan(k) * 180 / math.pi
+                    if r < alph and a < angle:
+                        new_lines.append((x1, y1, x4, y4))
+
+                    r = sqrt((x2, y2), (x3, y3))
+                    k = abs((y3 - y2) / (x3 - x2 + 1e-10))
+                    a = math.atan(k) * 180 / math.pi
+                    if r < alph and a < angle:
+                        new_lines.append((x2, y2, x3, y3))
+                    r = sqrt((x2, y2), (x4, y4))
+                    k = abs((y4 - y2) / (x4 - x2 + 1e-10))
+                    a = math.atan(k) * 180 / math.pi
+                    if r < alph and a < angle:
+                        new_lines.append((x2, y2, x4, y4))
+    return new_lines
+
+
+def final_adjust_lines(rowboxes, colboxes):
+    nrow = len(rowboxes)
+    ncol = len(colboxes)
+    for i in range(nrow):
+        for j in range(ncol):
+            rowboxes[i] = line_to_line(rowboxes[i], colboxes[j], alpha=20, angle=30)
+            colboxes[j] = line_to_line(colboxes[j], rowboxes[i], alpha=20, angle=30)
+    return rowboxes, colboxes
+
+
+def draw_lines(im, bboxes, color=(0, 0, 0), lineW=3):
+    """
+    boxes: bounding boxes
+    """
+    tmp = np.copy(im)
+    c = color
+    h, w = im.shape[:2]
+
+    for box in bboxes:
+        x1, y1, x2, y2 = box[:4]
+        cv2.line(
+            tmp, (int(x1), int(y1)), (int(x2), int(y2)), c, lineW, lineType=cv2.LINE_AA
+        )
+
+    return tmp
+
+
+def line_to_line(points1, points2, alpha=10, angle=30):
+    """
+    线段之间的距离
+    """
+    x1, y1, x2, y2 = points1
+    ox1, oy1, ox2, oy2 = points2
+    xy = np.array([(x1, y1), (x2, y2)], dtype="float32")
+    A1, B1, C1 = fit_line(xy)
+    oxy = np.array([(ox1, oy1), (ox2, oy2)], dtype="float32")
+    A2, B2, C2 = fit_line(oxy)
+    flag1 = point_line_cor(np.array([x1, y1], dtype="float32"), A2, B2, C2)
+    flag2 = point_line_cor(np.array([x2, y2], dtype="float32"), A2, B2, C2)
+
+    if (flag1 > 0 and flag2 > 0) or (flag1 < 0 and flag2 < 0):  # 横线或者竖线在竖线或者横线的同一侧
+        if (A1 * B2 - A2 * B1) != 0:
+            x = (B1 * C2 - B2 * C1) / (A1 * B2 - A2 * B1)
+            y = (A2 * C1 - A1 * C2) / (A1 * B2 - A2 * B1)
+            # x, y = round(x, 2), round(y, 2)
+            p = (x, y)  # 横线与竖线的交点
+            r0 = sqrt(p, (x1, y1))
+            r1 = sqrt(p, (x2, y2))
+
+            if min(r0, r1) < alpha:  # 若交点与线起点或者终点的距离小于alpha,则延长线到交点
+                if r0 < r1:
+                    k = abs((y2 - p[1]) / (x2 - p[0] + 1e-10))
+                    a = math.atan(k) * 180 / math.pi
+                    if a < angle or abs(90 - a) < angle:
+                        points1 = np.array([p[0], p[1], x2, y2], dtype="float32")
+                else:
+                    k = abs((y1 - p[1]) / (x1 - p[0] + 1e-10))
+                    a = math.atan(k) * 180 / math.pi
+                    if a < angle or abs(90 - a) < angle:
+                        points1 = np.array([x1, y1, p[0], p[1]], dtype="float32")
+    return points1
+
+
+def min_area_rect_box(
+    regions, flag=True, W=0, H=0, filtersmall=False, adjust_box=False
+):
+    """
+    多边形外接矩形
+    """
+    boxes = []
+    for region in regions:
+        if region.bbox_area > H * W * 3 / 4:  # 过滤大的单元格
+            continue
+        rect = cv2.minAreaRect(region.coords[:, ::-1])
+
+        box = cv2.boxPoints(rect)
+        box = box.reshape((8,)).tolist()
+        box = image_location_sort_box(box)
+        x1, y1, x2, y2, x3, y3, x4, y4 = box
+        angle, w, h, cx, cy = calculate_center_rotate_angle(box)
+        # if adjustBox:
+        #     x1, y1, x2, y2, x3, y3, x4, y4 = xy_rotate_box(cx, cy, w + 5, h + 5, angle=0, degree=None)
+        #     x1, x4 = max(x1, 0), max(x4, 0)
+        #     y1, y2 = max(y1, 0), max(y2, 0)
+
+        # if w > 32 and h > 32 and flag:
+        #     if abs(angle / np.pi * 180) < 20:
+        #         if filtersmall and (w < 10 or h < 10):
+        #             continue
+        #         boxes.append([x1, y1, x2, y2, x3, y3, x4, y4])
+        # else:
+        if w * h < 0.5 * W * H:
+            if filtersmall and (
+                w < 15 or h < 15
+            ):  # or w / h > 30 or h / w > 30): # 过滤小的单元格
+                continue
+            boxes.append([x1, y1, x2, y2, x3, y3, x4, y4])
+    return boxes
+
+
+def point_line_cor(p, A, B, C):
+    ##判断点与线之间的位置关系
+    # 一般式直线方程(Ax+By+c)=0
+    x, y = p
+    r = A * x + B * y + C
+    return r
+
+
+def fit_line(p):
+    """A = Y2 - Y1
+       B = X1 - X2
+       C = X2*Y1 - X1*Y2
+       AX+BY+C=0
+    直线一般方程
+    """
+    x1, y1 = p[0]
+    x2, y2 = p[1]
+    A = y2 - y1
+    B = x1 - x2
+    C = x2 * y1 - x1 * y2
+    return A, B, C

+ 213 - 0
mineru/model/table/rec/unet_table/table_recover.py

@@ -0,0 +1,213 @@
+from typing import Dict, List, Tuple
+import numpy as np
+
+
+class TableRecover:
+    def __init__(
+        self,
+    ):
+        pass
+
+    def __call__(
+        self, polygons: np.ndarray, rows_thresh=10, col_thresh=15
+    ) -> Dict[int, Dict]:
+        rows = self.get_rows(polygons, rows_thresh)
+        longest_col, each_col_widths, col_nums = self.get_benchmark_cols(
+            rows, polygons, col_thresh
+        )
+        each_row_heights, row_nums = self.get_benchmark_rows(rows, polygons)
+        table_res, logic_points_dict = self.get_merge_cells(
+            polygons,
+            rows,
+            row_nums,
+            col_nums,
+            longest_col,
+            each_col_widths,
+            each_row_heights,
+        )
+        logic_points = np.array(
+            [logic_points_dict[i] for i in range(len(polygons))]
+        ).astype(np.int32)
+        return table_res, logic_points
+
+    @staticmethod
+    def get_rows(polygons: np.array, rows_thresh=10) -> Dict[int, List[int]]:
+        """对每个框进行行分类,框定哪个是一行的"""
+        y_axis = polygons[:, 0, 1]
+        if y_axis.size == 1:
+            return {0: [0]}
+
+        concat_y = np.array(list(zip(y_axis, y_axis[1:])))
+        minus_res = concat_y[:, 1] - concat_y[:, 0]
+
+        result = {}
+        split_idxs = np.argwhere(abs(minus_res) > rows_thresh).squeeze()
+        # 如果都在一行,则将所有下标设置为同一行
+        if split_idxs.size == 0:
+            return {0: [i for i in range(len(y_axis))]}
+        if split_idxs.ndim == 0:
+            split_idxs = split_idxs[None, ...]
+
+        if max(split_idxs) != len(minus_res):
+            split_idxs = np.append(split_idxs, len(minus_res))
+
+        start_idx = 0
+        for row_num, idx in enumerate(split_idxs):
+            if row_num != 0:
+                start_idx = split_idxs[row_num - 1] + 1
+            result.setdefault(row_num, []).extend(range(start_idx, idx + 1))
+
+        # 计算每一行相邻cell的iou,如果大于0.2,则合并为同一个cell
+        return result
+
+    def get_benchmark_cols(
+        self, rows: Dict[int, List], polygons: np.ndarray, col_thresh=15
+    ) -> Tuple[np.ndarray, List[float], int]:
+        longest_col = max(rows.values(), key=lambda x: len(x))
+        longest_col_points = polygons[longest_col]
+        longest_x_start = list(longest_col_points[:, 0, 0])
+        longest_x_end = list(longest_col_points[:, 2, 0])
+        min_x = longest_x_start[0]
+        max_x = longest_x_end[-1]
+
+        # 根据当前col的起始x坐标,更新col的边界
+        # 2025.2.22 --- 解决最长列可能漏掉最后一列的问题
+        def update_longest_col(col_x_list, cur_v, min_x_, max_x_, insert_last):
+            for i, v in enumerate(col_x_list):
+                if cur_v - col_thresh <= v <= cur_v + col_thresh:
+                    break
+                if cur_v < min_x_:
+                    col_x_list.insert(0, cur_v)
+                    min_x_ = cur_v
+                    break
+                if cur_v > max_x_:
+                    if insert_last:
+                        col_x_list.append(cur_v)
+                    max_x_ = cur_v
+                    break
+                if cur_v < v:
+                    col_x_list.insert(i, cur_v)
+                    break
+            return min_x_, max_x_
+
+        for row_value in rows.values():
+            cur_row_start = list(polygons[row_value][:, 0, 0])
+            cur_row_end = list(polygons[row_value][:, 2, 0])
+            for idx, (cur_v_start, cur_v_end) in enumerate(
+                zip(cur_row_start, cur_row_end)
+            ):
+                min_x, max_x = update_longest_col(
+                    longest_x_start, cur_v_start, min_x, max_x, True
+                )
+                min_x, max_x = update_longest_col(
+                    longest_x_start, cur_v_end, min_x, max_x, False
+                )
+
+        longest_x_start = np.array(longest_x_start)
+        each_col_widths = (longest_x_start[1:] - longest_x_start[:-1]).tolist()
+        each_col_widths.append(max_x - longest_x_start[-1])
+        col_nums = longest_x_start.shape[0]
+        return longest_x_start, each_col_widths, col_nums
+
+    def get_benchmark_rows(
+        self, rows: Dict[int, List], polygons: np.ndarray
+    ) -> Tuple[np.ndarray, List[float], int]:
+        leftmost_cell_idxs = [v[0] for v in rows.values()]
+        benchmark_x = polygons[leftmost_cell_idxs][:, 0, 1]
+
+        each_row_widths = (benchmark_x[1:] - benchmark_x[:-1]).tolist()
+
+        # 求出最后一行cell中,最大的高度作为最后一行的高度
+        bottommost_idxs = list(rows.values())[-1]
+        bottommost_boxes = polygons[bottommost_idxs]
+        # fix self.compute_L2(v[3, :], v[0, :]), v为逆时针,即v[3]为右上,v[0]为左上,v[1]为左下
+        max_height = max([self.compute_L2(v[1, :], v[0, :]) for v in bottommost_boxes])
+        each_row_widths.append(max_height)
+
+        row_nums = benchmark_x.shape[0]
+        return each_row_widths, row_nums
+
+    @staticmethod
+    def compute_L2(a1: np.ndarray, a2: np.ndarray) -> float:
+        return np.linalg.norm(a2 - a1)
+
+    def get_merge_cells(
+        self,
+        polygons: np.ndarray,
+        rows: Dict,
+        row_nums: int,
+        col_nums: int,
+        longest_col: np.ndarray,
+        each_col_widths: List[float],
+        each_row_heights: List[float],
+    ) -> Dict[int, Dict[int, int]]:
+        col_res_merge, row_res_merge = {}, {}
+        logic_points = {}
+        merge_thresh = 10
+        for cur_row, col_list in rows.items():
+            one_col_result, one_row_result = {}, {}
+            for one_col in col_list:
+                box = polygons[one_col]
+                box_width = self.compute_L2(box[3, :], box[0, :])
+
+                # 不一定是从0开始的,应该综合已有值和x坐标位置来确定起始位置
+                loc_col_idx = np.argmin(np.abs(longest_col - box[0, 0]))
+                col_start = max(sum(one_col_result.values()), loc_col_idx)
+
+                # 计算合并多少个列方向单元格
+                for i in range(col_start, col_nums):
+                    col_cum_sum = sum(each_col_widths[col_start : i + 1])
+                    if i == col_start and col_cum_sum > box_width:
+                        one_col_result[one_col] = 1
+                        break
+                    elif abs(col_cum_sum - box_width) <= merge_thresh:
+                        one_col_result[one_col] = i + 1 - col_start
+                        break
+                    # 这里必须进行修正,不然会出现超越阈值范围后列交错
+                    elif col_cum_sum > box_width:
+                        idx = (
+                            i
+                            if abs(col_cum_sum - box_width)
+                            < abs(col_cum_sum - each_col_widths[i] - box_width)
+                            else i - 1
+                        )
+                        one_col_result[one_col] = idx + 1 - col_start
+                        break
+                else:
+                    one_col_result[one_col] = col_nums - col_start
+                col_end = one_col_result[one_col] + col_start - 1
+                box_height = self.compute_L2(box[1, :], box[0, :])
+                row_start = cur_row
+                for j in range(row_start, row_nums):
+                    row_cum_sum = sum(each_row_heights[row_start : j + 1])
+                    # box_height 不确定是几行的高度,所以要逐个试验,找一个最近的几行的高
+                    # 如果第一次row_cum_sum就比box_height大,那么意味着?丢失了一行
+                    if j == row_start and row_cum_sum > box_height:
+                        one_row_result[one_col] = 1
+                        break
+                    elif abs(box_height - row_cum_sum) <= merge_thresh:
+                        one_row_result[one_col] = j + 1 - row_start
+                        break
+                    # 这里必须进行修正,不然会出现超越阈值范围后行交错
+                    elif row_cum_sum > box_height:
+                        idx = (
+                            j
+                            if abs(row_cum_sum - box_height)
+                            < abs(row_cum_sum - each_row_heights[j] - box_height)
+                            else j - 1
+                        )
+                        one_row_result[one_col] = idx + 1 - row_start
+                        break
+                else:
+                    one_row_result[one_col] = row_nums - row_start
+                row_end = one_row_result[one_col] + row_start - 1
+                logic_points[one_col] = np.array(
+                    [row_start, row_end, col_start, col_end]
+                )
+            col_res_merge[cur_row] = one_col_result
+            row_res_merge[cur_row] = one_row_result
+
+        res = {}
+        for i, (c, r) in enumerate(zip(col_res_merge.values(), row_res_merge.values())):
+            res[i] = {k: [cc, r[k]] for k, cc in c.items()}
+        return res, logic_points

+ 420 - 0
mineru/model/table/rec/unet_table/table_recover_utils.py

@@ -0,0 +1,420 @@
+from typing import Any, Dict, List, Union, Tuple
+
+import numpy as np
+import shapely
+from shapely.geometry import MultiPoint, Polygon
+
+
+def sorted_boxes(dt_boxes: np.ndarray) -> np.ndarray:
+    """
+    Sort text boxes in order from top to bottom, left to right
+    args:
+        dt_boxes(array):detected text boxes with shape (N, 4, 2)
+    return:
+        sorted boxes(array) with shape (N, 4, 2)
+    """
+    num_boxes = dt_boxes.shape[0]
+    dt_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
+    _boxes = list(dt_boxes)
+
+    # 解决相邻框,后边比前面y轴小,则会被排到前面去的问题
+    for i in range(num_boxes - 1):
+        for j in range(i, -1, -1):
+            if (
+                abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10
+                and _boxes[j + 1][0][0] < _boxes[j][0][0]
+            ):
+                _boxes[j], _boxes[j + 1] = _boxes[j + 1], _boxes[j]
+            else:
+                break
+    return np.array(_boxes)
+
+
+def calculate_iou(
+    box1: Union[np.ndarray, List], box2: Union[np.ndarray, List]
+) -> float:
+    """
+    :param box1: Iterable [xmin,ymin,xmax,ymax]
+    :param box2: Iterable [xmin,ymin,xmax,ymax]
+    :return: iou: float 0-1
+    """
+    b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
+    b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
+    # 不相交直接退出检测
+    if b1_x2 < b2_x1 or b1_x1 > b2_x2 or b1_y2 < b2_y1 or b1_y1 > b2_y2:
+        return 0.0
+    # 计算交集
+    inter_x1 = max(b1_x1, b2_x1)
+    inter_y1 = max(b1_y1, b2_y1)
+    inter_x2 = min(b1_x2, b2_x2)
+    inter_y2 = min(b1_y2, b2_y2)
+    i_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
+
+    # 计算并集
+    b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
+    b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
+    u_area = b1_area + b2_area - i_area
+
+    # 避免除零错误,如果区域小到乘积为0,认为是错误识别,直接去掉
+    if u_area == 0:
+        return 1
+        # 检查完全包含
+    iou = i_area / u_area
+    return iou
+
+
+def is_box_contained(
+    box1: Union[np.ndarray, List], box2: Union[np.ndarray, List], threshold=0.2
+) -> Union[int, None]:
+    """
+    :param box1: Iterable [xmin,ymin,xmax,ymax]
+    :param box2: Iterable [xmin,ymin,xmax,ymax]
+    :return: 1: box1 is contained 2: box2 is contained None: no contain these
+    """
+    b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
+    b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
+    # 不相交直接退出检测
+    if b1_x2 < b2_x1 or b1_x1 > b2_x2 or b1_y2 < b2_y1 or b1_y1 > b2_y2:
+        return None
+    # 计算box2的总面积
+    b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
+    b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
+
+    # 计算box1和box2的交集
+    intersect_x1 = max(b1_x1, b2_x1)
+    intersect_y1 = max(b1_y1, b2_y1)
+    intersect_x2 = min(b1_x2, b2_x2)
+    intersect_y2 = min(b1_y2, b2_y2)
+
+    # 计算交集的面积
+    intersect_area = max(0, intersect_x2 - intersect_x1) * max(
+        0, intersect_y2 - intersect_y1
+    )
+
+    # 计算外面的面积
+    b1_outside_area = b1_area - intersect_area
+    b2_outside_area = b2_area - intersect_area
+
+    # 计算外面的面积占box2总面积的比例
+    ratio_b1 = b1_outside_area / b1_area if b1_area > 0 else 0
+    ratio_b2 = b2_outside_area / b2_area if b2_area > 0 else 0
+
+    if ratio_b1 < threshold:
+        return 1
+    if ratio_b2 < threshold:
+        return 2
+    # 判断比例是否大于阈值
+    return None
+
+
+def is_single_axis_contained(
+    box1: Union[np.ndarray, List],
+    box2: Union[np.ndarray, List],
+    axis="x",
+    threshold: float = 0.2,
+) -> Union[int, None]:
+    """
+    :param box1: Iterable [xmin,ymin,xmax,ymax]
+    :param box2: Iterable [xmin,ymin,xmax,ymax]
+    :return: 1: box1 is contained 2: box2 is contained None: no contain these
+    """
+    b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
+    b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
+
+    # 计算轴重叠大小
+    if axis == "x":
+        b1_area = b1_x2 - b1_x1
+        b2_area = b2_x2 - b2_x1
+        i_area = min(b1_x2, b2_x2) - max(b1_x1, b2_x1)
+    else:
+        b1_area = b1_y2 - b1_y1
+        b2_area = b2_y2 - b2_y1
+        i_area = min(b1_y2, b2_y2) - max(b1_y1, b2_y1)
+        # 计算外面的面积
+    b1_outside_area = b1_area - i_area
+    b2_outside_area = b2_area - i_area
+
+    ratio_b1 = b1_outside_area / b1_area if b1_area > 0 else 0
+    ratio_b2 = b2_outside_area / b2_area if b2_area > 0 else 0
+    if ratio_b1 < threshold:
+        return 1
+    if ratio_b2 < threshold:
+        return 2
+    return None
+
+
+def sorted_ocr_boxes(
+    dt_boxes: Union[np.ndarray, list], threshold: float = 0.2
+) -> Tuple[Union[np.ndarray, list], List[int]]:
+    """
+    Sort text boxes in order from top to bottom, left to right
+    args:
+        dt_boxes(array):detected text boxes with (xmin, ymin, xmax, ymax)
+    return:
+        sorted boxes(array) with (xmin, ymin, xmax, ymax)
+    """
+    num_boxes = len(dt_boxes)
+    if num_boxes <= 0:
+        return dt_boxes, []
+    indexed_boxes = [(box, idx) for idx, box in enumerate(dt_boxes)]
+    sorted_boxes_with_idx = sorted(indexed_boxes, key=lambda x: (x[0][1], x[0][0]))
+    _boxes, indices = zip(*sorted_boxes_with_idx)
+    indices = list(indices)
+    _boxes = [dt_boxes[i] for i in indices]
+    # 避免输出和输入格式不对应,与函数功能不符合
+    if isinstance(dt_boxes, np.ndarray):
+        _boxes = np.array(_boxes)
+    for i in range(num_boxes - 1):
+        for j in range(i, -1, -1):
+            c_idx = is_single_axis_contained(
+                _boxes[j], _boxes[j + 1], axis="y", threshold=threshold
+            )
+            if (
+                c_idx is not None
+                and _boxes[j + 1][0] < _boxes[j][0]
+                and abs(_boxes[j][1] - _boxes[j + 1][1]) < 20
+            ):
+                _boxes[j], _boxes[j + 1] = _boxes[j + 1].copy(), _boxes[j].copy()
+                indices[j], indices[j + 1] = indices[j + 1], indices[j]
+            else:
+                break
+    return _boxes, indices
+
+
+def box_4_2_poly_to_box_4_1(poly_box: Union[list, np.ndarray]) -> List[Any]:
+    """
+    将poly_box转换为box_4_1
+    :param poly_box:
+    :return:
+    """
+    return [poly_box[0][0], poly_box[0][1], poly_box[2][0], poly_box[2][1]]
+
+
+def match_ocr_cell(dt_rec_boxes: List[List[Union[Any, str]]], pred_bboxes: np.ndarray):
+    """
+    :param dt_rec_boxes: [[(4.2), text, score]]
+    :param pred_bboxes: shap (4,2)
+    :return:
+    """
+    matched = {}
+    not_match_orc_boxes = []
+    for i, gt_box in enumerate(dt_rec_boxes):
+        for j, pred_box in enumerate(pred_bboxes):
+            pred_box = [pred_box[0][0], pred_box[0][1], pred_box[2][0], pred_box[2][1]]
+            ocr_boxes = gt_box[0]
+            # xmin,ymin,xmax,ymax
+            ocr_box = (
+                ocr_boxes[0][0],
+                ocr_boxes[0][1],
+                ocr_boxes[2][0],
+                ocr_boxes[2][1],
+            )
+            contained = is_box_contained(ocr_box, pred_box, 0.6)
+            if contained == 1 or calculate_iou(ocr_box, pred_box) > 0.8:
+                if j not in matched:
+                    matched[j] = [gt_box]
+                else:
+                    matched[j].append(gt_box)
+            else:
+                not_match_orc_boxes.append(gt_box)
+
+    return matched, not_match_orc_boxes
+
+
+def gather_ocr_list_by_row(ocr_list: List[Any], threshold: float = 0.2) -> List[Any]:
+    """
+        Groups OCR results by row based on the vertical (y-axis) overlap of their bounding boxes.
+    Args:
+        ocr_list (List[Any]): A list of OCR results, where each item is a list containing a bounding box
+            in the format [xmin, ymin, xmax, ymax] and the recognized text.
+        threshold (float, optional): The threshold for determining if two boxes are in the same row,
+            based on their y-axis overlap. Default is 0.2.
+    Returns:
+        List[Any]: A new list of OCR results where texts in the same row are merged, and their bounding
+            boxes are updated to encompass the merged text.
+    """
+    for i in range(len(ocr_list)):
+        if not ocr_list[i]:
+            continue
+
+        for j in range(i + 1, len(ocr_list)):
+            if not ocr_list[j]:
+                continue
+            cur = ocr_list[i]
+            next = ocr_list[j]
+            cur_box = cur[0]
+            next_box = next[0]
+            c_idx = is_single_axis_contained(
+                cur[0], next[0], axis="y", threshold=threshold
+            )
+            if c_idx:
+                dis = max(next_box[0] - cur_box[2], 0)
+                blank_str = int(dis / 10) * " "
+                cur[1] = cur[1] + blank_str + next[1]
+                xmin = min(cur_box[0], next_box[0])
+                xmax = max(cur_box[2], next_box[2])
+                ymin = min(cur_box[1], next_box[1])
+                ymax = max(cur_box[3], next_box[3])
+                cur_box[0] = xmin
+                cur_box[1] = ymin
+                cur_box[2] = xmax
+                cur_box[3] = ymax
+                ocr_list[j] = None
+    ocr_list = [x for x in ocr_list if x]
+    return ocr_list
+
+
+def compute_poly_iou(a: np.ndarray, b: np.ndarray) -> float:
+    """计算两个多边形的IOU
+
+    Args:
+        poly1 (np.ndarray): (4, 2)
+        poly2 (np.ndarray): (4, 2)
+
+    Returns:
+        float: iou
+    """
+    poly1 = Polygon(a).convex_hull
+    poly2 = Polygon(b).convex_hull
+
+    union_poly = np.concatenate((a, b))
+
+    if not poly1.intersects(poly2):
+        return 0.0
+
+    try:
+        inter_area = poly1.intersection(poly2).area
+        union_area = MultiPoint(union_poly).convex_hull.area
+    except shapely.geos.TopologicalError:
+        print("shapely.geos.TopologicalError occured, iou set to 0")
+        return 0.0
+
+    if union_area == 0:
+        return 0.0
+
+    return float(inter_area) / union_area
+
+
+def merge_adjacent_polys(polygons: np.ndarray) -> np.ndarray:
+    """合并相邻iou大于阈值的框"""
+    combine_iou_thresh = 0.1
+    pair_polygons = list(zip(polygons, polygons[1:, ...]))
+    pair_ious = np.array([compute_poly_iou(p1, p2) for p1, p2 in pair_polygons])
+    idxs = np.argwhere(pair_ious >= combine_iou_thresh)
+
+    if idxs.size <= 0:
+        return polygons
+
+    polygons = combine_two_poly(polygons, idxs)
+
+    # 注意:递归调用
+    polygons = merge_adjacent_polys(polygons)
+    return polygons
+
+
+def combine_two_poly(polygons: np.ndarray, idxs: np.ndarray) -> np.ndarray:
+    del_idxs, insert_boxes = [], []
+    idxs = idxs.squeeze(-1)
+    for idx in idxs:
+        # idx 和 idx + 1 是重合度过高的
+        # 合并,取两者各个点的最大值
+        new_poly = []
+        pre_poly, pos_poly = polygons[idx], polygons[idx + 1]
+
+        # 四个点,每个点逐一比较
+        new_poly.append(np.minimum(pre_poly[0], pos_poly[0]))
+
+        x_2 = min(pre_poly[1][0], pos_poly[1][0])
+        y_2 = max(pre_poly[1][1], pos_poly[1][1])
+        new_poly.append([x_2, y_2])
+
+        # 第3个点
+        new_poly.append(np.maximum(pre_poly[2], pos_poly[2]))
+
+        # 第4个点
+        x_4 = max(pre_poly[3][0], pos_poly[3][0])
+        y_4 = min(pre_poly[3][1], pos_poly[3][1])
+        new_poly.append([x_4, y_4])
+
+        new_poly = np.array(new_poly)
+
+        # 删除已经合并的两个框,插入新的框
+        del_idxs.extend([idx, idx + 1])
+        insert_boxes.append(new_poly)
+
+    # 整合合并后的框
+    polygons = np.delete(polygons, del_idxs, axis=0)
+
+    insert_boxes = np.array(insert_boxes)
+    polygons = np.append(polygons, insert_boxes, axis=0)
+    polygons = sorted_boxes(polygons)
+    return polygons
+
+
+def plot_html_table(
+    logi_points: Union[Union[np.ndarray, List]], cell_box_map: Dict[int, List[str]]
+) -> str:
+    # 初始化最大行数和列数
+    max_row = 0
+    max_col = 0
+    # 计算最大行数和列数
+    for point in logi_points:
+        max_row = max(max_row, point[1] + 1)  # 加1是因为结束下标是包含在内的
+        max_col = max(max_col, point[3] + 1)  # 加1是因为结束下标是包含在内的
+
+    # 创建一个二维数组来存储 sorted_logi_points 中的元素
+    grid = [[None] * max_col for _ in range(max_row)]
+
+    valid_start_row = (1 << 16) - 1
+    valid_start_col = (1 << 16) - 1
+    valid_end_col = 0
+    # 将 sorted_logi_points 中的元素填充到 grid 中
+    for i, logic_point in enumerate(logi_points):
+        row_start, row_end, col_start, col_end = (
+            logic_point[0],
+            logic_point[1],
+            logic_point[2],
+            logic_point[3],
+        )
+        ocr_rec_text_list = cell_box_map.get(i)
+        if ocr_rec_text_list and "".join(ocr_rec_text_list):
+            valid_start_row = min(row_start, valid_start_row)
+            valid_start_col = min(col_start, valid_start_col)
+            valid_end_col = max(col_end, valid_end_col)
+        for row in range(row_start, row_end + 1):
+            for col in range(col_start, col_end + 1):
+                grid[row][col] = (i, row_start, row_end, col_start, col_end)
+
+    # 创建表格
+    table_html = "<html><body><table>"
+
+    # 遍历每行
+    for row in range(max_row):
+        if row < valid_start_row:
+            continue
+        temp = "<tr>"
+        # 遍历每一列
+        for col in range(max_col):
+            if col < valid_start_col or col > valid_end_col:
+                continue
+            if not grid[row][col]:
+                temp += "<td></td>"
+            else:
+                i, row_start, row_end, col_start, col_end = grid[row][col]
+                if not cell_box_map.get(i):
+                    continue
+                if row == row_start and col == col_start:
+                    ocr_rec_text = cell_box_map.get(i)
+                    text = "<br>".join(ocr_rec_text)
+                    # 如果是起始单元格
+                    row_span = row_end - row_start + 1
+                    col_span = col_end - col_start + 1
+                    cell_content = (
+                        f"<td rowspan={row_span} colspan={col_span}>{text}</td>"
+                    )
+                    temp += cell_content
+
+        table_html = table_html + temp + "</tr>"
+
+    table_html += "</table></body></html>"
+    return table_html

+ 149 - 0
mineru/model/table/rec/unet_table/table_structure_unet.py

@@ -0,0 +1,149 @@
+import copy
+import math
+from typing import Optional, Dict, Any, Tuple
+
+import cv2
+import numpy as np
+from skimage import measure
+from .wired_table_rec_utils import OrtInferSession, resize_img
+from .table_line_rec_utils import (
+    get_table_line,
+    final_adjust_lines,
+    min_area_rect_box,
+    draw_lines,
+    adjust_lines,
+)
+from .table_recover_utils import (
+    sorted_ocr_boxes,
+    box_4_2_poly_to_box_4_1,
+)
+
+
+class TSRUnet:
+    def __init__(self, config: Dict):
+        self.K = 1000
+        self.MK = 4000
+        self.mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
+        self.std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
+        self.inp_height = 1024
+        self.inp_width = 1024
+
+        self.session = OrtInferSession(config)
+
+    def __call__(
+        self, img: np.ndarray, **kwargs
+    ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
+        img_info = self.preprocess(img)
+        pred = self.infer(img_info)
+        polygons, rotated_polygons = self.postprocess(img, pred, **kwargs)
+        if polygons.size == 0:
+            return None, None
+        polygons = polygons.reshape(polygons.shape[0], 4, 2)
+        polygons[:, 3, :], polygons[:, 1, :] = (
+            polygons[:, 1, :].copy(),
+            polygons[:, 3, :].copy(),
+        )
+        rotated_polygons = rotated_polygons.reshape(rotated_polygons.shape[0], 4, 2)
+        rotated_polygons[:, 3, :], rotated_polygons[:, 1, :] = (
+            rotated_polygons[:, 1, :].copy(),
+            rotated_polygons[:, 3, :].copy(),
+        )
+        _, idx = sorted_ocr_boxes(
+            [box_4_2_poly_to_box_4_1(poly_box) for poly_box in rotated_polygons],
+            threshold=0.4,
+        )
+        polygons = polygons[idx]
+        rotated_polygons = rotated_polygons[idx]
+        return polygons, rotated_polygons
+
+    def preprocess(self, img) -> Dict[str, Any]:
+        scale = (self.inp_height, self.inp_width)
+        img, _, _ = resize_img(img, scale, True)
+        img = img.copy().astype(np.float32)
+        assert img.dtype != np.uint8
+        mean = np.float64(self.mean.reshape(1, -1))
+        stdinv = 1 / np.float64(self.std.reshape(1, -1))
+        cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)  # inplace
+        cv2.subtract(img, mean, img)  # inplace
+        cv2.multiply(img, stdinv, img)  # inplace
+        img = img.transpose(2, 0, 1)
+        images = img[None, :]
+        return {"img": images}
+
+    def infer(self, input):
+        result = self.session(input["img"][None, ...])[0][0]
+        result = result[0].astype(np.uint8)
+        return result
+
+    def postprocess(self, img, pred, **kwargs):
+        row = kwargs.get("row", 50) if kwargs else 50
+        col = kwargs.get("col", 30) if kwargs else 30
+        h_lines_threshold = kwargs.get("h_lines_threshold", 100) if kwargs else 100
+        v_lines_threshold = kwargs.get("v_lines_threshold", 15) if kwargs else 15
+        angle = kwargs.get("angle", 50) if kwargs else 50
+        enhance_box_line = kwargs.get("enhance_box_line", True) if kwargs else True
+        morph_close = (
+            kwargs.get("morph_close", enhance_box_line) if kwargs else enhance_box_line
+        )  # 是否进行闭合运算以找到更多小的框
+        more_h_lines = (
+            kwargs.get("more_h_lines", enhance_box_line) if kwargs else enhance_box_line
+        )  # 是否调整以找到更多的横线
+        more_v_lines = (
+            kwargs.get("more_v_lines", enhance_box_line) if kwargs else enhance_box_line
+        )  # 是否调整以找到更多的横线
+        extend_line = (
+            kwargs.get("extend_line", enhance_box_line) if kwargs else enhance_box_line
+        )  # 是否进行线段延长使得端点连接
+
+        ori_shape = img.shape
+        pred = np.uint8(pred)
+        hpred = copy.deepcopy(pred)  # 横线
+        vpred = copy.deepcopy(pred)  # 竖线
+        whereh = np.where(hpred == 1)
+        wherev = np.where(vpred == 2)
+        hpred[wherev] = 0
+        vpred[whereh] = 0
+
+        hpred = cv2.resize(hpred, (ori_shape[1], ori_shape[0]))
+        vpred = cv2.resize(vpred, (ori_shape[1], ori_shape[0]))
+
+        h, w = pred.shape
+        hors_k = int(math.sqrt(w) * 1.2)
+        vert_k = int(math.sqrt(h) * 1.2)
+        hkernel = cv2.getStructuringElement(cv2.MORPH_RECT, (hors_k, 1))
+        vkernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, vert_k))
+        vpred = cv2.morphologyEx(
+            vpred, cv2.MORPH_CLOSE, vkernel, iterations=1
+        )  # 先膨胀后腐蚀的过程
+        if morph_close:
+            hpred = cv2.morphologyEx(hpred, cv2.MORPH_CLOSE, hkernel, iterations=1)
+        colboxes = get_table_line(vpred, axis=1, lineW=col)  # 竖线
+        rowboxes = get_table_line(hpred, axis=0, lineW=row)  # 横线
+        rboxes_row_, rboxes_col_ = [], []
+        if more_h_lines:
+            rboxes_row_ = adjust_lines(rowboxes, alph=h_lines_threshold, angle=angle)
+        if more_v_lines:
+            rboxes_col_ = adjust_lines(colboxes, alph=v_lines_threshold, angle=angle)
+        rowboxes += rboxes_row_
+        colboxes += rboxes_col_
+        if extend_line:
+            rowboxes, colboxes = final_adjust_lines(rowboxes, colboxes)
+        line_img = np.zeros(img.shape[:2], dtype="uint8")
+        line_img = draw_lines(line_img, rowboxes + colboxes, color=255, lineW=2)
+
+        polygons = self.cal_region_boxes(line_img)
+        rotated_polygons = polygons.copy()
+        return polygons, rotated_polygons
+
+    def cal_region_boxes(self, tmp):
+        labels = measure.label(tmp < 255, connectivity=2)  # 8连通区域标记
+        regions = measure.regionprops(labels)
+        ceilboxes = min_area_rect_box(
+            regions,
+            False,
+            tmp.shape[1],
+            tmp.shape[0],
+            filtersmall=True,
+            adjust_box=False,
+        )  # 最后一个参数改为False
+        return np.array(ceilboxes)

+ 199 - 0
mineru/model/table/rec/unet_table/unet_table.py

@@ -0,0 +1,199 @@
+import html
+import logging
+import os
+import time
+import traceback
+from dataclasses import dataclass, asdict
+from typing import List, Optional, Union, Dict, Any
+
+import cv2
+import numpy as np
+from loguru import logger
+
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
+from .table_structure_unet import TSRUnet
+from .table_recover import TableRecover
+from .wired_table_rec_utils import InputType, LoadImage
+from .table_recover_utils import (
+    match_ocr_cell,
+    plot_html_table,
+    box_4_2_poly_to_box_4_1,
+    sorted_ocr_boxes,
+    gather_ocr_list_by_row,
+)
+
+
+@dataclass
+class UnetTableInput:
+    model_path: str
+    device: str = "cpu"
+
+
+@dataclass
+class UnetTableOutput:
+    pred_html: Optional[str] = None
+    cell_bboxes: Optional[np.ndarray] = None
+    logic_points: Optional[np.ndarray] = None
+    elapse: Optional[float] = None
+
+
+class UnetTableRecognition:
+    def __init__(self, config: UnetTableInput):
+        self.table_structure = TSRUnet(asdict(config))
+        self.load_img = LoadImage()
+        self.table_recover = TableRecover()
+
+    def __call__(
+        self,
+        img: InputType,
+        ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
+        **kwargs,
+    ) -> UnetTableOutput:
+        s = time.perf_counter()
+        need_ocr = True
+        col_threshold = 15
+        row_threshold = 10
+        if kwargs:
+            need_ocr = kwargs.get("need_ocr", True)
+            col_threshold = kwargs.get("col_threshold", 15)
+            row_threshold = kwargs.get("row_threshold", 10)
+        img = self.load_img(img)
+        polygons, rotated_polygons = self.table_structure(img, **kwargs)
+        if polygons is None:
+            logging.warning("polygons is None.")
+            return UnetTableOutput("", None, None, 0.0)
+
+        try:
+            table_res, logi_points = self.table_recover(
+                rotated_polygons, row_threshold, col_threshold
+            )
+            # 将坐标由逆时针转为顺时针方向,后续处理与无线表格对齐
+            polygons[:, 1, :], polygons[:, 3, :] = (
+                polygons[:, 3, :].copy(),
+                polygons[:, 1, :].copy(),
+            )
+            if not need_ocr:
+                sorted_polygons, idx_list = sorted_ocr_boxes(
+                    [box_4_2_poly_to_box_4_1(box) for box in polygons]
+                )
+                return UnetTableOutput(
+                    "",
+                    sorted_polygons,
+                    logi_points[idx_list],
+                    time.perf_counter() - s,
+                )
+            cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
+            # 如果有识别框没有ocr结果,直接进行rec补充
+            cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map)
+            # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
+            t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
+            # 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
+            t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
+            # cell_box_map =
+            logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list]
+            cell_box_det_map = {
+                i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
+                for i, t_box_ocr in enumerate(t_rec_ocr_list)
+            }
+            pred_html = plot_html_table(logi_points, cell_box_det_map)
+            polygons = np.array(polygons).reshape(-1, 8)
+            logi_points = np.array(logi_points)
+            elapse = time.perf_counter() - s
+
+        except Exception:
+            logging.warning(traceback.format_exc())
+            return UnetTableOutput("", None, None, 0.0)
+        return UnetTableOutput(pred_html, polygons, logi_points, elapse)
+
+    def transform_res(
+        self,
+        cell_box_det_map: Dict[int, List[any]],
+        polygons: np.ndarray,
+        logi_points: List[np.ndarray],
+    ) -> List[Dict[str, any]]:
+        res = []
+        for i in range(len(polygons)):
+            ocr_res_list = cell_box_det_map.get(i)
+            if not ocr_res_list:
+                continue
+            xmin = min([ocr_box[0][0][0] for ocr_box in ocr_res_list])
+            ymin = min([ocr_box[0][0][1] for ocr_box in ocr_res_list])
+            xmax = max([ocr_box[0][2][0] for ocr_box in ocr_res_list])
+            ymax = max([ocr_box[0][2][1] for ocr_box in ocr_res_list])
+            dict_res = {
+                # xmin,xmax,ymin,ymax
+                "t_box": [xmin, ymin, xmax, ymax],
+                # row_start,row_end,col_start,col_end
+                "t_logic_box": logi_points[i].tolist(),
+                # [[xmin,xmax,ymin,ymax], text]
+                "t_ocr_res": [
+                    [box_4_2_poly_to_box_4_1(ocr_det[0]), ocr_det[1]]
+                    for ocr_det in ocr_res_list
+                ],
+            }
+            res.append(dict_res)
+        return res
+
+    def sort_and_gather_ocr_res(self, res):
+        for i, dict_res in enumerate(res):
+            _, sorted_idx = sorted_ocr_boxes(
+                [ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threshold=0.3
+            )
+            dict_res["t_ocr_res"] = [dict_res["t_ocr_res"][i] for i in sorted_idx]
+            dict_res["t_ocr_res"] = gather_ocr_list_by_row(
+                dict_res["t_ocr_res"], threshold=0.3
+            )
+        return res
+
+    def fill_blank_rec(
+        self,
+        img: np.ndarray,
+        sorted_polygons: np.ndarray,
+        cell_box_map: Dict[int, List[str]],
+    ) -> Dict[int, List[Any]]:
+        """找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
+        for i in range(sorted_polygons.shape[0]):
+            if cell_box_map.get(i):
+                continue
+            box = sorted_polygons[i]
+            cell_box_map[i] = [[box, "", 1]]
+            continue
+        return cell_box_map
+
+
+def escape_html(input_string):
+    """Escape HTML Entities."""
+    return html.escape(input_string)
+
+
+class UnetTableModel:
+    def __init__(self, ocr_engine):
+        model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.unet_structure), ModelPath.unet_structure)
+        input_args = UnetTableInput(model_path=model_path)
+        self.table_model = UnetTableRecognition(input_args)
+        self.ocr_engine = ocr_engine
+
+    def predict(self, img):
+        bgr_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
+        ocr_result = self.ocr_engine.ocr(bgr_img)[0]
+
+        if ocr_result:
+            ocr_result = [
+                [item[0], escape_html(item[1][0]), item[1][1]]
+                for item in ocr_result
+                if len(item) == 2 and isinstance(item[1], tuple)
+            ]
+        else:
+            ocr_result = None
+        if ocr_result:
+            try:
+                table_results = self.table_model(np.asarray(img), ocr_result)
+                html_code = table_results.pred_html
+                table_cell_bboxes = table_results.cell_bboxes
+                logic_points = table_results.logic_points
+                elapse = table_results.elapse
+                return html_code, table_cell_bboxes, logic_points, elapse
+            except Exception as e:
+                logger.exception(e)
+        return None, None, None, None

+ 391 - 0
mineru/model/table/rec/unet_table/wired_table_rec_utils.py

@@ -0,0 +1,391 @@
+import os
+import traceback
+from enum import Enum
+from io import BytesIO
+from pathlib import Path
+from typing import List, Union, Dict, Any, Tuple
+
+import cv2
+import numpy as np
+from onnxruntime import (
+    GraphOptimizationLevel,
+    InferenceSession,
+    SessionOptions,
+    get_available_providers,
+)
+from PIL import Image, UnidentifiedImageError
+
+root_dir = Path(__file__).resolve().parent
+InputType = Union[str, np.ndarray, bytes, Path]
+
+class EP(Enum):
+    CPU_EP = "CPUExecutionProvider"
+
+class OrtInferSession:
+    def __init__(self, config: Dict[str, Any]):
+
+        model_path = config.get("model_path", None)
+        self._verify_model(model_path)
+
+        self.had_providers: List[str] = get_available_providers()
+        EP_list = self._get_ep_list()
+
+        sess_opt = self._init_sess_opts(config)
+        self.session = InferenceSession(
+            model_path,
+            sess_options=sess_opt,
+            providers=EP_list,
+        )
+
+    @staticmethod
+    def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
+        sess_opt = SessionOptions()
+        sess_opt.log_severity_level = 4
+        sess_opt.enable_cpu_mem_arena = False
+        sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
+
+        cpu_nums = os.cpu_count()
+        intra_op_num_threads = config.get("intra_op_num_threads", -1)
+        if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
+            sess_opt.intra_op_num_threads = intra_op_num_threads
+
+        inter_op_num_threads = config.get("inter_op_num_threads", -1)
+        if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
+            sess_opt.inter_op_num_threads = inter_op_num_threads
+
+        return sess_opt
+
+    def get_metadata(self, key: str = "character") -> list:
+        meta_dict = self.session.get_modelmeta().custom_metadata_map
+        content_list = meta_dict[key].splitlines()
+        return content_list
+
+    def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
+        cpu_provider_opts = {
+            "arena_extend_strategy": "kSameAsRequested",
+        }
+        EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
+        return EP_list
+
+    def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
+        input_dict = dict(zip(self.get_input_names(), input_content))
+        try:
+            return self.session.run(None, input_dict)
+        except Exception as e:
+            error_info = traceback.format_exc()
+            raise ONNXRuntimeError(error_info) from e
+
+    def get_input_names(self) -> List[str]:
+        return [v.name for v in self.session.get_inputs()]
+
+    def get_output_names(self) -> List[str]:
+        return [v.name for v in self.session.get_outputs()]
+
+    def get_character_list(self, key: str = "character") -> List[str]:
+        meta_dict = self.session.get_modelmeta().custom_metadata_map
+        return meta_dict[key].splitlines()
+
+    def have_key(self, key: str = "character") -> bool:
+        meta_dict = self.session.get_modelmeta().custom_metadata_map
+        if key in meta_dict.keys():
+            return True
+        return False
+
+    @staticmethod
+    def _verify_model(model_path: Union[str, Path, None]):
+        if model_path is None:
+            raise ValueError("model_path is None!")
+
+        model_path = Path(model_path)
+        if not model_path.exists():
+            raise FileNotFoundError(f"{model_path} does not exists.")
+
+        if not model_path.is_file():
+            raise FileExistsError(f"{model_path} is not a file.")
+
+
+class ONNXRuntimeError(Exception):
+    pass
+
+
+class LoadImage:
+    """
+    Utility class for loading and converting images from various input types to a numpy ndarray.
+
+    Supported input types:
+        - str or pathlib.Path: Path to an image file.
+        - bytes: Image data in bytes format.
+        - numpy.ndarray: Already loaded image array.
+
+    The class attempts to load the image and convert it to a numpy ndarray in BGR format.
+    Raises LoadImageError for unsupported types or if the image cannot be loaded.
+    """
+    def __init__(
+        self,
+    ):
+        pass
+
+    def __call__(self, img: InputType) -> np.ndarray:
+        img = self.load_img(img)
+        img = self.convert_img(img)
+        return img
+
+    def load_img(self, img: InputType) -> np.ndarray:
+        if isinstance(img, (str, Path)):
+            self.verify_exist(img)
+            try:
+                img = np.array(Image.open(img))
+            except UnidentifiedImageError as e:
+                raise LoadImageError(f"cannot identify image file {img}") from e
+            return img
+
+        elif isinstance(img, bytes):
+            try:
+                img = np.array(Image.open(BytesIO(img)))
+            except UnidentifiedImageError as e:
+                raise LoadImageError(f"cannot identify image from bytes data") from e
+            return img
+
+        elif isinstance(img, np.ndarray):
+            return img
+
+        else:
+            raise LoadImageError(f"{type(img)} is not supported!")
+
+    def convert_img(self, img: np.ndarray):
+        if img.ndim == 2:
+            return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+        if img.ndim == 3:
+            channel = img.shape[2]
+            if channel == 1:
+                return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+            if channel == 2:
+                return self.cvt_two_to_three(img)
+
+            if channel == 4:
+                return self.cvt_four_to_three(img)
+
+            if channel == 3:
+                return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+
+            raise LoadImageError(
+                f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
+            )
+
+        raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
+
+    @staticmethod
+    def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
+        """RGBA → BGR"""
+        r, g, b, a = cv2.split(img)
+        new_img = cv2.merge((b, g, r))
+
+        not_a = cv2.bitwise_not(a)
+        not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
+
+        new_img = cv2.bitwise_and(new_img, new_img, mask=a)
+        new_img = cv2.add(new_img, not_a)
+        return new_img
+
+    @staticmethod
+    def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
+        """gray + alpha → BGR"""
+        img_gray = img[..., 0]
+        img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
+
+        img_alpha = img[..., 1]
+        not_a = cv2.bitwise_not(img_alpha)
+        not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
+
+        new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
+        new_img = cv2.add(new_img, not_a)
+        return new_img
+
+    @staticmethod
+    def verify_exist(file_path: Union[str, Path]):
+        if not Path(file_path).exists():
+            raise LoadImageError(f"{file_path} does not exist.")
+
+
+class LoadImageError(Exception):
+    pass
+
+
+# Pillow >=v9.1.0 use a slightly different naming scheme for filters.
+# Set pillow_interp_codes according to the naming scheme used.
+if Image is not None:
+    if hasattr(Image, "Resampling"):
+        pillow_interp_codes = {
+            "nearest": Image.Resampling.NEAREST,
+            "bilinear": Image.Resampling.BILINEAR,
+            "bicubic": Image.Resampling.BICUBIC,
+            "box": Image.Resampling.BOX,
+            "lanczos": Image.Resampling.LANCZOS,
+            "hamming": Image.Resampling.HAMMING,
+        }
+    else:
+        pillow_interp_codes = {
+            "nearest": Image.NEAREST,
+            "bilinear": Image.BILINEAR,
+            "bicubic": Image.BICUBIC,
+            "box": Image.BOX,
+            "lanczos": Image.LANCZOS,
+            "hamming": Image.HAMMING,
+        }
+
+cv2_interp_codes = {
+    "nearest": cv2.INTER_NEAREST,
+    "bilinear": cv2.INTER_LINEAR,
+    "bicubic": cv2.INTER_CUBIC,
+    "area": cv2.INTER_AREA,
+    "lanczos": cv2.INTER_LANCZOS4,
+}
+
+
+def resize_img(img, scale, keep_ratio=True):
+    if keep_ratio:
+        # 缩小使用area更保真
+        if min(img.shape[:2]) > min(scale):
+            interpolation = "area"
+        else:
+            interpolation = "bicubic"  # bilinear
+        img_new, scale_factor = imrescale(
+            img, scale, return_scale=True, interpolation=interpolation
+        )
+        # the w_scale and h_scale has minor difference
+        # a real fix should be done in the mmcv.imrescale in the future
+        new_h, new_w = img_new.shape[:2]
+        h, w = img.shape[:2]
+        w_scale = new_w / w
+        h_scale = new_h / h
+    else:
+        img_new, w_scale, h_scale = imresize(img, scale, return_scale=True)
+    return img_new, w_scale, h_scale
+
+
+def imrescale(img, scale, return_scale=False, interpolation="bilinear", backend=None):
+    """Resize image while keeping the aspect ratio.
+
+    Args:
+        img (ndarray): The input image.
+        scale (float | tuple[int]): The scaling factor or maximum size.
+            If it is a float number, then the image will be rescaled by this
+            factor, else if it is a tuple of 2 integers, then the image will
+            be rescaled as large as possible within the scale.
+        return_scale (bool): Whether to return the scaling factor besides the
+            rescaled image.
+        interpolation (str): Same as :func:`resize`.
+        backend (str | None): Same as :func:`resize`.
+
+    Returns:
+        ndarray: The rescaled image.
+    """
+    h, w = img.shape[:2]
+    new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
+    rescaled_img = imresize(img, new_size, interpolation=interpolation, backend=backend)
+    if return_scale:
+        return rescaled_img, scale_factor
+    else:
+        return rescaled_img
+
+
+def imresize(
+    img, size, return_scale=False, interpolation="bilinear", out=None, backend=None
+):
+    """Resize image to a given size.
+
+    Args:
+        img (ndarray): The input image.
+        size (tuple[int]): Target size (w, h).
+        return_scale (bool): Whether to return `w_scale` and `h_scale`.
+        interpolation (str): Interpolation method, accepted values are
+            "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+            backend, "nearest", "bilinear" for 'pillow' backend.
+        out (ndarray): The output destination.
+        backend (str | None): The image resize backend type. Options are `cv2`,
+            `pillow`, `None`. If backend is None, the global imread_backend
+            specified by ``mmcv.use_backend()`` will be used. Default: None.
+
+    Returns:
+        tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+        `resized_img`.
+    """
+    h, w = img.shape[:2]
+    if backend is None:
+        backend = "cv2"
+    if backend not in ["cv2", "pillow"]:
+        raise ValueError(
+            f"backend: {backend} is not supported for resize."
+            f"Supported backends are 'cv2', 'pillow'"
+        )
+
+    if backend == "pillow":
+        assert img.dtype == np.uint8, "Pillow backend only support uint8 type"
+        pil_image = Image.fromarray(img)
+        pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
+        resized_img = np.array(pil_image)
+    else:
+        resized_img = cv2.resize(
+            img, size, dst=out, interpolation=cv2_interp_codes[interpolation]
+        )
+    if not return_scale:
+        return resized_img
+    else:
+        w_scale = size[0] / w
+        h_scale = size[1] / h
+        return resized_img, w_scale, h_scale
+
+
+def rescale_size(old_size, scale, return_scale=False):
+    """Calculate the new size to be rescaled to.
+
+    Args:
+        old_size (tuple[int]): The old size (w, h) of image.
+        scale (float | tuple[int]): The scaling factor or maximum size.
+            If it is a float number, then the image will be rescaled by this
+            factor, else if it is a tuple of 2 integers, then the image will
+            be rescaled as large as possible within the scale.
+        return_scale (bool): Whether to return the scaling factor besides the
+            rescaled image size.
+
+    Returns:
+        tuple[int]: The new rescaled image size.
+    """
+    w, h = old_size
+    if isinstance(scale, (float, int)):
+        if scale <= 0:
+            raise ValueError(f"Invalid scale {scale}, must be positive.")
+        scale_factor = scale
+    elif isinstance(scale, tuple):
+        max_long_edge = max(scale)
+        max_short_edge = min(scale)
+        scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))
+    else:
+        raise TypeError(
+            f"Scale must be a number or tuple of int, but got {type(scale)}"
+        )
+
+    new_size = _scale_size((w, h), scale_factor)
+
+    if return_scale:
+        return new_size, scale_factor
+    else:
+        return new_size
+
+
+def _scale_size(size, scale):
+    """Rescale a size by a ratio.
+
+    Args:
+        size (tuple[int]): (w, h).
+        scale (float | tuple(float)): Scaling factor.
+
+    Returns:
+        tuple[int]: scaled size.
+    """
+    if isinstance(scale, (float, int)):
+        scale = (scale, scale)
+    w, h = size
+    return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)

+ 3 - 0
mineru/utils/enum_class.py

@@ -59,6 +59,9 @@ class ModelPath:
     pytorch_paddle = "models/OCR/paddleocr_torch"
     layout_reader = "models/ReadingOrder/layout_reader"
     slanet_plus = "models/TabRec/SlanetPlus/slanet-plus.onnx"
+    unet_structure = "models/TabRec/UnetStructure/unet.onnx"
+    paddle_table_cls = "models/TabCls/paddle_table_cls/PP-LCNet_x1_0_table_cls.onnx"
+    paddle_orientation_classification = "models/OriCls/paddle_orientation_classification/PP-LCNet_x1_0_doc_ori.onnx"
 
 
 class SplitFlag:

+ 1 - 0
pyproject.toml

@@ -35,6 +35,7 @@ dependencies = [
     "json-repair>=0.46.2",
     "opencv-python>=4.11.0.86",
     "fast-langdetect>=0.2.3,<0.3.0",
+    "scikit-image>=0.25.0,<1.0.0",
 ]
 
 [project.optional-dependencies]