|
|
@@ -3,6 +3,7 @@ import os
|
|
|
import time
|
|
|
import traceback
|
|
|
from dataclasses import dataclass, asdict
|
|
|
+from pathlib import Path
|
|
|
from typing import List, Optional, Union, Dict, Any
|
|
|
|
|
|
import cv2
|
|
|
@@ -172,7 +173,6 @@ class UnetTableRecognition:
|
|
|
img_crop = img[int(y1):int(y2), int(x1):int(x2)]
|
|
|
img_crop_list.append(img_crop)
|
|
|
img_crop_info_list.append([i, box])
|
|
|
- continue
|
|
|
|
|
|
if len(img_crop_list) > 0:
|
|
|
# 进行ocr识别
|
|
|
@@ -187,13 +187,14 @@ class UnetTableRecognition:
|
|
|
for j, ocr_res in enumerate(ocr_res_list):
|
|
|
img_crop_info_list[j].append(ocr_res)
|
|
|
|
|
|
-
|
|
|
for i, box, ocr_res in img_crop_info_list:
|
|
|
# 处理ocr结果
|
|
|
ocr_text, ocr_score = ocr_res
|
|
|
# logger.debug(f"OCR result for box {i}: {ocr_text} with score {ocr_score}")
|
|
|
if ocr_score < 0.9:
|
|
|
# logger.warning(f"Low confidence OCR result for box {i}: {ocr_text} with score {ocr_score}")
|
|
|
+ box = sorted_polygons[i]
|
|
|
+ cell_box_map[i] = [[box, "", 0.5]]
|
|
|
continue
|
|
|
cell_box_map[i] = [[box, ocr_text, ocr_score]]
|
|
|
|
|
|
@@ -230,6 +231,17 @@ class UnetTableModel:
|
|
|
if ocr_result:
|
|
|
try:
|
|
|
wired_table_results = self.wired_table_model(np.asarray(img), ocr_result, self.ocr_engine)
|
|
|
+
|
|
|
+ # viser = VisTable()
|
|
|
+ # save_html_path = f"outputs/output.html"
|
|
|
+ # save_drawed_path = f"outputs/output_table_vis.jpg"
|
|
|
+ # save_logic_path = (
|
|
|
+ # f"outputs/output_table_vis_logic.jpg"
|
|
|
+ # )
|
|
|
+ # vis_imged = viser(
|
|
|
+ # np.asarray(img), wired_table_results, save_html_path, save_drawed_path, save_logic_path
|
|
|
+ # )
|
|
|
+
|
|
|
wired_html_code = wired_table_results.pred_html
|
|
|
wired_table_cell_bboxes = wired_table_results.cell_bboxes
|
|
|
wired_logic_points = wired_table_results.logic_points
|
|
|
@@ -249,7 +261,7 @@ class UnetTableModel:
|
|
|
# 判断是否使用无线表格模型的结果
|
|
|
if (
|
|
|
wired_len <= round(wireless_len * 0.5) # 有线模型检测到的单元格数太少(低于无线模型的50%)
|
|
|
- or ((wireless_len+1 < wired_len) and (wired_len < (2 * wireless_len)) and table_cls_score <= 0.949) # 有线模型检测到的单元格数反而更多
|
|
|
+ or ((round(wireless_len*1.2) < wired_len) and (wired_len < (2 * wireless_len)) and table_cls_score <= 0.94) # 有线模型检测到的单元格数反而更多
|
|
|
or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75)) # 两者相差不大但有线模型结果较少
|
|
|
or (gap_of_len == 0 and wired_len <= 4) # 单元格数量完全相等且总量小于等于4
|
|
|
):
|
|
|
@@ -267,3 +279,139 @@ class UnetTableModel:
|
|
|
except Exception as e:
|
|
|
logger.exception(e)
|
|
|
return None, None, None, None
|
|
|
+
|
|
|
+
|
|
|
+class VisTable:
|
|
|
+ def __init__(self):
|
|
|
+ self.load_img = LoadImage()
|
|
|
+
|
|
|
+ def __call__(
|
|
|
+ self,
|
|
|
+ img_path: InputType,
|
|
|
+ table_results,
|
|
|
+ save_html_path: Optional[Union[str, Path]] = None,
|
|
|
+ save_drawed_path: Optional[Union[str, Path]] = None,
|
|
|
+ save_logic_path: Optional[Union[str, Path]] = None,
|
|
|
+ ):
|
|
|
+ if save_html_path:
|
|
|
+ html_with_border = self.insert_border_style(table_results.pred_html)
|
|
|
+ self.save_html(save_html_path, html_with_border)
|
|
|
+
|
|
|
+ table_cell_bboxes = table_results.cell_bboxes
|
|
|
+ table_logic_points = table_results.logic_points
|
|
|
+ if table_cell_bboxes is None:
|
|
|
+ return None
|
|
|
+
|
|
|
+ img = self.load_img(img_path)
|
|
|
+
|
|
|
+ dims_bboxes = table_cell_bboxes.shape[1]
|
|
|
+ if dims_bboxes == 4:
|
|
|
+ drawed_img = self.draw_rectangle(img, table_cell_bboxes)
|
|
|
+ elif dims_bboxes == 8:
|
|
|
+ drawed_img = self.draw_polylines(img, table_cell_bboxes)
|
|
|
+ else:
|
|
|
+ raise ValueError("Shape of table bounding boxes is not between in 4 or 8.")
|
|
|
+
|
|
|
+ if save_drawed_path:
|
|
|
+ self.save_img(save_drawed_path, drawed_img)
|
|
|
+
|
|
|
+ if save_logic_path:
|
|
|
+ polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes]
|
|
|
+ self.plot_rec_box_with_logic_info(
|
|
|
+ img_path, save_logic_path, table_logic_points, polygons
|
|
|
+ )
|
|
|
+ return drawed_img
|
|
|
+
|
|
|
+ def insert_border_style(self, table_html_str: str):
|
|
|
+ style_res = """<meta charset="UTF-8"><style>
|
|
|
+ table {
|
|
|
+ border-collapse: collapse;
|
|
|
+ width: 100%;
|
|
|
+ }
|
|
|
+ th, td {
|
|
|
+ border: 1px solid black;
|
|
|
+ padding: 8px;
|
|
|
+ text-align: center;
|
|
|
+ }
|
|
|
+ th {
|
|
|
+ background-color: #f2f2f2;
|
|
|
+ }
|
|
|
+ </style>"""
|
|
|
+
|
|
|
+ prefix_table, suffix_table = table_html_str.split("<body>")
|
|
|
+ html_with_border = f"{prefix_table}{style_res}<body>{suffix_table}"
|
|
|
+ return html_with_border
|
|
|
+
|
|
|
+ def plot_rec_box_with_logic_info(
|
|
|
+ self, img_path, output_path, logic_points, sorted_polygons
|
|
|
+ ):
|
|
|
+ """
|
|
|
+ :param img_path
|
|
|
+ :param output_path
|
|
|
+ :param logic_points: [row_start,row_end,col_start,col_end]
|
|
|
+ :param sorted_polygons: [xmin,ymin,xmax,ymax]
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ # 读取原图
|
|
|
+ img = img_path
|
|
|
+ img = cv2.copyMakeBorder(
|
|
|
+ img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
|
|
|
+ )
|
|
|
+ # 绘制 polygons 矩形
|
|
|
+ for idx, polygon in enumerate(sorted_polygons):
|
|
|
+ x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
|
|
|
+ x0 = round(x0)
|
|
|
+ y0 = round(y0)
|
|
|
+ x1 = round(x1)
|
|
|
+ y1 = round(y1)
|
|
|
+ cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
|
|
|
+ # 增大字体大小和线宽
|
|
|
+ font_scale = 0.9 # 原先是0.5
|
|
|
+ thickness = 1 # 原先是1
|
|
|
+ logic_point = logic_points[idx]
|
|
|
+ cv2.putText(
|
|
|
+ img,
|
|
|
+ f"row: {logic_point[0]}-{logic_point[1]}",
|
|
|
+ (x0 + 3, y0 + 8),
|
|
|
+ cv2.FONT_HERSHEY_PLAIN,
|
|
|
+ font_scale,
|
|
|
+ (0, 0, 255),
|
|
|
+ thickness,
|
|
|
+ )
|
|
|
+ cv2.putText(
|
|
|
+ img,
|
|
|
+ f"col: {logic_point[2]}-{logic_point[3]}",
|
|
|
+ (x0 + 3, y0 + 18),
|
|
|
+ cv2.FONT_HERSHEY_PLAIN,
|
|
|
+ font_scale,
|
|
|
+ (0, 0, 255),
|
|
|
+ thickness,
|
|
|
+ )
|
|
|
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
|
+ # 保存绘制后的图像
|
|
|
+ self.save_img(output_path, img)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
|
|
|
+ img_copy = img.copy()
|
|
|
+ for box in boxes.astype(int):
|
|
|
+ x1, y1, x2, y2 = box
|
|
|
+ cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2)
|
|
|
+ return img_copy
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def draw_polylines(img: np.ndarray, points) -> np.ndarray:
|
|
|
+ img_copy = img.copy()
|
|
|
+ for point in points.astype(int):
|
|
|
+ point = point.reshape(4, 2)
|
|
|
+ cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2)
|
|
|
+ return img_copy
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def save_img(save_path: Union[str, Path], img: np.ndarray):
|
|
|
+ cv2.imwrite(str(save_path), img)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def save_html(save_path: Union[str, Path], html: str):
|
|
|
+ with open(save_path, "w", encoding="utf-8") as f:
|
|
|
+ f.write(html)
|