浏览代码

feat: implement visualization for table results and enhance bounding box drawing logic

myhloli 3 月之前
父节点
当前提交
ac15ad0e61
共有 1 个文件被更改,包括 151 次插入3 次删除
  1. 151 3
      mineru/model/table/rec/unet_table/unet_table.py

+ 151 - 3
mineru/model/table/rec/unet_table/unet_table.py

@@ -3,6 +3,7 @@ import os
 import time
 import traceback
 from dataclasses import dataclass, asdict
+from pathlib import Path
 from typing import List, Optional, Union, Dict, Any
 
 import cv2
@@ -172,7 +173,6 @@ class UnetTableRecognition:
             img_crop = img[int(y1):int(y2), int(x1):int(x2)]
             img_crop_list.append(img_crop)
             img_crop_info_list.append([i, box])
-            continue
 
         if len(img_crop_list) > 0:
             # 进行ocr识别
@@ -187,13 +187,14 @@ class UnetTableRecognition:
             for j, ocr_res in enumerate(ocr_res_list):
                 img_crop_info_list[j].append(ocr_res)
 
-
             for i, box, ocr_res in img_crop_info_list:
                 # 处理ocr结果
                 ocr_text, ocr_score = ocr_res
                 # logger.debug(f"OCR result for box {i}: {ocr_text} with score {ocr_score}")
                 if ocr_score < 0.9:
                     # logger.warning(f"Low confidence OCR result for box {i}: {ocr_text} with score {ocr_score}")
+                    box = sorted_polygons[i]
+                    cell_box_map[i] = [[box, "", 0.5]]
                     continue
                 cell_box_map[i] = [[box, ocr_text, ocr_score]]
 
@@ -230,6 +231,17 @@ class UnetTableModel:
         if ocr_result:
             try:
                 wired_table_results = self.wired_table_model(np.asarray(img), ocr_result, self.ocr_engine)
+
+                # viser = VisTable()
+                # save_html_path = f"outputs/output.html"
+                # save_drawed_path = f"outputs/output_table_vis.jpg"
+                # save_logic_path = (
+                #     f"outputs/output_table_vis_logic.jpg"
+                # )
+                # vis_imged = viser(
+                #     np.asarray(img), wired_table_results, save_html_path, save_drawed_path, save_logic_path
+                # )
+
                 wired_html_code = wired_table_results.pred_html
                 wired_table_cell_bboxes = wired_table_results.cell_bboxes
                 wired_logic_points = wired_table_results.logic_points
@@ -249,7 +261,7 @@ class UnetTableModel:
                 # 判断是否使用无线表格模型的结果
                 if (
                     wired_len <= round(wireless_len * 0.5)  # 有线模型检测到的单元格数太少(低于无线模型的50%)
-                    or ((wireless_len+1 < wired_len) and (wired_len < (2 * wireless_len)) and table_cls_score <= 0.949)  # 有线模型检测到的单元格数反而更多
+                    or ((round(wireless_len*1.2) < wired_len) and (wired_len < (2 * wireless_len)) and table_cls_score <= 0.94)  # 有线模型检测到的单元格数反而更多
                     or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75))  # 两者相差不大但有线模型结果较少
                     or (gap_of_len == 0 and wired_len <= 4)  # 单元格数量完全相等且总量小于等于4
                 ):
@@ -267,3 +279,139 @@ class UnetTableModel:
             except Exception as e:
                 logger.exception(e)
         return None, None, None, None
