|
@@ -12,6 +12,8 @@ from typing import Dict, List, Optional, Tuple, Union
|
|
|
import cv2
|
|
import cv2
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
from loguru import logger
|
|
from loguru import logger
|
|
|
|
|
+from tqdm import tqdm
|
|
|
|
|
+
|
|
|
from .matcher import TableMatch
|
|
from .matcher import TableMatch
|
|
|
from .table_structure import TableStructurer
|
|
from .table_structure import TableStructurer
|
|
|
from mineru.utils.enum_class import ModelPath
|
|
from mineru.utils.enum_class import ModelPath
|
|
@@ -72,40 +74,25 @@ class RapidTable:
|
|
|
raise ValueError(f"{self.model_type} is not supported.")
|
|
raise ValueError(f"{self.model_type} is not supported.")
|
|
|
self.table_matcher = TableMatch()
|
|
self.table_matcher = TableMatch()
|
|
|
|
|
|
|
|
- try:
|
|
|
|
|
- self.ocr_engine = importlib.import_module("rapidocr").RapidOCR()
|
|
|
|
|
- except ModuleNotFoundError:
|
|
|
|
|
- self.ocr_engine = None
|
|
|
|
|
-
|
|
|
|
|
- def __call__(
|
|
|
|
|
|
|
+ def predict(
|
|
|
self,
|
|
self,
|
|
|
img: np.ndarray,
|
|
img: np.ndarray,
|
|
|
ocr_result: List[Union[List[List[float]], str, str]] = None,
|
|
ocr_result: List[Union[List[List[float]], str, str]] = None,
|
|
|
) -> RapidTableOutput:
|
|
) -> RapidTableOutput:
|
|
|
- if self.ocr_engine is None and ocr_result is None:
|
|
|
|
|
- raise ValueError(
|
|
|
|
|
- "One of two conditions must be met: ocr_result is not empty, or rapidocr is installed."
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ if ocr_result is None:
|
|
|
|
|
+ raise ValueError("OCR result is None")
|
|
|
|
|
|
|
|
s = time.perf_counter()
|
|
s = time.perf_counter()
|
|
|
h, w = img.shape[:2]
|
|
h, w = img.shape[:2]
|
|
|
|
|
|
|
|
- if ocr_result is None:
|
|
|
|
|
- ocr_result = self.ocr_engine(img)
|
|
|
|
|
- ocr_result = list(
|
|
|
|
|
- zip(
|
|
|
|
|
- ocr_result.boxes,
|
|
|
|
|
- ocr_result.txts,
|
|
|
|
|
- ocr_result.scores,
|
|
|
|
|
- )
|
|
|
|
|
- )
|
|
|
|
|
dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
|
|
dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
|
|
|
|
|
|
|
|
- pred_structures, cell_bboxes, _ = self.table_structure(copy.deepcopy(img))
|
|
|
|
|
|
|
+ pred_structures, cell_bboxes, _ = self.table_structure.process(
|
|
|
|
|
+ copy.deepcopy(img)
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
# 适配slanet-plus模型输出的box缩放还原
|
|
# 适配slanet-plus模型输出的box缩放还原
|
|
|
- if self.model_type == ModelType.SLANETPLUS.value:
|
|
|
|
|
- cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
|
|
|
|
|
|
|
+ cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
|
|
|
|
|
|
|
|
pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
|
|
pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
|
|
|
|
|
|
|
@@ -117,6 +104,50 @@ class RapidTable:
|
|
|
elapse = time.perf_counter() - s
|
|
elapse = time.perf_counter() - s
|
|
|
return RapidTableOutput(pred_html, cell_bboxes, logic_points, elapse)
|
|
return RapidTableOutput(pred_html, cell_bboxes, logic_points, elapse)
|
|
|
|
|
|
|
|
|
|
+ def batch_predict(
|
|
|
|
|
+ self,
|
|
|
|
|
+ images: List[np.ndarray],
|
|
|
|
|
+ ocr_results: List[List[Union[List[List[float]], str, str]]],
|
|
|
|
|
+ batch_size: int = 4,
|
|
|
|
|
+ ) -> List[RapidTableOutput]:
|
|
|
|
|
+ """批量处理图像"""
|
|
|
|
|
+ s = time.perf_counter()
|
|
|
|
|
+
|
|
|
|
|
+ batch_dt_boxes = []
|
|
|
|
|
+ batch_rec_res = []
|
|
|
|
|
+
|
|
|
|
|
+ for i, img in enumerate(images):
|
|
|
|
|
+ h, w = img.shape[:2]
|
|
|
|
|
+ dt_boxes, rec_res = self.get_boxes_recs(ocr_results[i], h, w)
|
|
|
|
|
+ batch_dt_boxes.append(dt_boxes)
|
|
|
|
|
+ batch_rec_res.append(rec_res)
|
|
|
|
|
+
|
|
|
|
|
+ # 批量表格结构识别
|
|
|
|
|
+ batch_results = self.table_structure.batch_process(images)
|
|
|
|
|
+
|
|
|
|
|
+ output_results = []
|
|
|
|
|
+ for i, (img, ocr_result, (pred_structures, cell_bboxes, _)) in enumerate(
|
|
|
|
|
+ zip(images, ocr_results, batch_results)
|
|
|
|
|
+ ):
|
|
|
|
|
+ # 适配slanet-plus模型输出的box缩放还原
|
|
|
|
|
+ cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
|
|
|
|
|
+ pred_html = self.table_matcher(
|
|
|
|
|
+ pred_structures, cell_bboxes, batch_dt_boxes[i], batch_rec_res[i]
|
|
|
|
|
+ )
|
|
|
|
|
+ # 过滤掉占位的bbox
|
|
|
|
|
+ mask = ~np.all(cell_bboxes == 0, axis=1)
|
|
|
|
|
+ cell_bboxes = cell_bboxes[mask]
|
|
|
|
|
+
|
|
|
|
|
+ logic_points = self.table_matcher.decode_logic_points(pred_structures)
|
|
|
|
|
+ result = RapidTableOutput(pred_html, cell_bboxes, logic_points, 0)
|
|
|
|
|
+ output_results.append(result)
|
|
|
|
|
+
|
|
|
|
|
+ total_elapse = time.perf_counter() - s
|
|
|
|
|
+ for result in output_results:
|
|
|
|
|
+ result.elapse = total_elapse / len(output_results)
|
|
|
|
|
+
|
|
|
|
|
+ return output_results
|
|
|
|
|
+
|
|
|
def get_boxes_recs(
|
|
def get_boxes_recs(
|
|
|
self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
|
|
self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
|
|
|
) -> Tuple[np.ndarray, Tuple[str, str]]:
|
|
) -> Tuple[np.ndarray, Tuple[str, str]]:
|
|
@@ -201,7 +232,7 @@ class RapidTableModel(object):
|
|
|
|
|
|
|
|
if ocr_result:
|
|
if ocr_result:
|
|
|
try:
|
|
try:
|
|
|
- table_results = self.table_model(np.asarray(image), ocr_result)
|
|
|
|
|
|
|
+ table_results = self.table_model.predict(np.asarray(image), ocr_result)
|
|
|
html_code = table_results.pred_html
|
|
html_code = table_results.pred_html
|
|
|
table_cell_bboxes = table_results.cell_bboxes
|
|
table_cell_bboxes = table_results.cell_bboxes
|
|
|
logic_points = table_results.logic_points
|
|
logic_points = table_results.logic_points
|
|
@@ -211,3 +242,37 @@ class RapidTableModel(object):
|
|
|
logger.exception(e)
|
|
logger.exception(e)
|
|
|
|
|
|
|
|
return None, None, None, None
|
|
return None, None, None, None
|
|
|
|
|
+
|
|
|
|
|
+ def batch_predict(self, table_res_list: List[Dict], batch_size: int = 4) -> None:
|
|
|
|
|
+ """对传入的字典列表进行批量预测,无返回值"""
|
|
|
|
|
+ for index in tqdm(
|
|
|
|
|
+ range(0, len(table_res_list), batch_size),
|
|
|
|
|
+ desc=f"Table Batch Predict, total={len(table_res_list)}, batch_size={batch_size}",
|
|
|
|
|
+ ):
|
|
|
|
|
+ batch_imgs = [
|
|
|
|
|
+ cv2.cvtColor(np.asarray(table_res_list[i]["table_img"]), cv2.COLOR_RGB2BGR)
|
|
|
|
|
+ for i in range(index, min(index + batch_size, len(table_res_list)))
|
|
|
|
|
+ ]
|
|
|
|
|
+ batch_ocrs = [
|
|
|
|
|
+ table_res_list[i]["ocr_result"]
|
|
|
|
|
+ for i in range(index, min(index + batch_size, len(table_res_list)))
|
|
|
|
|
+ ]
|
|
|
|
|
+ results = self.table_model.batch_predict(
|
|
|
|
|
+ batch_imgs, batch_ocrs, batch_size=batch_size
|
|
|
|
|
+ )
|
|
|
|
|
+ for i, result in enumerate(results):
|
|
|
|
|
+ if result.pred_html:
|
|
|
|
|
+ # 检查html_code是否包含'<table>'和'</table>'
|
|
|
|
|
+ if '<table>' in result.pred_html and '</table>' in result.pred_html:
|
|
|
|
|
+ # 选用<table>到</table>的内容,放入table_res_dict['table_res']['html']
|
|
|
|
|
+ start_index = result.pred_html.find('<table>')
|
|
|
|
|
+ end_index = result.pred_html.rfind('</table>') + len('</table>')
|
|
|
|
|
+ table_res_list[index + i]['table_res']['html'] = result.pred_html[start_index:end_index]
|
|
|
|
|
+ else:
|
|
|
|
|
+ logger.warning(
|
|
|
|
|
+ 'table recognition processing fails, not found expected HTML table end'
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ logger.warning(
|
|
|
|
|
+ "table recognition processing fails, not get html return"
|
|
|
|
|
+ )
|