| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417 |
- import html
- import os
- import time
- import traceback
- from dataclasses import dataclass, asdict
- from pathlib import Path
- from typing import List, Optional, Union, Dict, Any
- import cv2
- import numpy as np
- from loguru import logger
- from rapid_table import RapidTableInput, RapidTable
- from mineru.utils.enum_class import ModelPath
- from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
- from .table_structure_unet import TSRUnet
- from .table_recover import TableRecover
- from .wired_table_rec_utils import InputType, LoadImage
- from .table_recover_utils import (
- match_ocr_cell,
- plot_html_table,
- box_4_2_poly_to_box_4_1,
- sorted_ocr_boxes,
- gather_ocr_list_by_row,
- )
- @dataclass
- class UnetTableInput:
- model_path: str
- device: str = "cpu"
- @dataclass
- class UnetTableOutput:
- pred_html: Optional[str] = None
- cell_bboxes: Optional[np.ndarray] = None
- logic_points: Optional[np.ndarray] = None
- elapse: Optional[float] = None
- class UnetTableRecognition:
- def __init__(self, config: UnetTableInput):
- self.table_structure = TSRUnet(asdict(config))
- self.load_img = LoadImage()
- self.table_recover = TableRecover()
- def __call__(
- self,
- img: InputType,
- ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
- ocr_engine = None,
- **kwargs,
- ) -> UnetTableOutput:
- s = time.perf_counter()
- need_ocr = True
- col_threshold = 15
- row_threshold = 10
- if kwargs:
- need_ocr = kwargs.get("need_ocr", True)
- col_threshold = kwargs.get("col_threshold", 15)
- row_threshold = kwargs.get("row_threshold", 10)
- img = self.load_img(img)
- polygons, rotated_polygons = self.table_structure(img, **kwargs)
- if polygons is None:
- # logger.warning("polygons is None.")
- return UnetTableOutput("", None, None, 0.0)
- try:
- table_res, logi_points = self.table_recover(
- rotated_polygons, row_threshold, col_threshold
- )
- # 将坐标由逆时针转为顺时针方向,后续处理与无线表格对齐
- polygons[:, 1, :], polygons[:, 3, :] = (
- polygons[:, 3, :].copy(),
- polygons[:, 1, :].copy(),
- )
- if not need_ocr:
- sorted_polygons, idx_list = sorted_ocr_boxes(
- [box_4_2_poly_to_box_4_1(box) for box in polygons]
- )
- return UnetTableOutput(
- "",
- sorted_polygons,
- logi_points[idx_list],
- time.perf_counter() - s,
- )
- cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
- # 如果有识别框没有ocr结果,直接进行rec补充
- cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map, ocr_engine)
- # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
- t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
- # 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
- t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
- # cell_box_map =
- logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list]
- cell_box_det_map = {
- i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
- for i, t_box_ocr in enumerate(t_rec_ocr_list)
- }
- pred_html = plot_html_table(logi_points, cell_box_det_map)
- polygons = np.array(polygons).reshape(-1, 8)
- logi_points = np.array(logi_points)
- elapse = time.perf_counter() - s
- except Exception:
- logger.warning(traceback.format_exc())
- return UnetTableOutput("", None, None, 0.0)
- return UnetTableOutput(pred_html, polygons, logi_points, elapse)
- def transform_res(
- self,
- cell_box_det_map: Dict[int, List[any]],
- polygons: np.ndarray,
- logi_points: List[np.ndarray],
- ) -> List[Dict[str, any]]:
- res = []
- for i in range(len(polygons)):
- ocr_res_list = cell_box_det_map.get(i)
- if not ocr_res_list:
- continue
- xmin = min([ocr_box[0][0][0] for ocr_box in ocr_res_list])
- ymin = min([ocr_box[0][0][1] for ocr_box in ocr_res_list])
- xmax = max([ocr_box[0][2][0] for ocr_box in ocr_res_list])
- ymax = max([ocr_box[0][2][1] for ocr_box in ocr_res_list])
- dict_res = {
- # xmin,xmax,ymin,ymax
- "t_box": [xmin, ymin, xmax, ymax],
- # row_start,row_end,col_start,col_end
- "t_logic_box": logi_points[i].tolist(),
- # [[xmin,xmax,ymin,ymax], text]
- "t_ocr_res": [
- [box_4_2_poly_to_box_4_1(ocr_det[0]), ocr_det[1]]
- for ocr_det in ocr_res_list
- ],
- }
- res.append(dict_res)
- return res
- def sort_and_gather_ocr_res(self, res):
- for i, dict_res in enumerate(res):
- _, sorted_idx = sorted_ocr_boxes(
- [ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threshold=0.3
- )
- dict_res["t_ocr_res"] = [dict_res["t_ocr_res"][i] for i in sorted_idx]
- dict_res["t_ocr_res"] = gather_ocr_list_by_row(
- dict_res["t_ocr_res"], threshold=0.3
- )
- return res
- def fill_blank_rec(
- self,
- img: np.ndarray,
- sorted_polygons: np.ndarray,
- cell_box_map: Dict[int, List[str]],
- ocr_engine
- ) -> Dict[int, List[Any]]:
- """找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
- img_crop_info_list = []
- img_crop_list = []
- for i in range(sorted_polygons.shape[0]):
- if cell_box_map.get(i):
- continue
- box = sorted_polygons[i]
- if ocr_engine is None:
- logger.warning(f"No OCR engine provided for box {i}: {box}")
- continue
- # 从img中截取对应的区域
- x1, y1, x2, y2 = box[0][0], box[0][1], box[2][0], box[2][1]
- if x1 >= x2 or y1 >= y2:
- logger.warning(f"Invalid box coordinates: {box}")
- continue
- img_crop = img[int(y1):int(y2), int(x1):int(x2)]
- img_crop_list.append(img_crop)
- img_crop_info_list.append([i, box])
- if len(img_crop_list) > 0:
- # 进行ocr识别
- ocr_result = ocr_engine.ocr(img_crop_list, det=False)
- if not ocr_result or not isinstance(ocr_result, list) or len(ocr_result) == 0:
- logger.warning("OCR engine returned no results or invalid result for image crops.")
- return cell_box_map
- ocr_res_list = ocr_result[0]
- if not isinstance(ocr_res_list, list) or len(ocr_res_list) != len(img_crop_list):
- logger.warning("OCR result list length does not match image crop list length.")
- return cell_box_map
- for j, ocr_res in enumerate(ocr_res_list):
- img_crop_info_list[j].append(ocr_res)
- for i, box, ocr_res in img_crop_info_list:
- # 处理ocr结果
- ocr_text, ocr_score = ocr_res
- # logger.debug(f"OCR result for box {i}: {ocr_text} with score {ocr_score}")
- if ocr_score < 0.9:
- # logger.warning(f"Low confidence OCR result for box {i}: {ocr_text} with score {ocr_score}")
- box = sorted_polygons[i]
- cell_box_map[i] = [[box, "", 0.5]]
- continue
- cell_box_map[i] = [[box, ocr_text, ocr_score]]
- return cell_box_map
- def escape_html(input_string):
- """Escape HTML Entities."""
- return html.escape(input_string)
- class UnetTableModel:
- def __init__(self, ocr_engine):
- model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.unet_structure), ModelPath.unet_structure)
- wired_input_args = UnetTableInput(model_path=model_path)
- self.wired_table_model = UnetTableRecognition(wired_input_args)
- slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
- wireless_input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
- self.wireless_table_model = RapidTable(wireless_input_args)
- self.ocr_engine = ocr_engine
- def predict(self, img, table_cls_score):
- bgr_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
- ocr_result = self.ocr_engine.ocr(bgr_img)[0]
- if ocr_result:
- ocr_result = [
- [item[0], escape_html(item[1][0]), item[1][1]]
- for item in ocr_result
- if len(item) == 2 and isinstance(item[1], tuple)
- ]
- else:
- ocr_result = None
- if ocr_result:
- try:
- wired_table_results = self.wired_table_model(np.asarray(img), ocr_result, self.ocr_engine)
- # viser = VisTable()
- # save_html_path = f"outputs/output.html"
- # save_drawed_path = f"outputs/output_table_vis.jpg"
- # save_logic_path = (
- # f"outputs/output_table_vis_logic.jpg"
- # )
- # vis_imged = viser(
- # np.asarray(img), wired_table_results, save_html_path, save_drawed_path, save_logic_path
- # )
- wired_html_code = wired_table_results.pred_html
- wired_table_cell_bboxes = wired_table_results.cell_bboxes
- wired_logic_points = wired_table_results.logic_points
- wired_elapse = wired_table_results.elapse
- wireless_table_results = self.wireless_table_model(np.asarray(img), ocr_result)
- wireless_html_code = wireless_table_results.pred_html
- wireless_table_cell_bboxes = wireless_table_results.cell_bboxes
- wireless_logic_points = wireless_table_results.logic_points
- wireless_elapse = wireless_table_results.elapse
- wired_len = len(wired_table_cell_bboxes) if wired_table_cell_bboxes is not None else 0
- wireless_len = len(wireless_table_cell_bboxes) if wireless_table_cell_bboxes is not None else 0
- # logger.debug(f"wired table cell bboxes: {wired_len}, wireless table cell bboxes: {wireless_len}")
- # 计算两种模型检测的单元格数量差异
- gap_of_len = wireless_len - wired_len
- # 判断是否使用无线表格模型的结果
- if (
- wired_len <= round(wireless_len * 0.5) # 有线模型检测到的单元格数太少(低于无线模型的50%)
- or ((round(wireless_len*1.2) < wired_len) and (wired_len < (2 * wireless_len)) and table_cls_score <= 0.94) # 有线模型检测到的单元格数反而更多
- or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75)) # 两者相差不大但有线模型结果较少
- or (gap_of_len == 0 and wired_len <= 4) # 单元格数量完全相等且总量小于等于4
- ):
- # logger.debug("fall back to wireless table model")
- html_code = wireless_html_code
- table_cell_bboxes = wireless_table_cell_bboxes
- logic_points = wireless_logic_points
- else:
- html_code = wired_html_code
- table_cell_bboxes = wired_table_cell_bboxes
- logic_points = wired_logic_points
- elapse = wired_elapse + wireless_elapse
- return html_code, table_cell_bboxes, logic_points, elapse
- except Exception as e:
- logger.exception(e)
- return None, None, None, None
- class VisTable:
- def __init__(self):
- self.load_img = LoadImage()
- def __call__(
- self,
- img_path: InputType,
- table_results,
- save_html_path: Optional[Union[str, Path]] = None,
- save_drawed_path: Optional[Union[str, Path]] = None,
- save_logic_path: Optional[Union[str, Path]] = None,
- ):
- if save_html_path:
- html_with_border = self.insert_border_style(table_results.pred_html)
- self.save_html(save_html_path, html_with_border)
- table_cell_bboxes = table_results.cell_bboxes
- table_logic_points = table_results.logic_points
- if table_cell_bboxes is None:
- return None
- img = self.load_img(img_path)
- dims_bboxes = table_cell_bboxes.shape[1]
- if dims_bboxes == 4:
- drawed_img = self.draw_rectangle(img, table_cell_bboxes)
- elif dims_bboxes == 8:
- drawed_img = self.draw_polylines(img, table_cell_bboxes)
- else:
- raise ValueError("Shape of table bounding boxes is not between in 4 or 8.")
- if save_drawed_path:
- self.save_img(save_drawed_path, drawed_img)
- if save_logic_path:
- polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes]
- self.plot_rec_box_with_logic_info(
- img_path, save_logic_path, table_logic_points, polygons
- )
- return drawed_img
- def insert_border_style(self, table_html_str: str):
- style_res = """<meta charset="UTF-8"><style>
- table {
- border-collapse: collapse;
- width: 100%;
- }
- th, td {
- border: 1px solid black;
- padding: 8px;
- text-align: center;
- }
- th {
- background-color: #f2f2f2;
- }
- </style>"""
- prefix_table, suffix_table = table_html_str.split("<body>")
- html_with_border = f"{prefix_table}{style_res}<body>{suffix_table}"
- return html_with_border
- def plot_rec_box_with_logic_info(
- self, img_path, output_path, logic_points, sorted_polygons
- ):
- """
- :param img_path
- :param output_path
- :param logic_points: [row_start,row_end,col_start,col_end]
- :param sorted_polygons: [xmin,ymin,xmax,ymax]
- :return:
- """
- # 读取原图
- img = img_path
- img = cv2.copyMakeBorder(
- img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
- )
- # 绘制 polygons 矩形
- for idx, polygon in enumerate(sorted_polygons):
- x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
- x0 = round(x0)
- y0 = round(y0)
- x1 = round(x1)
- y1 = round(y1)
- cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
- # 增大字体大小和线宽
- font_scale = 0.9 # 原先是0.5
- thickness = 1 # 原先是1
- logic_point = logic_points[idx]
- cv2.putText(
- img,
- f"row: {logic_point[0]}-{logic_point[1]}",
- (x0 + 3, y0 + 8),
- cv2.FONT_HERSHEY_PLAIN,
- font_scale,
- (0, 0, 255),
- thickness,
- )
- cv2.putText(
- img,
- f"col: {logic_point[2]}-{logic_point[3]}",
- (x0 + 3, y0 + 18),
- cv2.FONT_HERSHEY_PLAIN,
- font_scale,
- (0, 0, 255),
- thickness,
- )
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
- # 保存绘制后的图像
- self.save_img(output_path, img)
- @staticmethod
- def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
- img_copy = img.copy()
- for box in boxes.astype(int):
- x1, y1, x2, y2 = box
- cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2)
- return img_copy
- @staticmethod
- def draw_polylines(img: np.ndarray, points) -> np.ndarray:
- img_copy = img.copy()
- for point in points.astype(int):
- point = point.reshape(4, 2)
- cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2)
- return img_copy
- @staticmethod
- def save_img(save_path: Union[str, Path], img: np.ndarray):
- cv2.imwrite(str(save_path), img)
- @staticmethod
- def save_html(save_path: Union[str, Path], html: str):
- with open(save_path, "w", encoding="utf-8") as f:
- f.write(html)
|