+
+
+class VisTable:
+    def __init__(self):
+        self.load_img = LoadImage()
+
+    def __call__(
+        self,
+        img_path: InputType,
+        table_results,
+        save_html_path: Optional[Union[str, Path]] = None,
+        save_drawed_path: Optional[Union[str, Path]] = None,
+        save_logic_path: Optional[Union[str, Path]] = None,
+    ):
+        if save_html_path:
+            html_with_border = self.insert_border_style(table_results.pred_html)
+            self.save_html(save_html_path, html_with_border)
+
+        table_cell_bboxes = table_results.cell_bboxes
+        table_logic_points = table_results.logic_points
+        if table_cell_bboxes is None:
+            return None
+
+        img = self.load_img(img_path)
+
+        dims_bboxes = table_cell_bboxes.shape[1]
+        if dims_bboxes == 4:
+            drawed_img = self.draw_rectangle(img, table_cell_bboxes)
+        elif dims_bboxes == 8:
+            drawed_img = self.draw_polylines(img, table_cell_bboxes)
+        else:
+            raise ValueError("Shape of table bounding boxes is not between in 4 or 8.")
+
+        if save_drawed_path:
+            self.save_img(save_drawed_path, drawed_img)
+
+        if save_logic_path:
+            polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes]
+            self.plot_rec_box_with_logic_info(
+                img_path, save_logic_path, table_logic_points, polygons
+            )
+        return drawed_img
+
+    def insert_border_style(self, table_html_str: str):
+        style_res = """<meta charset="UTF-8"><style>
+        table {
+            border-collapse: collapse;
+            width: 100%;
+        }
+        th, td {
+            border: 1px solid black;
+            padding: 8px;
+            text-align: center;
+        }
+        th {
+            background-color: #f2f2f2;
+        }
+                    </style>"""
+
+        prefix_table, suffix_table = table_html_str.split("<body>")
+        html_with_border = f"{prefix_table}{style_res}<body>{suffix_table}"
+        return html_with_border
+
+    def plot_rec_box_with_logic_info(
+        self, img_path, output_path, logic_points, sorted_polygons
+    ):
+        """
+        :param img_path
+        :param output_path
+        :param logic_points: [row_start,row_end,col_start,col_end]
+        :param sorted_polygons: [xmin,ymin,xmax,ymax]
+        :return:
+        """
+        # 读取原图
+        img = img_path
+        img = cv2.copyMakeBorder(
+            img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
+        )
+        # 绘制 polygons 矩形
+        for idx, polygon in enumerate(sorted_polygons):
+            x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
+            x0 = round(x0)
+            y0 = round(y0)
+            x1 = round(x1)
+            y1 = round(y1)
+            cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
+            # 增大字体大小和线宽
+            font_scale = 0.9  # 原先是0.5
+            thickness = 1  # 原先是1
+            logic_point = logic_points[idx]
+            cv2.putText(
+                img,
+                f"row: {logic_point[0]}-{logic_point[1]}",
+                (x0 + 3, y0 + 8),
+                cv2.FONT_HERSHEY_PLAIN,
+                font_scale,
+                (0, 0, 255),
+                thickness,
+            )
+            cv2.putText(
+                img,
+                f"col: {logic_point[2]}-{logic_point[3]}",
+                (x0 + 3, y0 + 18),
+                cv2.FONT_HERSHEY_PLAIN,
+                font_scale,
+                (0, 0, 255),
+                thickness,
+            )
+            os.makedirs(os.path.dirname(output_path), exist_ok=True)
+            # 保存绘制后的图像
+            self.save_img(output_path, img)
+
+    @staticmethod
+    def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
+        img_copy = img.copy()
+        for box in boxes.astype(int):
+            x1, y1, x2, y2 = box
+            cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2)
+        return img_copy
+
+    @staticmethod
+    def draw_polylines(img: np.ndarray, points) -> np.ndarray:
+        img_copy = img.copy()
+        for point in points.astype(int):
+            point = point.reshape(4, 2)
+            cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2)
+        return img_copy
+
+    @staticmethod
+    def save_img(save_path: Union[str, Path], img: np.ndarray):
+        cv2.imwrite(str(save_path), img)
+
+    @staticmethod
+    def save_html(save_path: Union[str, Path], html: str):
+        with open(save_path, "w", encoding="utf-8") as f:
+            f.write(html)