|
|
@@ -8,6 +8,7 @@ from typing import List, Optional, Union, Dict, Any
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
from loguru import logger
|
|
|
+from rapid_table import RapidTableInput, RapidTable
|
|
|
|
|
|
from mineru.utils.enum_class import ModelPath
|
|
|
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
|
|
|
@@ -47,6 +48,7 @@ class UnetTableRecognition:
|
|
|
self,
|
|
|
img: InputType,
|
|
|
ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
|
|
|
+ ocr_engine = None,
|
|
|
**kwargs,
|
|
|
) -> UnetTableOutput:
|
|
|
s = time.perf_counter()
|
|
|
@@ -84,7 +86,7 @@ class UnetTableRecognition:
|
|
|
)
|
|
|
cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
|
|
|
# 如果有识别框没有ocr结果,直接进行rec补充
|
|
|
- cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map)
|
|
|
+ cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map, ocr_engine)
|
|
|
# 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
|
|
|
t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
|
|
|
# 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
|
|
|
@@ -150,14 +152,45 @@ class UnetTableRecognition:
|
|
|
img: np.ndarray,
|
|
|
sorted_polygons: np.ndarray,
|
|
|
cell_box_map: Dict[int, List[str]],
|
|
|
+ ocr_engine
|
|
|
) -> Dict[int, List[Any]]:
|
|
|
"""找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
|
|
|
+ img_crop_info_list = []
|
|
|
+ img_crop_list = []
|
|
|
for i in range(sorted_polygons.shape[0]):
|
|
|
if cell_box_map.get(i):
|
|
|
continue
|
|
|
box = sorted_polygons[i]
|
|
|
- cell_box_map[i] = [[box, "", 1]]
|
|
|
+ if ocr_engine is None:
|
|
|
+ logger.warning(f"No OCR engine provided for box {i}: {box}")
|
|
|
+ continue
|
|
|
+ # 从img中截取对应的区域
|
|
|
+ x1, y1, x2, y2 = box[0][0], box[0][1], box[2][0], box[2][1]
|
|
|
+ if x1 >= x2 or y1 >= y2:
|
|
|
+ logger.warning(f"Invalid box coordinates: {box}")
|
|
|
+ continue
|
|
|
+ img_crop = img[int(y1):int(y2), int(x1):int(x2)]
|
|
|
+ img_crop_list.append(img_crop)
|
|
|
+ img_crop_info_list.append([i, box])
|
|
|
continue
|
|
|
+
|
|
|
+ if len(img_crop_list) > 0:
|
|
|
+ # 进行ocr识别
|
|
|
+ ocr_res_list = ocr_engine.ocr(img_crop_list, det=False)[0]
|
|
|
+ assert len(ocr_res_list) == len(img_crop_list)
|
|
|
+ for j, ocr_res in enumerate(ocr_res_list):
|
|
|
+ img_crop_info_list[j].append(ocr_res)
|
|
|
+
|
|
|
+
|
|
|
+ for i, box, ocr_res in img_crop_info_list:
|
|
|
+ # 处理ocr结果
|
|
|
+ ocr_text, ocr_score = ocr_res
|
|
|
+ # logger.debug(f"OCR result for box {i}: {ocr_text} with score {ocr_score}")
|
|
|
+ if ocr_score < 0.9:
|
|
|
+ # logger.warning(f"Low confidence OCR result for box {i}: {ocr_text} with score {ocr_score}")
|
|
|
+ continue
|
|
|
+ cell_box_map[i] = [[box, ocr_text, ocr_score]]
|
|
|
+
|
|
|
return cell_box_map
|
|
|
|
|
|
|
|
|
@@ -169,11 +202,14 @@ def escape_html(input_string):
|
|
|
class UnetTableModel:
|
|
|
def __init__(self, ocr_engine):
|
|
|
model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.unet_structure), ModelPath.unet_structure)
|
|
|
- input_args = UnetTableInput(model_path=model_path)
|
|
|
- self.table_model = UnetTableRecognition(input_args)
|
|
|
+ wired_input_args = UnetTableInput(model_path=model_path)
|
|
|
+ self.wired_table_model = UnetTableRecognition(wired_input_args)
|
|
|
+ slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
|
|
|
+ wireless_input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
|
|
|
+ self.wireless_table_model = RapidTable(wireless_input_args)
|
|
|
self.ocr_engine = ocr_engine
|
|
|
|
|
|
- def predict(self, img):
|
|
|
+ def predict(self, img, table_cls_score):
|
|
|
bgr_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
|
|
ocr_result = self.ocr_engine.ocr(bgr_img)[0]
|
|
|
|
|
|
@@ -187,11 +223,40 @@ class UnetTableModel:
|
|
|
ocr_result = None
|
|
|
if ocr_result:
|
|
|
try:
|
|
|
- table_results = self.table_model(np.asarray(img), ocr_result)
|
|
|
- html_code = table_results.pred_html
|
|
|
- table_cell_bboxes = table_results.cell_bboxes
|
|
|
- logic_points = table_results.logic_points
|
|
|
- elapse = table_results.elapse
|
|
|
+ wired_table_results = self.wired_table_model(np.asarray(img), ocr_result, self.ocr_engine)
|
|
|
+ wired_html_code = wired_table_results.pred_html
|
|
|
+ wired_table_cell_bboxes = wired_table_results.cell_bboxes
|
|
|
+ wired_logic_points = wired_table_results.logic_points
|
|
|
+ wired_elapse = wired_table_results.elapse
|
|
|
+
|
|
|
+ wireless_table_results = self.wireless_table_model(np.asarray(img), ocr_result)
|
|
|
+ wireless_html_code = wireless_table_results.pred_html
|
|
|
+ wireless_table_cell_bboxes = wireless_table_results.cell_bboxes
|
|
|
+ wireless_logic_points = wireless_table_results.logic_points
|
|
|
+ wireless_elapse = wireless_table_results.elapse
|
|
|
+
|
|
|
+ wired_len = len(wired_table_cell_bboxes) if wired_table_cell_bboxes is not None else 0
|
|
|
+ wireless_len = len(wireless_table_cell_bboxes) if wireless_table_cell_bboxes is not None else 0
|
|
|
+ # logger.debug(f"wired table cell bboxes: {wired_len}, wireless table cell bboxes: {wireless_len}")
|
|
|
+ # 计算两种模型检测的单元格数量差异
|
|
|
+ gap_of_len = wireless_len - wired_len
|
|
|
+ # 判断是否使用无线表格模型的结果
|
|
|
+ if (
|
|
|
+ wired_len <= round(wireless_len * 0.5) # 有线模型检测到的单元格数太少(低于无线模型的50%)
|
|
|
+ or (wireless_len < wired_len < (2 * wireless_len) and table_cls_score <= 0.949) # 有线模型检测到的单元格数反而更多
|
|
|
+ or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75)) # 两者相差不大但有线模型结果较少
|
|
|
+ or (gap_of_len == 0 and wired_len <= 4) # 单元格数量完全相等且总量小于等于4
|
|
|
+ ):
|
|
|
+ # logger.debug("fall back to wireless table model")
|
|
|
+ html_code = wireless_html_code
|
|
|
+ table_cell_bboxes = wireless_table_cell_bboxes
|
|
|
+ logic_points = wireless_logic_points
|
|
|
+ else:
|
|
|
+ html_code = wired_html_code
|
|
|
+ table_cell_bboxes = wired_table_cell_bboxes
|
|
|
+ logic_points = wired_logic_points
|
|
|
+
|
|
|
+ elapse = wired_elapse + wireless_elapse
|
|
|
return html_code, table_cell_bboxes, logic_points, elapse
|
|
|
except Exception as e:
|
|
|
logger.exception(e)
|