浏览代码

fix: refactor image processing and table classification logic for improved accuracy

myhloli 3 月之前
父节点
当前提交
866ad6ae51

+ 22 - 10
mineru/backend/pipeline/batch_analyze.py

@@ -9,6 +9,7 @@ from .model_list import AtomicModel
 from ...utils.config_reader import get_formula_enable, get_table_enable
 from ...utils.model_utils import crop_img, get_res_list_from_layout_res
 from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
+from ...utils.pdf_image_tools import get_crop_img
 
 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
@@ -40,10 +41,7 @@ class BatchAnalyze:
         images = [image for image, _, _ in images_with_extra_info]
 
         # doclayout_yolo
-        layout_images = []
-        for image_index, image in enumerate(images):
-            layout_images.append(image)
-
+        layout_images = images.copy()
 
         images_layout_res += self.model.layout_model.batch_predict(
             layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
@@ -89,7 +87,14 @@ class BatchAnalyze:
                                           })
 
             for table_res in table_res_list:
-                table_img, _ = crop_img(table_res, pil_img)
+                # table_img, _ = crop_img(table_res, pil_img)
+                # bbox = (241, 208, 1475, 2019)
+                scale = 10/3
+                crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
+                crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
+                bbox = (int(crop_xmin/scale), int(crop_ymin/scale), int(crop_xmax/scale), int(crop_ymax/scale))
+                table_img = get_crop_img(bbox, pil_img, scale=scale)
+
                 table_res_list_all_page.append({'table_res':table_res,
                                                 'lang':_lang,
                                                 'table_img':table_img,
@@ -256,14 +261,22 @@ class BatchAnalyze:
                     atom_model_name=AtomicModel.ImgOrientationCls,
                 )
                 try:
-                    table_img = img_orientation_cls_model.predict(
+                    rotate_label = img_orientation_cls_model.predict(
                         table_res_dict["table_img"]
                     )
                 except Exception as e:
                     logger.warning(
                         f"Image orientation classification failed: {e}, using original image"
                     )
-                    table_img = table_res_dict["table_img"]
+                    rotate_label = "0"
+
+                np_table_img = np.asarray(table_res_dict["table_img"])
+                if rotate_label == "270":
+                    np_table_img = cv2.rotate(np_table_img, cv2.ROTATE_90_CLOCKWISE)
+                elif rotate_label == "90":
+                    np_table_img = cv2.rotate(np_table_img, cv2.ROTATE_90_COUNTERCLOCKWISE)
+                else:
+                    pass
 
                 # 有线表/无线表分类
                 table_cls_model = atom_model_manager.get_atom_model(
@@ -271,7 +284,7 @@ class BatchAnalyze:
                 )
                 table_cls_score = 0.5
                 try:
-                    table_label, table_cls_score = table_cls_model.predict(table_img)
+                    table_label, table_cls_score = table_cls_model.predict(np_table_img)
                 except Exception as e:
                     table_label = AtomicModel.WirelessTable
                     logger.warning(
@@ -289,8 +302,7 @@ class BatchAnalyze:
                     atom_model_name=table_label,
                     lang=_lang,
                 )
-
-                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_img, table_cls_score)
+                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(np_table_img, table_cls_score)
                 # 判断是否返回正常
                 if html_code:
                     # 检查html_code是否包含'<table>'和'</table>'

+ 1 - 1
mineru/backend/pipeline/model_init.py

@@ -11,7 +11,7 @@ from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
 from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
 from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
 from ...model.table.rec.rapid_table import RapidTableModel
-from ...model.table.rec.unet_table.unet_table import UnetTableModel
+from ...model.table.rec.unet_table.main import UnetTableModel
 from ...utils.enum_class import ModelPath
 from ...utils.models_download_utils import auto_download_and_get_model_root_path
 

+ 2 - 3
mineru/backend/pipeline/pipeline_analyze.py

@@ -1,7 +1,7 @@
 import os
 import time
 from typing import List, Tuple
-import PIL.Image
+from PIL import Image
 from loguru import logger
 
 from .model_init import MineruPipelineModel
@@ -148,10 +148,9 @@ def doc_analyze(
 
 
 def batch_image_analyze(
-        images_with_extra_info: List[Tuple[PIL.Image.Image, bool, str]],
+        images_with_extra_info: List[Tuple[Image.Image, bool, str]],
         formula_enable=True,
         table_enable=True):
-    # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
 
     from .batch_analyze import BatchAnalyze
 

+ 17 - 18
mineru/model/ori_cls/paddle_ori_cls.py

@@ -1,6 +1,7 @@
 # Copyright (c) Opendatalab. All rights reserved.
 import os
 
+from PIL import Image
 import cv2
 import numpy as np
 import onnxruntime
@@ -23,15 +24,13 @@ class PaddleOrientationClsModel:
         self.mean = [0.485, 0.456, 0.406]
         self.labels = ["0", "90", "180", "270"]
 
-    def preprocess(self, img):
-        # PIL图像转cv2
-        img = np.array(img)
+    def preprocess(self, input_img):
         # 放大图片,使其最短边长为256
-        h, w = img.shape[:2]
+        h, w = input_img.shape[:2]
         scale = 256 / min(h, w)
         h_resize = round(h * scale)
         w_resize = round(w * scale)
-        img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
+        img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
         # 调整为224*224的正方形
         h, w = img.shape[:2]
         cw, ch = 224, 224
@@ -62,8 +61,15 @@ class PaddleOrientationClsModel:
         x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
         return x
 
-    def predict(self, img):
-        bgr_image = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
+    def predict(self, input_img):
+        rotate_label = "0"  # Default to 0 if no rotation detected or not portrait
+        if isinstance(input_img, Image.Image):
+            np_img = np.asarray(input_img)
+        elif isinstance(input_img, np.ndarray):
+            np_img = input_img
+        else:
+            raise ValueError("Input must be a pillow object or a numpy array.")
+        bgr_image = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
         # First check the overall image aspect ratio (height/width)
         img_height, img_width = bgr_image.shape[:2]
         img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
@@ -99,16 +105,9 @@ class PaddleOrientationClsModel:
                 # If we have more vertical text boxes than horizontal ones,
                 # and vertical ones are significant, table might be rotated
                 if is_rotated:
-                    x = self.preprocess(img)
+                    x = self.preprocess(np_img)
                     (result,) = self.sess.run(None, {"x": x})
-                    label = self.labels[np.argmax(result)]
+                    rotate_label = self.labels[np.argmax(result)]
                     # logger.debug(f"Orientation classification result: {label}")
-                    if label == "270":
-                        rotation = cv2.ROTATE_90_CLOCKWISE
-                        img = cv2.rotate(np.asarray(img), rotation)
-                    elif label == "90":
-                        rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
-                        img = cv2.rotate(np.asarray(img), rotation)
-                    else:
-                        pass
-        return img
+
+        return rotate_label

+ 12 - 7
mineru/model/table/cls/paddle_table_cls.py

@@ -1,5 +1,6 @@
 import os
 
+from PIL import Image
 import cv2
 import numpy as np
 import onnxruntime
@@ -22,15 +23,13 @@ class PaddleTableClsModel:
         self.mean = [0.485, 0.456, 0.406]
         self.labels = [AtomicModel.WiredTable, AtomicModel.WirelessTable]
 
-    def preprocess(self, img):
-        # PIL图像转cv2
-        img = np.array(img)
+    def preprocess(self, input_img):
         # 放大图片,使其最短边长为256
-        h, w = img.shape[:2]
+        h, w = input_img.shape[:2]
         scale = 256 / min(h, w)
         h_resize = round(h * scale)
         w_resize = round(w * scale)
-        img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
+        img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
         # 调整为224*224的正方形
         h, w = img.shape[:2]
         cw, ch = 224, 224
@@ -61,8 +60,14 @@ class PaddleTableClsModel:
         x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
         return x
 
-    def predict(self, img):
-        x = self.preprocess(img)
+    def predict(self, input_img):
+        if isinstance(input_img, Image.Image):
+            np_img = np.asarray(input_img)
+        elif isinstance(input_img, np.ndarray):
+            np_img = input_img
+        else:
+            raise ValueError("Input must be a pillow object or a numpy array.")
+        x = self.preprocess(np_img)
         result = self.sess.run(None, {"x": x})
         idx = np.argmax(result)
         conf = float(np.max(result))

+ 0 - 1
mineru/model/table/rec/unet_table/__init__.py

@@ -1 +0,0 @@
-# Copyright (c) Opendatalab. All rights reserved.

+ 0 - 256
mineru/model/table/rec/unet_table/table_line_rec_utils.py

@@ -1,256 +0,0 @@
-import math
-
-import cv2
-import numpy as np
-from scipy.spatial import distance as dist
-from skimage import measure
-
-
-def get_table_line(binimg, axis=0, lineW=10):
-    ##获取表格线
-    ##axis=0 横线
-    ##axis=1 竖线
-    labels = measure.label(binimg > 0, connectivity=2)  # 8连通区域标记
-    regions = measure.regionprops(labels)
-    if axis == 1:
-        lineboxes = [
-            min_area_rect(line.coords)
-            for line in regions
-            if line.bbox[2] - line.bbox[0] > lineW
-        ]
-    else:
-        lineboxes = [
-            min_area_rect(line.coords)
-            for line in regions
-            if line.bbox[3] - line.bbox[1] > lineW
-        ]
-    return lineboxes
-
-
-def min_area_rect(coords):
-    """
-    多边形外接矩形
-    """
-    rect = cv2.minAreaRect(coords[:, ::-1])
-    box = cv2.boxPoints(rect)
-    box = box.reshape((8,)).tolist()
-
-    box = image_location_sort_box(box)
-
-    x1, y1, x2, y2, x3, y3, x4, y4 = box
-    w, h = calculate_center_rotate_angle(box)
-    if w < h:
-        xmin = (x1 + x2) / 2
-        xmax = (x3 + x4) / 2
-        ymin = (y1 + y2) / 2
-        ymax = (y3 + y4) / 2
-
-    else:
-        xmin = (x1 + x4) / 2
-        xmax = (x2 + x3) / 2
-        ymin = (y1 + y4) / 2
-        ymax = (y2 + y3) / 2
-    return [xmin, ymin, xmax, ymax]
-
-
-def image_location_sort_box(box):
-    x1, y1, x2, y2, x3, y3, x4, y4 = box[:8]
-    pts = (x1, y1), (x2, y2), (x3, y3), (x4, y4)
-    pts = np.array(pts, dtype="float32")
-    (x1, y1), (x2, y2), (x3, y3), (x4, y4) = _order_points(pts)
-    return [x1, y1, x2, y2, x3, y3, x4, y4]
-
-
-def calculate_center_rotate_angle(box):
-
-    x1, y1, x2, y2, x3, y3, x4, y4 = box[:8]
-    w = (
-        np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
-        + np.sqrt((x3 - x4) ** 2 + (y3 - y4) ** 2)
-    ) / 2
-    h = (
-        np.sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2)
-        + np.sqrt((x1 - x4) ** 2 + (y1 - y4) ** 2)
-    ) / 2
-
-    return w, h
-
-
-def _order_points(pts):
-    # 根据x坐标对点进行排序
-    """
-    ---------------------
-    本项目中是为了排序后得到[(xmin,ymin),(xmax,ymin),(xmax,ymax),(xmin,ymax)]
-    作者:Tong_T
-    来源:CSDN
-    原文:https://blog.csdn.net/Tong_T/article/details/81907132
-    版权声明:本文为博主原创文章,转载请附上博文链接!
-    """
-    x_sorted = pts[np.argsort(pts[:, 0]), :]
-
-    left_most = x_sorted[:2, :]
-    right_most = x_sorted[2:, :]
-    left_most = left_most[np.argsort(left_most[:, 1]), :]
-    (tl, bl) = left_most
-
-    distance = dist.cdist(tl[np.newaxis], right_most, "euclidean")[0]
-    (br, tr) = right_most[np.argsort(distance)[::-1], :]
-
-    return np.array([tl, tr, br, bl], dtype="float32")
-
-
-def sqrt(p1, p2):
-    return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
-
-
-def adjust_lines(lines, alph=50, angle=50):
-    lines_n = len(lines)
-    new_lines = []
-    for i in range(lines_n):
-        x1, y1, x2, y2 = lines[i]
-        cx1, cy1 = (x1 + x2) / 2, (y1 + y2) / 2
-        for j in range(lines_n):
-            if i != j:
-                x3, y3, x4, y4 = lines[j]
-                cx2, cy2 = (x3 + x4) / 2, (y3 + y4) / 2
-                if (x3 < cx1 < x4 or y3 < cy1 < y4) or (
-                    x1 < cx2 < x2 or y1 < cy2 < y2
-                ):  # 判断两个横线在y方向的投影重不重合
-                    continue
-                else:
-                    r = sqrt((x1, y1), (x3, y3))
-                    k = abs((y3 - y1) / (x3 - x1 + 1e-10))
-                    a = math.atan(k) * 180 / math.pi
-                    if r < alph and a < angle:
-                        new_lines.append((x1, y1, x3, y3))
-
-                    r = sqrt((x1, y1), (x4, y4))
-                    k = abs((y4 - y1) / (x4 - x1 + 1e-10))
-                    a = math.atan(k) * 180 / math.pi
-                    if r < alph and a < angle:
-                        new_lines.append((x1, y1, x4, y4))
-
-                    r = sqrt((x2, y2), (x3, y3))
-                    k = abs((y3 - y2) / (x3 - x2 + 1e-10))
-                    a = math.atan(k) * 180 / math.pi
-                    if r < alph and a < angle:
-                        new_lines.append((x2, y2, x3, y3))
-                    r = sqrt((x2, y2), (x4, y4))
-                    k = abs((y4 - y2) / (x4 - x2 + 1e-10))
-                    a = math.atan(k) * 180 / math.pi
-                    if r < alph and a < angle:
-                        new_lines.append((x2, y2, x4, y4))
-    return new_lines
-
-
-def final_adjust_lines(rowboxes, colboxes):
-    nrow = len(rowboxes)
-    ncol = len(colboxes)
-    for i in range(nrow):
-        for j in range(ncol):
-            rowboxes[i] = line_to_line(rowboxes[i], colboxes[j], alpha=20, angle=30)
-            colboxes[j] = line_to_line(colboxes[j], rowboxes[i], alpha=20, angle=30)
-    return rowboxes, colboxes
-
-
-def draw_lines(im, bboxes, color=(0, 0, 0), lineW=3):
-    """
-    boxes: bounding boxes
-    """
-    tmp = np.copy(im)
-    c = color
-    h, w = im.shape[:2]
-
-    for box in bboxes:
-        x1, y1, x2, y2 = box[:4]
-        cv2.line(
-            tmp, (int(x1), int(y1)), (int(x2), int(y2)), c, lineW, lineType=cv2.LINE_AA
-        )
-
-    return tmp
-
-
-def line_to_line(points1, points2, alpha=10, angle=30):
-    """
-    线段之间的距离
-    """
-    x1, y1, x2, y2 = points1
-    ox1, oy1, ox2, oy2 = points2
-    xy = np.array([(x1, y1), (x2, y2)], dtype="float32")
-    A1, B1, C1 = fit_line(xy)
-    oxy = np.array([(ox1, oy1), (ox2, oy2)], dtype="float32")
-    A2, B2, C2 = fit_line(oxy)
-    flag1 = point_line_cor(np.array([x1, y1], dtype="float32"), A2, B2, C2)
-    flag2 = point_line_cor(np.array([x2, y2], dtype="float32"), A2, B2, C2)
-
-    if (flag1 > 0 and flag2 > 0) or (flag1 < 0 and flag2 < 0):  # 横线或者竖线在竖线或者横线的同一侧
-        if (A1 * B2 - A2 * B1) != 0:
-            x = (B1 * C2 - B2 * C1) / (A1 * B2 - A2 * B1)
-            y = (A2 * C1 - A1 * C2) / (A1 * B2 - A2 * B1)
-            # x, y = round(x, 2), round(y, 2)
-            p = (x, y)  # 横线与竖线的交点
-            r0 = sqrt(p, (x1, y1))
-            r1 = sqrt(p, (x2, y2))
-
-            if min(r0, r1) < alpha:  # 若交点与线起点或者终点的距离小于alpha,则延长线到交点
-                if r0 < r1:
-                    k = abs((y2 - p[1]) / (x2 - p[0] + 1e-10))
-                    a = math.atan(k) * 180 / math.pi
-                    if a < angle or abs(90 - a) < angle:
-                        points1 = np.array([p[0], p[1], x2, y2], dtype="float32")
-                else:
-                    k = abs((y1 - p[1]) / (x1 - p[0] + 1e-10))
-                    a = math.atan(k) * 180 / math.pi
-                    if a < angle or abs(90 - a) < angle:
-                        points1 = np.array([x1, y1, p[0], p[1]], dtype="float32")
-    return points1
-
-
-def min_area_rect_box(
-    regions, flag=True, W=0, H=0, filtersmall=False, adjust_box=False
-):
-    """
-    多边形外接矩形
-    """
-    boxes = []
-    for region in regions:
-        if region.bbox_area > H * W * 3 / 4:  # 过滤大的单元格
-            continue
-        rect = cv2.minAreaRect(region.coords[:, ::-1])
-
-        box = cv2.boxPoints(rect)
-        box = box.reshape((8,)).tolist()
-        box = image_location_sort_box(box)
-        x1, y1, x2, y2, x3, y3, x4, y4 = box
-        w, h = calculate_center_rotate_angle(box)
-
-        if w * h < 0.5 * W * H:
-            if filtersmall and (
-                w < 15 or h < 15
-            ):  # or w / h > 30 or h / w > 30): # 过滤小的单元格
-                continue
-            boxes.append([x1, y1, x2, y2, x3, y3, x4, y4])
-    return boxes
-
-
-def point_line_cor(p, A, B, C):
-    ##判断点与线之间的位置关系
-    # 一般式直线方程(Ax+By+c)=0
-    x, y = p
-    r = A * x + B * y + C
-    return r
-
-
-def fit_line(p):
-    """A = Y2 - Y1
-       B = X1 - X2
-       C = X2*Y1 - X1*Y2
-       AX+BY+C=0
-    直线一般方程
-    """
-    x1, y1 = p[0]
-    x2, y2 = p[1]
-    A = y2 - y1
-    B = x1 - x2
-    C = x2 * y1 - x1 * y2
-    return A, B, C

+ 0 - 213
mineru/model/table/rec/unet_table/table_recover.py

@@ -1,213 +0,0 @@
-from typing import Dict, List, Tuple
-import numpy as np
-
-
-class TableRecover:
-    def __init__(
-        self,
-    ):
-        pass
-
-    def __call__(
-        self, polygons: np.ndarray, rows_thresh=10, col_thresh=15
-    ) -> Dict[int, Dict]:
-        rows = self.get_rows(polygons, rows_thresh)
-        longest_col, each_col_widths, col_nums = self.get_benchmark_cols(
-            rows, polygons, col_thresh
-        )
-        each_row_heights, row_nums = self.get_benchmark_rows(rows, polygons)
-        table_res, logic_points_dict = self.get_merge_cells(
-            polygons,
-            rows,
-            row_nums,
-            col_nums,
-            longest_col,
-            each_col_widths,
-            each_row_heights,
-        )
-        logic_points = np.array(
-            [logic_points_dict[i] for i in range(len(polygons))]
-        ).astype(np.int32)
-        return table_res, logic_points
-
-    @staticmethod
-    def get_rows(polygons: np.array, rows_thresh=10) -> Dict[int, List[int]]:
-        """对每个框进行行分类,框定哪个是一行的"""
-        y_axis = polygons[:, 0, 1]
-        if y_axis.size == 1:
-            return {0: [0]}
-
-        concat_y = np.array(list(zip(y_axis, y_axis[1:])))
-        minus_res = concat_y[:, 1] - concat_y[:, 0]
-
-        result = {}
-        split_idxs = np.argwhere(abs(minus_res) > rows_thresh).squeeze()
-        # 如果都在一行,则将所有下标设置为同一行
-        if split_idxs.size == 0:
-            return {0: [i for i in range(len(y_axis))]}
-        if split_idxs.ndim == 0:
-            split_idxs = split_idxs[None, ...]
-
-        if max(split_idxs) != len(minus_res):
-            split_idxs = np.append(split_idxs, len(minus_res))
-
-        start_idx = 0
-        for row_num, idx in enumerate(split_idxs):
-            if row_num != 0:
-                start_idx = split_idxs[row_num - 1] + 1
-            result.setdefault(row_num, []).extend(range(start_idx, idx + 1))
-
-        # 计算每一行相邻cell的iou,如果大于0.2,则合并为同一个cell
-        return result
-
-    def get_benchmark_cols(
-        self, rows: Dict[int, List], polygons: np.ndarray, col_thresh=15
-    ) -> Tuple[np.ndarray, List[float], int]:
-        longest_col = max(rows.values(), key=lambda x: len(x))
-        longest_col_points = polygons[longest_col]
-        longest_x_start = list(longest_col_points[:, 0, 0])
-        longest_x_end = list(longest_col_points[:, 2, 0])
-        min_x = longest_x_start[0]
-        max_x = longest_x_end[-1]
-
-        # 根据当前col的起始x坐标,更新col的边界
-        # 2025.2.22 --- 解决最长列可能漏掉最后一列的问题
-        def update_longest_col(col_x_list, cur_v, min_x_, max_x_, insert_last):
-            for i, v in enumerate(col_x_list):
-                if cur_v - col_thresh <= v <= cur_v + col_thresh:
-                    break
-                if cur_v < min_x_:
-                    col_x_list.insert(0, cur_v)
-                    min_x_ = cur_v
-                    break
-                if cur_v > max_x_:
-                    if insert_last:
-                        col_x_list.append(cur_v)
-                    max_x_ = cur_v
-                    break
-                if cur_v < v:
-                    col_x_list.insert(i, cur_v)
-                    break
-            return min_x_, max_x_
-
-        for row_value in rows.values():
-            cur_row_start = list(polygons[row_value][:, 0, 0])
-            cur_row_end = list(polygons[row_value][:, 2, 0])
-            for idx, (cur_v_start, cur_v_end) in enumerate(
-                zip(cur_row_start, cur_row_end)
-            ):
-                min_x, max_x = update_longest_col(
-                    longest_x_start, cur_v_start, min_x, max_x, True
-                )
-                min_x, max_x = update_longest_col(
-                    longest_x_start, cur_v_end, min_x, max_x, False
-                )
-
-        longest_x_start = np.array(longest_x_start)
-        each_col_widths = (longest_x_start[1:] - longest_x_start[:-1]).tolist()
-        each_col_widths.append(max_x - longest_x_start[-1])
-        col_nums = longest_x_start.shape[0]
-        return longest_x_start, each_col_widths, col_nums
-
-    def get_benchmark_rows(
-        self, rows: Dict[int, List], polygons: np.ndarray
-    ) -> Tuple[np.ndarray, List[float], int]:
-        leftmost_cell_idxs = [v[0] for v in rows.values()]
-        benchmark_x = polygons[leftmost_cell_idxs][:, 0, 1]
-
-        each_row_widths = (benchmark_x[1:] - benchmark_x[:-1]).tolist()
-
-        # 求出最后一行cell中,最大的高度作为最后一行的高度
-        bottommost_idxs = list(rows.values())[-1]
-        bottommost_boxes = polygons[bottommost_idxs]
-        # fix self.compute_L2(v[3, :], v[0, :]), v为逆时针,即v[3]为右上,v[0]为左上,v[1]为左下
-        max_height = max([self.compute_L2(v[1, :], v[0, :]) for v in bottommost_boxes])
-        each_row_widths.append(max_height)
-
-        row_nums = benchmark_x.shape[0]
-        return each_row_widths, row_nums
-
-    @staticmethod
-    def compute_L2(a1: np.ndarray, a2: np.ndarray) -> float:
-        return np.linalg.norm(a2 - a1)
-
-    def get_merge_cells(
-        self,
-        polygons: np.ndarray,
-        rows: Dict,
-        row_nums: int,
-        col_nums: int,
-        longest_col: np.ndarray,
-        each_col_widths: List[float],
-        each_row_heights: List[float],
-    ) -> Dict[int, Dict[int, int]]:
-        col_res_merge, row_res_merge = {}, {}
-        logic_points = {}
-        merge_thresh = 10
-        for cur_row, col_list in rows.items():
-            one_col_result, one_row_result = {}, {}
-            for one_col in col_list:
-                box = polygons[one_col]
-                box_width = self.compute_L2(box[3, :], box[0, :])
-
-                # 不一定是从0开始的,应该综合已有值和x坐标位置来确定起始位置
-                loc_col_idx = np.argmin(np.abs(longest_col - box[0, 0]))
-                col_start = max(sum(one_col_result.values()), loc_col_idx)
-
-                # 计算合并多少个列方向单元格
-                for i in range(col_start, col_nums):
-                    col_cum_sum = sum(each_col_widths[col_start : i + 1])
-                    if i == col_start and col_cum_sum > box_width:
-                        one_col_result[one_col] = 1
-                        break
-                    elif abs(col_cum_sum - box_width) <= merge_thresh:
-                        one_col_result[one_col] = i + 1 - col_start
-                        break
-                    # 这里必须进行修正,不然会出现超越阈值范围后列交错
-                    elif col_cum_sum > box_width:
-                        idx = (
-                            i
-                            if abs(col_cum_sum - box_width)
-                            < abs(col_cum_sum - each_col_widths[i] - box_width)
-                            else i - 1
-                        )
-                        one_col_result[one_col] = idx + 1 - col_start
-                        break
-                else:
-                    one_col_result[one_col] = col_nums - col_start
-                col_end = one_col_result[one_col] + col_start - 1
-                box_height = self.compute_L2(box[1, :], box[0, :])
-                row_start = cur_row
-                for j in range(row_start, row_nums):
-                    row_cum_sum = sum(each_row_heights[row_start : j + 1])
-                    # box_height 不确定是几行的高度,所以要逐个试验,找一个最近的几行的高
-                    # 如果第一次row_cum_sum就比box_height大,那么意味着?丢失了一行
-                    if j == row_start and row_cum_sum > box_height:
-                        one_row_result[one_col] = 1
-                        break
-                    elif abs(box_height - row_cum_sum) <= merge_thresh:
-                        one_row_result[one_col] = j + 1 - row_start
-                        break
-                    # 这里必须进行修正,不然会出现超越阈值范围后行交错
-                    elif row_cum_sum > box_height:
-                        idx = (
-                            j
-                            if abs(row_cum_sum - box_height)
-                            < abs(row_cum_sum - each_row_heights[j] - box_height)
-                            else j - 1
-                        )
-                        one_row_result[one_col] = idx + 1 - row_start
-                        break
-                else:
-                    one_row_result[one_col] = row_nums - row_start
-                row_end = one_row_result[one_col] + row_start - 1
-                logic_points[one_col] = np.array(
-                    [row_start, row_end, col_start, col_end]
-                )
-            col_res_merge[cur_row] = one_col_result
-            row_res_merge[cur_row] = one_row_result
-
-        res = {}
-        for i, (c, r) in enumerate(zip(col_res_merge.values(), row_res_merge.values())):
-            res[i] = {k: [cc, r[k]] for k, cc in c.items()}
-        return res, logic_points

+ 0 - 420
mineru/model/table/rec/unet_table/table_recover_utils.py

@@ -1,420 +0,0 @@
-from typing import Any, Dict, List, Union, Tuple
-
-import numpy as np
-import shapely
-from shapely.geometry import MultiPoint, Polygon
-
-
-def sorted_boxes(dt_boxes: np.ndarray) -> np.ndarray:
-    """
-    Sort text boxes in order from top to bottom, left to right
-    args:
-        dt_boxes(array):detected text boxes with shape (N, 4, 2)
-    return:
-        sorted boxes(array) with shape (N, 4, 2)
-    """
-    num_boxes = dt_boxes.shape[0]
-    dt_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
-    _boxes = list(dt_boxes)
-
-    # 解决相邻框,后边比前面y轴小,则会被排到前面去的问题
-    for i in range(num_boxes - 1):
-        for j in range(i, -1, -1):
-            if (
-                abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10
-                and _boxes[j + 1][0][0] < _boxes[j][0][0]
-            ):
-                _boxes[j], _boxes[j + 1] = _boxes[j + 1], _boxes[j]
-            else:
-                break
-    return np.array(_boxes)
-
-
-def calculate_iou(
-    box1: Union[np.ndarray, List], box2: Union[np.ndarray, List]
-) -> float:
-    """
-    :param box1: Iterable [xmin,ymin,xmax,ymax]
-    :param box2: Iterable [xmin,ymin,xmax,ymax]
-    :return: iou: float 0-1
-    """
-    b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
-    b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
-    # 不相交直接退出检测
-    if b1_x2 < b2_x1 or b1_x1 > b2_x2 or b1_y2 < b2_y1 or b1_y1 > b2_y2:
-        return 0.0
-    # 计算交集
-    inter_x1 = max(b1_x1, b2_x1)
-    inter_y1 = max(b1_y1, b2_y1)
-    inter_x2 = min(b1_x2, b2_x2)
-    inter_y2 = min(b1_y2, b2_y2)
-    i_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
-
-    # 计算并集
-    b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
-    b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
-    u_area = b1_area + b2_area - i_area
-
-    # 避免除零错误,如果区域小到乘积为0,认为是错误识别,直接去掉
-    if u_area == 0:
-        return 1
-        # 检查完全包含
-    iou = i_area / u_area
-    return iou
-
-
-def is_box_contained(
-    box1: Union[np.ndarray, List], box2: Union[np.ndarray, List], threshold=0.2
-) -> Union[int, None]:
-    """
-    :param box1: Iterable [xmin,ymin,xmax,ymax]
-    :param box2: Iterable [xmin,ymin,xmax,ymax]
-    :return: 1: box1 is contained 2: box2 is contained None: no contain these
-    """
-    b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
-    b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
-    # 不相交直接退出检测
-    if b1_x2 < b2_x1 or b1_x1 > b2_x2 or b1_y2 < b2_y1 or b1_y1 > b2_y2:
-        return None
-    # 计算box2的总面积
-    b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
-    b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
-
-    # 计算box1和box2的交集
-    intersect_x1 = max(b1_x1, b2_x1)
-    intersect_y1 = max(b1_y1, b2_y1)
-    intersect_x2 = min(b1_x2, b2_x2)
-    intersect_y2 = min(b1_y2, b2_y2)
-
-    # 计算交集的面积
-    intersect_area = max(0, intersect_x2 - intersect_x1) * max(
-        0, intersect_y2 - intersect_y1
-    )
-
-    # 计算外面的面积
-    b1_outside_area = b1_area - intersect_area
-    b2_outside_area = b2_area - intersect_area
-
-    # 计算外面的面积占box2总面积的比例
-    ratio_b1 = b1_outside_area / b1_area if b1_area > 0 else 0
-    ratio_b2 = b2_outside_area / b2_area if b2_area > 0 else 0
-
-    if ratio_b1 < threshold:
-        return 1
-    if ratio_b2 < threshold:
-        return 2
-    # 判断比例是否大于阈值
-    return None
-
-
-def is_single_axis_contained(
-    box1: Union[np.ndarray, List],
-    box2: Union[np.ndarray, List],
-    axis="x",
-    threshold: float = 0.2,
-) -> Union[int, None]:
-    """
-    :param box1: Iterable [xmin,ymin,xmax,ymax]
-    :param box2: Iterable [xmin,ymin,xmax,ymax]
-    :return: 1: box1 is contained 2: box2 is contained None: no contain these
-    """
-    b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
-    b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
-
-    # 计算轴重叠大小
-    if axis == "x":
-        b1_area = b1_x2 - b1_x1
-        b2_area = b2_x2 - b2_x1
-        i_area = min(b1_x2, b2_x2) - max(b1_x1, b2_x1)
-    else:
-        b1_area = b1_y2 - b1_y1
-        b2_area = b2_y2 - b2_y1
-        i_area = min(b1_y2, b2_y2) - max(b1_y1, b2_y1)
-        # 计算外面的面积
-    b1_outside_area = b1_area - i_area
-    b2_outside_area = b2_area - i_area
-
-    ratio_b1 = b1_outside_area / b1_area if b1_area > 0 else 0
-    ratio_b2 = b2_outside_area / b2_area if b2_area > 0 else 0
-    if ratio_b1 < threshold:
-        return 1
-    if ratio_b2 < threshold:
-        return 2
-    return None
-
-
-def sorted_ocr_boxes(
-    dt_boxes: Union[np.ndarray, list], threshold: float = 0.2
-) -> Tuple[Union[np.ndarray, list], List[int]]:
-    """
-    Sort text boxes in order from top to bottom, left to right
-    args:
-        dt_boxes(array):detected text boxes with (xmin, ymin, xmax, ymax)
-    return:
-        sorted boxes(array) with (xmin, ymin, xmax, ymax)
-    """
-    num_boxes = len(dt_boxes)
-    if num_boxes <= 0:
-        return dt_boxes, []
-    indexed_boxes = [(box, idx) for idx, box in enumerate(dt_boxes)]
-    sorted_boxes_with_idx = sorted(indexed_boxes, key=lambda x: (x[0][1], x[0][0]))
-    _boxes, indices = zip(*sorted_boxes_with_idx)
-    indices = list(indices)
-    _boxes = [dt_boxes[i] for i in indices]
-    # 避免输出和输入格式不对应,与函数功能不符合
-    if isinstance(dt_boxes, np.ndarray):
-        _boxes = np.array(_boxes)
-    for i in range(num_boxes - 1):
-        for j in range(i, -1, -1):
-            c_idx = is_single_axis_contained(
-                _boxes[j], _boxes[j + 1], axis="y", threshold=threshold
-            )
-            if (
-                c_idx is not None
-                and _boxes[j + 1][0] < _boxes[j][0]
-                and abs(_boxes[j][1] - _boxes[j + 1][1]) < 20
-            ):
-                _boxes[j], _boxes[j + 1] = _boxes[j + 1].copy(), _boxes[j].copy()
-                indices[j], indices[j + 1] = indices[j + 1], indices[j]
-            else:
-                break
-    return _boxes, indices
-
-
-def box_4_2_poly_to_box_4_1(poly_box: Union[list, np.ndarray]) -> List[Any]:
-    """
-    将poly_box转换为box_4_1
-    :param poly_box:
-    :return:
-    """
-    return [poly_box[0][0], poly_box[0][1], poly_box[2][0], poly_box[2][1]]
-
-
-def match_ocr_cell(dt_rec_boxes: List[List[Union[Any, str]]], pred_bboxes: np.ndarray):
-    """
-    :param dt_rec_boxes: [[(4.2), text, score]]
-    :param pred_bboxes: shap (4,2)
-    :return:
-    """
-    matched = {}
-    not_match_orc_boxes = []
-    for i, gt_box in enumerate(dt_rec_boxes):
-        for j, pred_box in enumerate(pred_bboxes):
-            pred_box = [pred_box[0][0], pred_box[0][1], pred_box[2][0], pred_box[2][1]]
-            ocr_boxes = gt_box[0]
-            # xmin,ymin,xmax,ymax
-            ocr_box = (
-                ocr_boxes[0][0],
-                ocr_boxes[0][1],
-                ocr_boxes[2][0],
-                ocr_boxes[2][1],
-            )
-            contained = is_box_contained(ocr_box, pred_box, 0.6)
-            if contained == 1 or calculate_iou(ocr_box, pred_box) > 0.8:
-                if j not in matched:
-                    matched[j] = [gt_box]
-                else:
-                    matched[j].append(gt_box)
-            else:
-                not_match_orc_boxes.append(gt_box)
-
-    return matched, not_match_orc_boxes
-
-
-def gather_ocr_list_by_row(ocr_list: List[Any], threshold: float = 0.2) -> List[Any]:
-    """
-        Groups OCR results by row based on the vertical (y-axis) overlap of their bounding boxes.
-    Args:
-        ocr_list (List[Any]): A list of OCR results, where each item is a list containing a bounding box
-            in the format [xmin, ymin, xmax, ymax] and the recognized text.
-        threshold (float, optional): The threshold for determining if two boxes are in the same row,
-            based on their y-axis overlap. Default is 0.2.
-    Returns:
-        List[Any]: A new list of OCR results where texts in the same row are merged, and their bounding
-            boxes are updated to encompass the merged text.
-    """
-    for i in range(len(ocr_list)):
-        if not ocr_list[i]:
-            continue
-
-        for j in range(i + 1, len(ocr_list)):
-            if not ocr_list[j]:
-                continue
-            cur = ocr_list[i]
-            next = ocr_list[j]
-            cur_box = cur[0]
-            next_box = next[0]
-            c_idx = is_single_axis_contained(
-                cur[0], next[0], axis="y", threshold=threshold
-            )
-            if c_idx:
-                dis = max(next_box[0] - cur_box[2], 0)
-                blank_str = int(dis / 10) * " "
-                cur[1] = cur[1] + blank_str + next[1]
-                xmin = min(cur_box[0], next_box[0])
-                xmax = max(cur_box[2], next_box[2])
-                ymin = min(cur_box[1], next_box[1])
-                ymax = max(cur_box[3], next_box[3])
-                cur_box[0] = xmin
-                cur_box[1] = ymin
-                cur_box[2] = xmax
-                cur_box[3] = ymax
-                ocr_list[j] = None
-    ocr_list = [x for x in ocr_list if x]
-    return ocr_list
-
-
-def compute_poly_iou(a: np.ndarray, b: np.ndarray) -> float:
-    """计算两个多边形的IOU
-
-    Args:
-        poly1 (np.ndarray): (4, 2)
-        poly2 (np.ndarray): (4, 2)
-
-    Returns:
-        float: iou
-    """
-    poly1 = Polygon(a).convex_hull
-    poly2 = Polygon(b).convex_hull
-
-    union_poly = np.concatenate((a, b))
-
-    if not poly1.intersects(poly2):
-        return 0.0
-
-    try:
-        inter_area = poly1.intersection(poly2).area
-        union_area = MultiPoint(union_poly).convex_hull.area
-    except shapely.geos.TopologicalError:
-        print("shapely.geos.TopologicalError occured, iou set to 0")
-        return 0.0
-
-    if union_area == 0:
-        return 0.0
-
-    return float(inter_area) / union_area
-
-
-def merge_adjacent_polys(polygons: np.ndarray) -> np.ndarray:
-    """合并相邻iou大于阈值的框"""
-    combine_iou_thresh = 0.1
-    pair_polygons = list(zip(polygons, polygons[1:, ...]))
-    pair_ious = np.array([compute_poly_iou(p1, p2) for p1, p2 in pair_polygons])
-    idxs = np.argwhere(pair_ious >= combine_iou_thresh)
-
-    if idxs.size <= 0:
-        return polygons
-
-    polygons = combine_two_poly(polygons, idxs)
-
-    # 注意:递归调用
-    polygons = merge_adjacent_polys(polygons)
-    return polygons
-
-
-def combine_two_poly(polygons: np.ndarray, idxs: np.ndarray) -> np.ndarray:
-    del_idxs, insert_boxes = [], []
-    idxs = idxs.squeeze(-1)
-    for idx in idxs:
-        # idx 和 idx + 1 是重合度过高的
-        # 合并,取两者各个点的最大值
-        new_poly = []
-        pre_poly, pos_poly = polygons[idx], polygons[idx + 1]
-
-        # 四个点,每个点逐一比较
-        new_poly.append(np.minimum(pre_poly[0], pos_poly[0]))
-
-        x_2 = min(pre_poly[1][0], pos_poly[1][0])
-        y_2 = max(pre_poly[1][1], pos_poly[1][1])
-        new_poly.append([x_2, y_2])
-
-        # 第3个点
-        new_poly.append(np.maximum(pre_poly[2], pos_poly[2]))
-
-        # 第4个点
-        x_4 = max(pre_poly[3][0], pos_poly[3][0])
-        y_4 = min(pre_poly[3][1], pos_poly[3][1])
-        new_poly.append([x_4, y_4])
-
-        new_poly = np.array(new_poly)
-
-        # 删除已经合并的两个框,插入新的框
-        del_idxs.extend([idx, idx + 1])
-        insert_boxes.append(new_poly)
-
-    # 整合合并后的框
-    polygons = np.delete(polygons, del_idxs, axis=0)
-
-    insert_boxes = np.array(insert_boxes)
-    polygons = np.append(polygons, insert_boxes, axis=0)
-    polygons = sorted_boxes(polygons)
-    return polygons
-
-
-def plot_html_table(
-    logi_points: Union[Union[np.ndarray, List]], cell_box_map: Dict[int, List[str]]
-) -> str:
-    # 初始化最大行数和列数
-    max_row = 0
-    max_col = 0
-    # 计算最大行数和列数
-    for point in logi_points:
-        max_row = max(max_row, point[1] + 1)  # 加1是因为结束下标是包含在内的
-        max_col = max(max_col, point[3] + 1)  # 加1是因为结束下标是包含在内的
-
-    # 创建一个二维数组来存储 sorted_logi_points 中的元素
-    grid = [[None] * max_col for _ in range(max_row)]
-
-    valid_start_row = (1 << 16) - 1
-    valid_start_col = (1 << 16) - 1
-    valid_end_col = 0
-    # 将 sorted_logi_points 中的元素填充到 grid 中
-    for i, logic_point in enumerate(logi_points):
-        row_start, row_end, col_start, col_end = (
-            logic_point[0],
-            logic_point[1],
-            logic_point[2],
-            logic_point[3],
-        )
-        ocr_rec_text_list = cell_box_map.get(i)
-        if ocr_rec_text_list and "".join(ocr_rec_text_list):
-            valid_start_row = min(row_start, valid_start_row)
-            valid_start_col = min(col_start, valid_start_col)
-            valid_end_col = max(col_end, valid_end_col)
-        for row in range(row_start, row_end + 1):
-            for col in range(col_start, col_end + 1):
-                grid[row][col] = (i, row_start, row_end, col_start, col_end)
-
-    # 创建表格
-    table_html = "<html><body><table>"
-
-    # 遍历每行
-    for row in range(max_row):
-        if row < valid_start_row:
-            continue
-        temp = "<tr>"
-        # 遍历每一列
-        for col in range(max_col):
-            if col < valid_start_col or col > valid_end_col:
-                continue
-            if not grid[row][col]:
-                temp += "<td></td>"
-            else:
-                i, row_start, row_end, col_start, col_end = grid[row][col]
-                if not cell_box_map.get(i):
-                    continue
-                if row == row_start and col == col_start:
-                    ocr_rec_text = cell_box_map.get(i)
-                    text = "<br>".join(ocr_rec_text)
-                    # 如果是起始单元格
-                    row_span = row_end - row_start + 1
-                    col_span = col_end - col_start + 1
-                    cell_content = (
-                        f"<td rowspan={row_span} colspan={col_span}>{text}</td>"
-                    )
-                    temp += cell_content
-
-        table_html = table_html + temp + "</tr>"
-
-    table_html += "</table></body></html>"
-    return table_html

+ 0 - 149
mineru/model/table/rec/unet_table/table_structure_unet.py

@@ -1,149 +0,0 @@
-import copy
-import math
-from typing import Optional, Dict, Any, Tuple
-
-import cv2
-import numpy as np
-from skimage import measure
-from .wired_table_rec_utils import OrtInferSession, resize_img
-from .table_line_rec_utils import (
-    get_table_line,
-    final_adjust_lines,
-    min_area_rect_box,
-    draw_lines,
-    adjust_lines,
-)
-from .table_recover_utils import (
-    sorted_ocr_boxes,
-    box_4_2_poly_to_box_4_1,
-)
-
-
-class TSRUnet:
-    def __init__(self, config: Dict):
-        self.K = 1000
-        self.MK = 4000
-        self.mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
-        self.std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
-        self.inp_height = 1024
-        self.inp_width = 1024
-
-        self.session = OrtInferSession(config)
-
-    def __call__(
-        self, img: np.ndarray, **kwargs
-    ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
-        img_info = self.preprocess(img)
-        pred = self.infer(img_info)
-        polygons, rotated_polygons = self.postprocess(img, pred, **kwargs)
-        if polygons.size == 0:
-            return None, None
-        polygons = polygons.reshape(polygons.shape[0], 4, 2)
-        polygons[:, 3, :], polygons[:, 1, :] = (
-            polygons[:, 1, :].copy(),
-            polygons[:, 3, :].copy(),
-        )
-        rotated_polygons = rotated_polygons.reshape(rotated_polygons.shape[0], 4, 2)
-        rotated_polygons[:, 3, :], rotated_polygons[:, 1, :] = (
-            rotated_polygons[:, 1, :].copy(),
-            rotated_polygons[:, 3, :].copy(),
-        )
-        _, idx = sorted_ocr_boxes(
-            [box_4_2_poly_to_box_4_1(poly_box) for poly_box in rotated_polygons],
-            threshold=0.4,
-        )
-        polygons = polygons[idx]
-        rotated_polygons = rotated_polygons[idx]
-        return polygons, rotated_polygons
-
-    def preprocess(self, img) -> Dict[str, Any]:
-        scale = (self.inp_height, self.inp_width)
-        img, _, _ = resize_img(img, scale, True)
-        img = img.copy().astype(np.float32)
-        assert img.dtype != np.uint8
-        mean = np.float64(self.mean.reshape(1, -1))
-        stdinv = 1 / np.float64(self.std.reshape(1, -1))
-        cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)  # inplace
-        cv2.subtract(img, mean, img)  # inplace
-        cv2.multiply(img, stdinv, img)  # inplace
-        img = img.transpose(2, 0, 1)
-        images = img[None, :]
-        return {"img": images}
-
-    def infer(self, input):
-        result = self.session(input["img"][None, ...])[0][0]
-        result = result[0].astype(np.uint8)
-        return result
-
-    def postprocess(self, img, pred, **kwargs):
-        row = kwargs.get("row", 50) if kwargs else 50
-        col = kwargs.get("col", 30) if kwargs else 30
-        h_lines_threshold = kwargs.get("h_lines_threshold", 100) if kwargs else 100
-        v_lines_threshold = kwargs.get("v_lines_threshold", 15) if kwargs else 15
-        angle = kwargs.get("angle", 50) if kwargs else 50
-        enhance_box_line = kwargs.get("enhance_box_line", True) if kwargs else True
-        morph_close = (
-            kwargs.get("morph_close", enhance_box_line) if kwargs else enhance_box_line
-        )  # 是否进行闭合运算以找到更多小的框
-        more_h_lines = (
-            kwargs.get("more_h_lines", enhance_box_line) if kwargs else enhance_box_line
-        )  # 是否调整以找到更多的横线
-        more_v_lines = (
-            kwargs.get("more_v_lines", enhance_box_line) if kwargs else enhance_box_line
-        )  # 是否调整以找到更多的横线
-        extend_line = (
-            kwargs.get("extend_line", enhance_box_line) if kwargs else enhance_box_line
-        )  # 是否进行线段延长使得端点连接
-
-        ori_shape = img.shape
-        pred = np.uint8(pred)
-        hpred = copy.deepcopy(pred)  # 横线
-        vpred = copy.deepcopy(pred)  # 竖线
-        whereh = np.where(hpred == 1)
-        wherev = np.where(vpred == 2)
-        hpred[wherev] = 0
-        vpred[whereh] = 0
-
-        hpred = cv2.resize(hpred, (ori_shape[1], ori_shape[0]))
-        vpred = cv2.resize(vpred, (ori_shape[1], ori_shape[0]))
-
-        h, w = pred.shape
-        hors_k = int(math.sqrt(w) * 1.2)
-        vert_k = int(math.sqrt(h) * 1.2)
-        hkernel = cv2.getStructuringElement(cv2.MORPH_RECT, (hors_k, 1))
-        vkernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, vert_k))
-        vpred = cv2.morphologyEx(
-            vpred, cv2.MORPH_CLOSE, vkernel, iterations=1
-        )  # 先膨胀后腐蚀的过程
-        if morph_close:
-            hpred = cv2.morphologyEx(hpred, cv2.MORPH_CLOSE, hkernel, iterations=1)
-        colboxes = get_table_line(vpred, axis=1, lineW=col)  # 竖线
-        rowboxes = get_table_line(hpred, axis=0, lineW=row)  # 横线
-        rboxes_row_, rboxes_col_ = [], []
-        if more_h_lines:
-            rboxes_row_ = adjust_lines(rowboxes, alph=h_lines_threshold, angle=angle)
-        if more_v_lines:
-            rboxes_col_ = adjust_lines(colboxes, alph=v_lines_threshold, angle=angle)
-        rowboxes += rboxes_row_
-        colboxes += rboxes_col_
-        if extend_line:
-            rowboxes, colboxes = final_adjust_lines(rowboxes, colboxes)
-        line_img = np.zeros(img.shape[:2], dtype="uint8")
-        line_img = draw_lines(line_img, rowboxes + colboxes, color=255, lineW=2)
-
-        polygons = self.cal_region_boxes(line_img)
-        rotated_polygons = polygons.copy()
-        return polygons, rotated_polygons
-
-    def cal_region_boxes(self, tmp):
-        labels = measure.label(tmp < 255, connectivity=2)  # 8连通区域标记
-        regions = measure.regionprops(labels)
-        ceilboxes = min_area_rect_box(
-            regions,
-            False,
-            tmp.shape[1],
-            tmp.shape[0],
-            filtersmall=True,
-            adjust_box=False,
-        )  # 最后一个参数改为False
-        return np.array(ceilboxes)

