Browse Source

fix: remove unused imports and clean up code in wired_table_rec_utils.py

myhloli 3 months ago
parent
commit
c98cba1e30
1 changed files with 2 additions and 185 deletions
  1. 2 185
      mineru/model/table/rec/unet_table/wired_table_rec_utils.py

+ 2 - 185
mineru/model/table/rec/unet_table/wired_table_rec_utils.py

@@ -1,12 +1,10 @@
 # -*- encoding: utf-8 -*-
-import math
 import os
-import platform
 import traceback
 from enum import Enum
 from io import BytesIO
 from pathlib import Path
-from typing import List, Union, Dict, Any, Tuple, Optional
+from typing import List, Union, Dict, Any, Tuple
 
 import cv2
 import numpy as np
@@ -15,7 +13,6 @@ from onnxruntime import (
     InferenceSession,
     SessionOptions,
     get_available_providers,
-    get_device,
 )
 from PIL import Image, UnidentifiedImageError
 
@@ -382,184 +379,4 @@ def _scale_size(size, scale):
     if isinstance(scale, (float, int)):
         scale = (scale, scale)
     w, h = size
-    return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
-
-
-class ImageOrientationCorrector:
-    """
-    对图片小角度(-90 - + 90度进行修正)
-    """
-
-    def __init__(self):
-        self.img_loader = LoadImage()
-
-    def __call__(self, img: InputType):
-        img = self.img_loader(img)
-        # 取灰度
-        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
-        # 二值化
-        gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
-        # 边缘检测
-        edges = cv2.Canny(gray, 100, 250, apertureSize=3)
-        # 霍夫变换,摘自https://blog.csdn.net/feilong_csdn/article/details/81586322
-        lines = cv2.HoughLines(edges, 1, np.pi / 180, 0)
-        for rho, theta in lines[0]:
-            a = np.cos(theta)
-            b = np.sin(theta)
-            x0 = a * rho
-            y0 = b * rho
-            x1 = int(x0 + 1000 * (-b))
-            y1 = int(y0 + 1000 * (a))
-            x2 = int(x0 - 1000 * (-b))
-            y2 = int(y0 - 1000 * (a))
-        if x1 == x2 or y1 == y2:
-            return img
-        else:
-            t = float(y2 - y1) / (x2 - x1)
-            # 得到角度后
-            rotate_angle = math.degrees(math.atan(t))
-            if rotate_angle > 45:
-                rotate_angle = -90 + rotate_angle
-            elif rotate_angle < -45:
-                rotate_angle = 90 + rotate_angle
-            # 旋转图像
-            (h, w) = img.shape[:2]
-            center = (w // 2, h // 2)
-            M = cv2.getRotationMatrix2D(center, rotate_angle, 1.0)
-            return cv2.warpAffine(img, M, (w, h))
-
-
-class VisTable:
-    def __init__(self):
-        self.load_img = LoadImage()
-
-    def __call__(
-        self,
-        img_path: Union[str, Path],
-        table_results,
-        save_html_path: Optional[Union[str, Path]] = None,
-        save_drawed_path: Optional[Union[str, Path]] = None,
-        save_logic_path: Optional[Union[str, Path]] = None,
-    ):
-        if save_html_path:
-            html_with_border = self.insert_border_style(table_results.pred_html)
-            self.save_html(save_html_path, html_with_border)
-
-        table_cell_bboxes = table_results.cell_bboxes
-        table_logic_points = table_results.logic_points
-        if table_cell_bboxes is None:
-            return None
-
-        img = self.load_img(img_path)
-
-        dims_bboxes = table_cell_bboxes.shape[1]
-        if dims_bboxes == 4:
-            drawed_img = self.draw_rectangle(img, table_cell_bboxes)
-        elif dims_bboxes == 8:
-            drawed_img = self.draw_polylines(img, table_cell_bboxes)
-        else:
-            raise ValueError("Shape of table bounding boxes is not between in 4 or 8.")
-
-        if save_drawed_path:
-            self.save_img(save_drawed_path, drawed_img)
-
-        if save_logic_path:
-            polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes]
-            self.plot_rec_box_with_logic_info(
-                img_path, save_logic_path, table_logic_points, polygons
-            )
-        return drawed_img
-
-    def insert_border_style(self, table_html_str: str):
-        style_res = """<meta charset="UTF-8"><style>
-        table {
-            border-collapse: collapse;
-            width: 100%;
-        }
-        th, td {
-            border: 1px solid black;
-            padding: 8px;
-            text-align: center;
-        }
-        th {
-            background-color: #f2f2f2;
-        }
-                    </style>"""
-
-        prefix_table, suffix_table = table_html_str.split("<body>")
-        html_with_border = f"{prefix_table}{style_res}<body>{suffix_table}"
-        return html_with_border
-
-    def plot_rec_box_with_logic_info(
-        self, img_path, output_path, logic_points, sorted_polygons
-    ):
-        """
-        :param img_path
-        :param output_path
-        :param logic_points: [row_start,row_end,col_start,col_end]
-        :param sorted_polygons: [xmin,ymin,xmax,ymax]
-        :return:
-        """
-        # 读取原图
-        img = cv2.imread(img_path)
-        img = cv2.copyMakeBorder(
-            img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
-        )
-        # 绘制 polygons 矩形
-        for idx, polygon in enumerate(sorted_polygons):
-            x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
-            x0 = round(x0)
-            y0 = round(y0)
-            x1 = round(x1)
-            y1 = round(y1)
-            cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
-            # 增大字体大小和线宽
-            font_scale = 0.9  # 原先是0.5
-            thickness = 1  # 原先是1
-            logic_point = logic_points[idx]
-            cv2.putText(
-                img,
-                f"row: {logic_point[0]}-{logic_point[1]}",
-                (x0 + 3, y0 + 8),
-                cv2.FONT_HERSHEY_PLAIN,
-                font_scale,
-                (0, 0, 255),
-                thickness,
-            )
-            cv2.putText(
-                img,
-                f"col: {logic_point[2]}-{logic_point[3]}",
-                (x0 + 3, y0 + 18),
-                cv2.FONT_HERSHEY_PLAIN,
-                font_scale,
-                (0, 0, 255),
-                thickness,
-            )
-            os.makedirs(os.path.dirname(output_path), exist_ok=True)
-            # 保存绘制后的图像
-            self.save_img(output_path, img)
-
-    @staticmethod
-    def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
-        img_copy = img.copy()
-        for box in boxes.astype(int):
-            x1, y1, x2, y2 = box
-            cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2)
-        return img_copy
-
-    @staticmethod
-    def draw_polylines(img: np.ndarray, points) -> np.ndarray:
-        img_copy = img.copy()
-        for point in points.astype(int):
-            point = point.reshape(4, 2)
-            cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2)
-        return img_copy
-
-    @staticmethod
-    def save_img(save_path: Union[str, Path], img: np.ndarray):
-        cv2.imwrite(str(save_path), img)
-
-    @staticmethod
-    def save_html(save_path: Union[str, Path], html: str):
-        with open(save_path, "w", encoding="utf-8") as f:
-            f.write(html)
+    return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)