+ 0 - 417
mineru/model/table/rec/unet_table/unet_table.py

@@ -1,417 +0,0 @@
-import html
-import os
-import time
-import traceback
-from dataclasses import dataclass, asdict
-from pathlib import Path
-from typing import List, Optional, Union, Dict, Any
-
-import cv2
-import numpy as np
-from loguru import logger
-from rapid_table import RapidTableInput, RapidTable
-
-from mineru.utils.enum_class import ModelPath
-from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
-from .table_structure_unet import TSRUnet
-from .table_recover import TableRecover
-from .wired_table_rec_utils import InputType, LoadImage
-from .table_recover_utils import (
-    match_ocr_cell,
-    plot_html_table,
-    box_4_2_poly_to_box_4_1,
-    sorted_ocr_boxes,
-    gather_ocr_list_by_row,
-)
-
-
-@dataclass
-class UnetTableInput:
-    model_path: str
-    device: str = "cpu"
-
-
-@dataclass
-class UnetTableOutput:
-    pred_html: Optional[str] = None
-    cell_bboxes: Optional[np.ndarray] = None
-    logic_points: Optional[np.ndarray] = None
-    elapse: Optional[float] = None
-
-
-class UnetTableRecognition:
-    def __init__(self, config: UnetTableInput):
-        self.table_structure = TSRUnet(asdict(config))
-        self.load_img = LoadImage()
-        self.table_recover = TableRecover()
-
-    def __call__(
-        self,
-        img: InputType,
-        ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
-        ocr_engine = None,
-        **kwargs,
-    ) -> UnetTableOutput:
-        s = time.perf_counter()
-        need_ocr = True
-        col_threshold = 15
-        row_threshold = 10
-        if kwargs:
-            need_ocr = kwargs.get("need_ocr", True)
-            col_threshold = kwargs.get("col_threshold", 15)
-            row_threshold = kwargs.get("row_threshold", 10)
-        img = self.load_img(img)
-        polygons, rotated_polygons = self.table_structure(img, **kwargs)
-        if polygons is None:
-            # logger.warning("polygons is None.")
-            return UnetTableOutput("", None, None, 0.0)
-
-        try:
-            table_res, logi_points = self.table_recover(
-                rotated_polygons, row_threshold, col_threshold
-            )
-            # 将坐标由逆时针转为顺时针方向,后续处理与无线表格对齐
-            polygons[:, 1, :], polygons[:, 3, :] = (
-                polygons[:, 3, :].copy(),
-                polygons[:, 1, :].copy(),
-            )
-            if not need_ocr:
-                sorted_polygons, idx_list = sorted_ocr_boxes(
-                    [box_4_2_poly_to_box_4_1(box) for box in polygons]
-                )
-                return UnetTableOutput(
-                    "",
-                    sorted_polygons,
-                    logi_points[idx_list],
-                    time.perf_counter() - s,
-                )
-            cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
-            # 如果有识别框没有ocr结果,直接进行rec补充
-            cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map, ocr_engine)
-            # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
-            t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
-            # 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
-            t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
-            # cell_box_map =
-            logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list]
-            cell_box_det_map = {
-                i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
-                for i, t_box_ocr in enumerate(t_rec_ocr_list)
-            }
-            pred_html = plot_html_table(logi_points, cell_box_det_map)
-            polygons = np.array(polygons).reshape(-1, 8)
-            logi_points = np.array(logi_points)
-            elapse = time.perf_counter() - s
-
-        except Exception:
-            logger.warning(traceback.format_exc())
-            return UnetTableOutput("", None, None, 0.0)
-        return UnetTableOutput(pred_html, polygons, logi_points, elapse)
-
-    def transform_res(
-        self,
-        cell_box_det_map: Dict[int, List[any]],
-        polygons: np.ndarray,
-        logi_points: List[np.ndarray],
-    ) -> List[Dict[str, any]]:
-        res = []
-        for i in range(len(polygons)):
-            ocr_res_list = cell_box_det_map.get(i)
-            if not ocr_res_list:
-                continue
-            xmin = min([ocr_box[0][0][0] for ocr_box in ocr_res_list])
-            ymin = min([ocr_box[0][0][1] for ocr_box in ocr_res_list])
-            xmax = max([ocr_box[0][2][0] for ocr_box in ocr_res_list])
-            ymax = max([ocr_box[0][2][1] for ocr_box in ocr_res_list])
-            dict_res = {
-                # xmin,xmax,ymin,ymax
-                "t_box": [xmin, ymin, xmax, ymax],
-                # row_start,row_end,col_start,col_end
-                "t_logic_box": logi_points[i].tolist(),
-                # [[xmin,xmax,ymin,ymax], text]
-                "t_ocr_res": [
-                    [box_4_2_poly_to_box_4_1(ocr_det[0]), ocr_det[1]]
-                    for ocr_det in ocr_res_list
-                ],
-            }
-            res.append(dict_res)
-        return res
-
-    def sort_and_gather_ocr_res(self, res):
-        for i, dict_res in enumerate(res):
-            _, sorted_idx = sorted_ocr_boxes(
-                [ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threshold=0.3
-            )
-            dict_res["t_ocr_res"] = [dict_res["t_ocr_res"][i] for i in sorted_idx]
-            dict_res["t_ocr_res"] = gather_ocr_list_by_row(
-                dict_res["t_ocr_res"], threshold=0.3
-            )
-        return res
-
-    def fill_blank_rec(
-        self,
-        img: np.ndarray,
-        sorted_polygons: np.ndarray,
-        cell_box_map: Dict[int, List[str]],
-        ocr_engine
-    ) -> Dict[int, List[Any]]:
-        """找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
-        img_crop_info_list = []
-        img_crop_list = []
-        for i in range(sorted_polygons.shape[0]):
-            if cell_box_map.get(i):
-                continue
-            box = sorted_polygons[i]
-            if ocr_engine is None:
-                logger.warning(f"No OCR engine provided for box {i}: {box}")
-                continue
-            # 从img中截取对应的区域
-            x1, y1, x2, y2 = box[0][0], box[0][1], box[2][0], box[2][1]
-            if x1 >= x2 or y1 >= y2:
-                logger.warning(f"Invalid box coordinates: {box}")
-                continue
-            img_crop = img[int(y1):int(y2), int(x1):int(x2)]
-            img_crop_list.append(img_crop)
-            img_crop_info_list.append([i, box])
-
-        if len(img_crop_list) > 0:
-            # 进行ocr识别
-            ocr_result = ocr_engine.ocr(img_crop_list, det=False)
-            if not ocr_result or not isinstance(ocr_result, list) or len(ocr_result) == 0:
-                logger.warning("OCR engine returned no results or invalid result for image crops.")
-                return cell_box_map
-            ocr_res_list = ocr_result[0]
-            if not isinstance(ocr_res_list, list) or len(ocr_res_list) != len(img_crop_list):
-                logger.warning("OCR result list length does not match image crop list length.")
-                return cell_box_map
-            for j, ocr_res in enumerate(ocr_res_list):
-                img_crop_info_list[j].append(ocr_res)
-
-            for i, box, ocr_res in img_crop_info_list:
-                # 处理ocr结果
-                ocr_text, ocr_score = ocr_res
-                # logger.debug(f"OCR result for box {i}: {ocr_text} with score {ocr_score}")
-                if ocr_score < 0.9:
-                    # logger.warning(f"Low confidence OCR result for box {i}: {ocr_text} with score {ocr_score}")
-                    box = sorted_polygons[i]
-                    cell_box_map[i] = [[box, "", 0.5]]
-                    continue
-                cell_box_map[i] = [[box, ocr_text, ocr_score]]
-
-        return cell_box_map
-
-
-def escape_html(input_string):
-    """Escape HTML Entities."""
-    return html.escape(input_string)
-
-
-class UnetTableModel:
-    def __init__(self, ocr_engine):
-        model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.unet_structure), ModelPath.unet_structure)
-        wired_input_args = UnetTableInput(model_path=model_path)
-        self.wired_table_model = UnetTableRecognition(wired_input_args)
-        slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
-        wireless_input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
-        self.wireless_table_model = RapidTable(wireless_input_args)
-        self.ocr_engine = ocr_engine
-
-    def predict(self, img, table_cls_score):
-        bgr_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
-        ocr_result = self.ocr_engine.ocr(bgr_img)[0]
-
-        if ocr_result:
-            ocr_result = [
-                [item[0], escape_html(item[1][0]), item[1][1]]
-                for item in ocr_result
-                if len(item) == 2 and isinstance(item[1], tuple)
-            ]
-        else:
-            ocr_result = None
-        if ocr_result:
-            try:
-                wired_table_results = self.wired_table_model(np.asarray(img), ocr_result, self.ocr_engine)
-
-                # viser = VisTable()
-                # save_html_path = f"outputs/output.html"
-                # save_drawed_path = f"outputs/output_table_vis.jpg"
-                # save_logic_path = (
-                #     f"outputs/output_table_vis_logic.jpg"
-                # )
-                # vis_imged = viser(
-                #     np.asarray(img), wired_table_results, save_html_path, save_drawed_path, save_logic_path
-                # )
-
-                wired_html_code = wired_table_results.pred_html
-                wired_table_cell_bboxes = wired_table_results.cell_bboxes
-                wired_logic_points = wired_table_results.logic_points
-                wired_elapse = wired_table_results.elapse
-
-                wireless_table_results = self.wireless_table_model(np.asarray(img), ocr_result)
-                wireless_html_code = wireless_table_results.pred_html
-                wireless_table_cell_bboxes = wireless_table_results.cell_bboxes
-                wireless_logic_points = wireless_table_results.logic_points
-                wireless_elapse = wireless_table_results.elapse
-
-                wired_len = len(wired_table_cell_bboxes) if wired_table_cell_bboxes is not None else 0
-                wireless_len = len(wireless_table_cell_bboxes) if wireless_table_cell_bboxes is not None else 0
-                # logger.debug(f"wired table cell bboxes: {wired_len}, wireless table cell bboxes: {wireless_len}")
-                # 计算两种模型检测的单元格数量差异
-                gap_of_len = wireless_len - wired_len
-                # 判断是否使用无线表格模型的结果
-                if (
-                    wired_len <= round(wireless_len * 0.5)  # 有线模型检测到的单元格数太少(低于无线模型的50%)
-                    or ((round(wireless_len*1.2) < wired_len) and (wired_len < (2 * wireless_len)) and table_cls_score <= 0.94)  # 有线模型检测到的单元格数反而更多
-                    or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75))  # 两者相差不大但有线模型结果较少
-                    or (gap_of_len == 0 and wired_len <= 4)  # 单元格数量完全相等且总量小于等于4
-                ):
-                    # logger.debug("fall back to wireless table model")
-                    html_code = wireless_html_code
-                    table_cell_bboxes = wireless_table_cell_bboxes
-                    logic_points = wireless_logic_points
-                else:
-                    html_code = wired_html_code
-                    table_cell_bboxes = wired_table_cell_bboxes
-                    logic_points = wired_logic_points
-
-                elapse = wired_elapse + wireless_elapse
-                return html_code, table_cell_bboxes, logic_points, elapse
-            except Exception as e:
-                logger.exception(e)
-        return None, None, None, None
-
-
-class VisTable:
-    def __init__(self):
-        self.load_img = LoadImage()
-
-    def __call__(
-        self,
-        img_path: InputType,
-        table_results,
-        save_html_path: Optional[Union[str, Path]] = None,
-        save_drawed_path: Optional[Union[str, Path]] = None,
-        save_logic_path: Optional[Union[str, Path]] = None,
-    ):
-        if save_html_path:
-            html_with_border = self.insert_border_style(table_results.pred_html)
-            self.save_html(save_html_path, html_with_border)
-
-        table_cell_bboxes = table_results.cell_bboxes
-        table_logic_points = table_results.logic_points
-        if table_cell_bboxes is None:
-            return None
-
-        img = self.load_img(img_path)
-
-        dims_bboxes = table_cell_bboxes.shape[1]
-        if dims_bboxes == 4:
-            drawed_img = self.draw_rectangle(img, table_cell_bboxes)
-        elif dims_bboxes == 8:
-            drawed_img = self.draw_polylines(img, table_cell_bboxes)
-        else:
-            raise ValueError("Shape of table bounding boxes is not between in 4 or 8.")
-
-        if save_drawed_path:
-            self.save_img(save_drawed_path, drawed_img)
-
-        if save_logic_path:
-            polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes]
-            self.plot_rec_box_with_logic_info(
-                img_path, save_logic_path, table_logic_points, polygons
-            )
-        return drawed_img
-
-    def insert_border_style(self, table_html_str: str):
-        style_res = """<meta charset="UTF-8"><style>
-        table {
-            border-collapse: collapse;
-            width: 100%;
-        }
-        th, td {
-            border: 1px solid black;
-            padding: 8px;
-            text-align: center;
-        }
-        th {
-            background-color: #f2f2f2;
-        }
-                    </style>"""
-
-        prefix_table, suffix_table = table_html_str.split("<body>")
-        html_with_border = f"{prefix_table}{style_res}<body>{suffix_table}"
-        return html_with_border
-
-    def plot_rec_box_with_logic_info(
-        self, img_path, output_path, logic_points, sorted_polygons
-    ):
-        """
-        :param img_path
-        :param output_path
-        :param logic_points: [row_start,row_end,col_start,col_end]
-        :param sorted_polygons: [xmin,ymin,xmax,ymax]
-        :return:
-        """
-        # 读取原图
-        img = img_path
-        img = cv2.copyMakeBorder(
-            img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
-        )
-        # 绘制 polygons 矩形
-        for idx, polygon in enumerate(sorted_polygons):
-            x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
-            x0 = round(x0)
-            y0 = round(y0)
-            x1 = round(x1)
-            y1 = round(y1)
-            cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
-            # 增大字体大小和线宽
-            font_scale = 0.9  # 原先是0.5
-            thickness = 1  # 原先是1
-            logic_point = logic_points[idx]
-            cv2.putText(
-                img,
-                f"row: {logic_point[0]}-{logic_point[1]}",
-                (x0 + 3, y0 + 8),
-                cv2.FONT_HERSHEY_PLAIN,
-                font_scale,
-                (0, 0, 255),
-                thickness,
-            )
-            cv2.putText(
-                img,
-                f"col: {logic_point[2]}-{logic_point[3]}",
-                (x0 + 3, y0 + 18),
-                cv2.FONT_HERSHEY_PLAIN,
-                font_scale,
-                (0, 0, 255),
-                thickness,
-            )
-            os.makedirs(os.path.dirname(output_path), exist_ok=True)
-            # 保存绘制后的图像
-            self.save_img(output_path, img)
-
-    @staticmethod
-    def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
-        img_copy = img.copy()
-        for box in boxes.astype(int):
-            x1, y1, x2, y2 = box
-            cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2)
-        return img_copy
-
-    @staticmethod
-    def draw_polylines(img: np.ndarray, points) -> np.ndarray:
-        img_copy = img.copy()
-        for point in points.astype(int):
-            point = point.reshape(4, 2)
-            cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2)
-        return img_copy
-
-    @staticmethod
-    def save_img(save_path: Union[str, Path], img: np.ndarray):
-        cv2.imwrite(str(save_path), img)
-
-    @staticmethod
-    def save_html(save_path: Union[str, Path], html: str):
-        with open(save_path, "w", encoding="utf-8") as f:
-            f.write(html)

+ 0 - 391
mineru/model/table/rec/unet_table/wired_table_rec_utils.py

@@ -1,391 +0,0 @@
-import os
-import traceback
-from enum import Enum
-from io import BytesIO
-from pathlib import Path
-from typing import List, Union, Dict, Any, Tuple
-
-import cv2
-import numpy as np
-from onnxruntime import (
-    GraphOptimizationLevel,
-    InferenceSession,
-    SessionOptions,
-    get_available_providers,
-)
-from PIL import Image, UnidentifiedImageError
-
-root_dir = Path(__file__).resolve().parent
-InputType = Union[str, np.ndarray, bytes, Path]
-
-class EP(Enum):
-    CPU_EP = "CPUExecutionProvider"
-
-class OrtInferSession:
-    def __init__(self, config: Dict[str, Any]):
-
-        model_path = config.get("model_path", None)
-        self._verify_model(model_path)
-
-        self.had_providers: List[str] = get_available_providers()
-        EP_list = self._get_ep_list()
-
-        sess_opt = self._init_sess_opts(config)
-        self.session = InferenceSession(
-            model_path,
-            sess_options=sess_opt,
-            providers=EP_list,
-        )
-
-    @staticmethod
-    def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
-        sess_opt = SessionOptions()
-        sess_opt.log_severity_level = 4
-        sess_opt.enable_cpu_mem_arena = False
-        sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
-
-        cpu_nums = os.cpu_count()
-        intra_op_num_threads = config.get("intra_op_num_threads", -1)
-        if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
-            sess_opt.intra_op_num_threads = intra_op_num_threads
-
-        inter_op_num_threads = config.get("inter_op_num_threads", -1)
-        if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
-            sess_opt.inter_op_num_threads = inter_op_num_threads
-
-        return sess_opt
-
-    def get_metadata(self, key: str = "character") -> list:
-        meta_dict = self.session.get_modelmeta().custom_metadata_map
-        content_list = meta_dict[key].splitlines()
-        return content_list
-
-    def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
-        cpu_provider_opts = {
-            "arena_extend_strategy": "kSameAsRequested",
-        }
-        EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
-        return EP_list
-
-    def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
-        input_dict = dict(zip(self.get_input_names(), input_content))
-        try:
-            return self.session.run(None, input_dict)
-        except Exception as e:
-            error_info = traceback.format_exc()
-            raise ONNXRuntimeError(error_info) from e
-
-    def get_input_names(self) -> List[str]:
-        return [v.name for v in self.session.get_inputs()]
-
-    def get_output_names(self) -> List[str]:
-        return [v.name for v in self.session.get_outputs()]
-
-    def get_character_list(self, key: str = "character") -> List[str]:
-        meta_dict = self.session.get_modelmeta().custom_metadata_map
-        return meta_dict[key].splitlines()
-
-    def have_key(self, key: str = "character") -> bool:
-        meta_dict = self.session.get_modelmeta().custom_metadata_map
-        if key in meta_dict.keys():
-            return True
-        return False
-
-    @staticmethod
-    def _verify_model(model_path: Union[str, Path, None]):
-        if model_path is None:
-            raise ValueError("model_path is None!")
-
-        model_path = Path(model_path)
-        if not model_path.exists():
-            raise FileNotFoundError(f"{model_path} does not exists.")
-
-        if not model_path.is_file():
-            raise FileExistsError(f"{model_path} is not a file.")
-
-
-class ONNXRuntimeError(Exception):
-    pass
-
-
-class LoadImage:
-    """
-    Utility class for loading and converting images from various input types to a numpy ndarray.
-
-    Supported input types:
-        - str or pathlib.Path: Path to an image file.
-        - bytes: Image data in bytes format.
-        - numpy.ndarray: Already loaded image array.
-
-    The class attempts to load the image and convert it to a numpy ndarray in BGR format.
-    Raises LoadImageError for unsupported types or if the image cannot be loaded.
-    """
-    def __init__(
-        self,
-    ):
-        pass
-
-    def __call__(self, img: InputType) -> np.ndarray:
-        img = self.load_img(img)
-        img = self.convert_img(img)
-        return img
-
-    def load_img(self, img: InputType) -> np.ndarray:
-        if isinstance(img, (str, Path)):
-            self.verify_exist(img)
-            try:
-                img = np.array(Image.open(img))
-            except UnidentifiedImageError as e:
-                raise LoadImageError(f"cannot identify image file {img}") from e
-            return img
-
-        elif isinstance(img, bytes):
-            try:
-                img = np.array(Image.open(BytesIO(img)))
-            except UnidentifiedImageError as e:
-                raise LoadImageError(f"cannot identify image from bytes data") from e
-            return img
-
-        elif isinstance(img, np.ndarray):
-            return img
-
-        else:
-            raise LoadImageError(f"{type(img)} is not supported!")
-
-    def convert_img(self, img: np.ndarray):
-        if img.ndim == 2:
-            return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
-
-        if img.ndim == 3:
-            channel = img.shape[2]
-            if channel == 1:
-                return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
-
-            if channel == 2:
-                return self.cvt_two_to_three(img)
-
-            if channel == 4:
-                return self.cvt_four_to_three(img)
-
-            if channel == 3:
-                return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
-
-            raise LoadImageError(
-                f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
-            )
-
-        raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
-
-    @staticmethod
-    def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
-        """RGBA → BGR"""
-        r, g, b, a = cv2.split(img)
-        new_img = cv2.merge((b, g, r))
-
-        not_a = cv2.bitwise_not(a)
-        not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
-
-        new_img = cv2.bitwise_and(new_img, new_img, mask=a)
-        new_img = cv2.add(new_img, not_a)
-        return new_img
-
-    @staticmethod
-    def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
-        """gray + alpha → BGR"""
-        img_gray = img[..., 0]
-        img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
-
-        img_alpha = img[..., 1]
-        not_a = cv2.bitwise_not(img_alpha)
-        not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
-
-        new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
-        new_img = cv2.add(new_img, not_a)
-        return new_img
-
-    @staticmethod
-    def verify_exist(file_path: Union[str, Path]):
-        if not Path(file_path).exists():
-            raise LoadImageError(f"{file_path} does not exist.")
-
-
-class LoadImageError(Exception):
-    pass
-
-
-# Pillow >=v9.1.0 use a slightly different naming scheme for filters.
-# Set pillow_interp_codes according to the naming scheme used.
-if Image is not None:
-    if hasattr(Image, "Resampling"):
-        pillow_interp_codes = {
-            "nearest": Image.Resampling.NEAREST,
-            "bilinear": Image.Resampling.BILINEAR,
-            "bicubic": Image.Resampling.BICUBIC,
-            "box": Image.Resampling.BOX,
-            "lanczos": Image.Resampling.LANCZOS,
-            "hamming": Image.Resampling.HAMMING,
-        }
-    else:
-        pillow_interp_codes = {
-            "nearest": Image.NEAREST,
-            "bilinear": Image.BILINEAR,
-            "bicubic": Image.BICUBIC,
-            "box": Image.BOX,
-            "lanczos": Image.LANCZOS,
-            "hamming": Image.HAMMING,
-        }
-
-cv2_interp_codes = {
-    "nearest": cv2.INTER_NEAREST,
-    "bilinear": cv2.INTER_LINEAR,
-    "bicubic": cv2.INTER_CUBIC,
-    "area": cv2.INTER_AREA,
-    "lanczos": cv2.INTER_LANCZOS4,
-}
-
-
-def resize_img(img, scale, keep_ratio=True):
-    if keep_ratio:
-        # 缩小使用area更保真
-        if min(img.shape[:2]) > min(scale):
-            interpolation = "area"
-        else:
-            interpolation = "bicubic"  # bilinear
-        img_new, scale_factor = imrescale(
-            img, scale, return_scale=True, interpolation=interpolation
-        )
-        # the w_scale and h_scale has minor difference
-        # a real fix should be done in the mmcv.imrescale in the future
-        new_h, new_w = img_new.shape[:2]
-        h, w = img.shape[:2]
-        w_scale = new_w / w
-        h_scale = new_h / h
-    else:
-        img_new, w_scale, h_scale = imresize(img, scale, return_scale=True)
-    return img_new, w_scale, h_scale
-
-
-def imrescale(img, scale, return_scale=False, interpolation="bilinear", backend=None):
-    """Resize image while keeping the aspect ratio.
-
-    Args:
-        img (ndarray): The input image.
-        scale (float | tuple[int]): The scaling factor or maximum size.
-            If it is a float number, then the image will be rescaled by this
-            factor, else if it is a tuple of 2 integers, then the image will
-            be rescaled as large as possible within the scale.
-        return_scale (bool): Whether to return the scaling factor besides the
-            rescaled image.
-        interpolation (str): Same as :func:`resize`.
-        backend (str | None): Same as :func:`resize`.
-
-    Returns:
-        ndarray: The rescaled image.
-    """
-    h, w = img.shape[:2]
-    new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
-    rescaled_img = imresize(img, new_size, interpolation=interpolation, backend=backend)
-    if return_scale:
-        return rescaled_img, scale_factor
-    else:
-        return rescaled_img
-
-
-def imresize(
-    img, size, return_scale=False, interpolation="bilinear", out=None, backend=None
-):
-    """Resize image to a given size.
-
-    Args:
-        img (ndarray): The input image.
-        size (tuple[int]): Target size (w, h).
-        return_scale (bool): Whether to return `w_scale` and `h_scale`.
-        interpolation (str): Interpolation method, accepted values are
-            "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
-            backend, "nearest", "bilinear" for 'pillow' backend.
-        out (ndarray): The output destination.
-        backend (str | None): The image resize backend type. Options are `cv2`,
-            `pillow`, `None`. If backend is None, the global imread_backend
-            specified by ``mmcv.use_backend()`` will be used. Default: None.
-
-    Returns:
-        tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
-        `resized_img`.
-    """
-    h, w = img.shape[:2]
-    if backend is None:
-        backend = "cv2"
-    if backend not in ["cv2", "pillow"]:
-        raise ValueError(
-            f"backend: {backend} is not supported for resize."
-            f"Supported backends are 'cv2', 'pillow'"
-        )
-
-    if backend == "pillow":
-        assert img.dtype == np.uint8, "Pillow backend only support uint8 type"
-        pil_image = Image.fromarray(img)
-        pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
-        resized_img = np.array(pil_image)
-    else:
-        resized_img = cv2.resize(
-            img, size, dst=out, interpolation=cv2_interp_codes[interpolation]
-        )
-    if not return_scale:
-        return resized_img
-    else:
-        w_scale = size[0] / w
-        h_scale = size[1] / h
-        return resized_img, w_scale, h_scale
-
-
-def rescale_size(old_size, scale, return_scale=False):
-    """Calculate the new size to be rescaled to.
-
-    Args:
-        old_size (tuple[int]): The old size (w, h) of image.
-        scale (float | tuple[int]): The scaling factor or maximum size.
-            If it is a float number, then the image will be rescaled by this
-            factor, else if it is a tuple of 2 integers, then the image will
-            be rescaled as large as possible within the scale.
-        return_scale (bool): Whether to return the scaling factor besides the
-            rescaled image size.
-
-    Returns:
-        tuple[int]: The new rescaled image size.
-    """
-    w, h = old_size
-    if isinstance(scale, (float, int)):
-        if scale <= 0:
-            raise ValueError(f"Invalid scale {scale}, must be positive.")
-        scale_factor = scale
-    elif isinstance(scale, tuple):
-        max_long_edge = max(scale)
-        max_short_edge = min(scale)
-        scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))
-    else:
-        raise TypeError(
-            f"Scale must be a number or tuple of int, but got {type(scale)}"
-        )
-
-    new_size = _scale_size((w, h), scale_factor)
-
-    if return_scale:
-        return new_size, scale_factor
-    else:
-        return new_size
-
-
-def _scale_size(size, scale):
-    """Rescale a size by a ratio.
-
-    Args:
-        size (tuple[int]): (w, h).
-        scale (float | tuple(float)): Scaling factor.
-
-    Returns:
-        tuple[int]: scaled size.
-    """
-    if isinstance(scale, (float, int)):
-        scale = (scale, scale)
-    w, h = size
-    return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)

+ 11 - 5
mineru/utils/model_utils.py

@@ -377,13 +377,19 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
                 layout_res.remove(res)
 
     # Remove filtered out tables from layout_res
+    # if len(filtered_table_res_list) < len(table_res_list):
+    #     kept_tables = set(id(table) for table in filtered_table_res_list)
+    #     to_remove = [table_indices[i] for i, table in enumerate(table_res_list)
+    #                  if id(table) not in kept_tables]
+    #
+    #     for idx in sorted(to_remove, reverse=True):
+    #         del layout_res[idx]
     if len(filtered_table_res_list) < len(table_res_list):
         kept_tables = set(id(table) for table in filtered_table_res_list)
-        to_remove = [table_indices[i] for i, table in enumerate(table_res_list)
-                     if id(table) not in kept_tables]
-
-        for idx in sorted(to_remove, reverse=True):
-            del layout_res[idx]
+        tables_to_remove = [table for table in table_res_list if id(table) not in kept_tables]
+        for table in tables_to_remove:
+            if table in layout_res:
+                layout_res.remove(table)
 
     # Remove overlaps in OCR and text regions
     text_res_list, need_remove = remove_overlaps_min_blocks(text_res_list)