|
|
@@ -10,8 +10,14 @@ from .model_init import AtomModelSingleton
|
|
|
from .model_list import AtomicModel
|
|
|
from ...utils.config_reader import get_formula_enable, get_table_enable
|
|
|
from ...utils.model_utils import crop_img, get_res_list_from_layout_res
|
|
|
-from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
|
|
|
+from ...utils.ocr_utils import (
|
|
|
+ get_adjusted_mfdetrec_res,
|
|
|
+ get_ocr_result_list,
|
|
|
+ OcrConfidence,
|
|
|
+ get_rotate_crop_image,
|
|
|
+)
|
|
|
from ...utils.pdf_image_tools import get_crop_img
|
|
|
+from ...utils.ocr_utils import merge_det_boxes, update_det_boxes, sorted_boxes
|
|
|
|
|
|
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
|
|
|
MFD_BASE_BATCH_SIZE = 1
|
|
|
@@ -193,9 +199,6 @@ class BatchAnalyze:
|
|
|
|
|
|
if dt_boxes is not None and len(dt_boxes) > 0:
|
|
|
# 直接应用原始OCR流程中的关键处理步骤
|
|
|
- from mineru.utils.ocr_utils import (
|
|
|
- merge_det_boxes, update_det_boxes, sorted_boxes
|
|
|
- )
|
|
|
|
|
|
# 1. 排序检测框
|
|
|
if len(dt_boxes) > 0:
|
|
|
@@ -280,9 +283,12 @@ class BatchAnalyze:
|
|
|
logger.warning(
|
|
|
f"Table classification failed: {e}, using default model"
|
|
|
)
|
|
|
- # 遍历表格,获取 OCR 结果
|
|
|
- for table_res_dict in tqdm(table_res_list_all_page, desc="Table OCR"):
|
|
|
- _lang = table_res_dict['lang']
|
|
|
+ rec_img_lang_group = defaultdict(list)
|
|
|
+ # OCR det 过程,顺序执行
|
|
|
+ for index, table_res_dict in enumerate(
|
|
|
+ tqdm(table_res_list_all_page, desc="Table OCR det")
|
|
|
+ ):
|
|
|
+ _lang = table_res_dict["lang"]
|
|
|
ocr_engine = atom_model_manager.get_atom_model(
|
|
|
atom_model_name=AtomicModel.OCR,
|
|
|
det_db_box_thresh=0.5,
|
|
|
@@ -290,26 +296,80 @@ class BatchAnalyze:
|
|
|
lang=_lang,
|
|
|
enable_merge_det_boxes=False,
|
|
|
)
|
|
|
- bgr_image = cv2.cvtColor(np.asarray(table_res_dict["table_img"]), cv2.COLOR_RGB2BGR)
|
|
|
- ocr_result = ocr_engine.ocr(bgr_image)[0]
|
|
|
- if ocr_result:
|
|
|
- ocr_result = [
|
|
|
- [item[0], html.escape(item[1][0]), item[1][1]]
|
|
|
- for item in ocr_result
|
|
|
- if len(item) == 2 and isinstance(item[1], tuple)
|
|
|
- ]
|
|
|
- else:
|
|
|
- logger.warning(
|
|
|
- f"Table OCR result returns None"
|
|
|
+ bgr_image = cv2.cvtColor(
|
|
|
+ np.asarray(table_res_dict["table_img"]), cv2.COLOR_RGB2BGR
|
|
|
+ )
|
|
|
+ ocr_result = ocr_engine.ocr(bgr_image, det=True, rec=False)[0]
|
|
|
+ # 构造需要 OCR 识别的图片字典,包括cropped_img, dt_box, table_id,并按照语言进行分组
|
|
|
+ for dt_box in ocr_result:
|
|
|
+ rec_img_lang_group[_lang].append(
|
|
|
+ {
|
|
|
+ "cropped_img": get_rotate_crop_image(
|
|
|
+ bgr_image, np.asarray(dt_box, dtype=np.float32)
|
|
|
+ ),
|
|
|
+ "dt_box": np.asarray(dt_box, dtype=np.float32),
|
|
|
+ "table_id": index,
|
|
|
+ }
|
|
|
)
|
|
|
- table_res_dict["ocr_result"] = ocr_result
|
|
|
+ # OCR rec,按照语言分批处理
|
|
|
+ for _lang, rec_img_list in rec_img_lang_group.items():
|
|
|
+ ocr_engine = atom_model_manager.get_atom_model(
|
|
|
+ atom_model_name=AtomicModel.OCR,
|
|
|
+ det_db_box_thresh=0.5,
|
|
|
+ det_db_unclip_ratio=1.6,
|
|
|
+ lang=_lang,
|
|
|
+ enable_merge_det_boxes=False,
|
|
|
+ )
|
|
|
+ cropped_img_list = [item["cropped_img"] for item in rec_img_list]
|
|
|
+ ocr_res_list = ocr_engine.ocr(
|
|
|
+ cropped_img_list, det=False, rec=True, tqdm_enable=True
|
|
|
+ )[0]
|
|
|
+ # 按照 table_id 将识别结果进行回填
|
|
|
+ for img_dict, ocr_res in zip(rec_img_list, ocr_res_list):
|
|
|
+ if table_res_list_all_page[img_dict["table_id"]].get("ocr_result"):
|
|
|
+ table_res_list_all_page[img_dict["table_id"]][
|
|
|
+ "ocr_result"
|
|
|
+ ].append(
|
|
|
+ [img_dict["dt_box"], html.escape(ocr_res[0]), ocr_res[1]]
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ table_res_list_all_page[img_dict["table_id"]]["ocr_result"] = [
|
|
|
+ [img_dict["dt_box"], html.escape(ocr_res[0]), ocr_res[1]]
|
|
|
+ ]
|
|
|
|
|
|
- # 先用无线表格模型
|
|
|
+ # 先对所有表格使用无线表格模型,然后对分类为有线的表格使用有线表格模型
|
|
|
wireless_table_model = atom_model_manager.get_atom_model(
|
|
|
- atom_model_name="wireless_table",
|
|
|
+ atom_model_name=AtomicModel.WirelessTable,
|
|
|
)
|
|
|
|
|
|
wireless_table_model.batch_predict(table_res_list_all_page)
|
|
|
+ for table_res_dict in tqdm(
|
|
|
+ table_res_list_all_page, desc="Wired Table Predict"
|
|
|
+ ):
|
|
|
+ if table_res_dict["table_res"]["cls_label"] == AtomicModel.WiredTable:
|
|
|
+ wired_table_model = atom_model_manager.get_atom_model(
|
|
|
+ atom_model_name=AtomicModel.WiredTable,
|
|
|
+ lang=table_res_dict["lang"],
|
|
|
+ )
|
|
|
+ if table_res_dict["table_res"].get("html") is None:
|
|
|
+ logger.warning("Table Wireless Predict Error.")
|
|
|
+ html_code = wired_table_model.predict(
|
|
|
+ table_res_dict["table_img"],
|
|
|
+ table_res_dict["table_res"]["cls_score"],
|
|
|
+ table_res_dict["table_res"]["html"],
|
|
|
+ )
|
|
|
+ # 检查html_code是否包含'<table>'和'</table>'
|
|
|
+ if "<table>" in html_code and "</table>" in html_code:
|
|
|
+ # 选用<table>到</table>的内容,放入table_res_dict['table_res']['html']
|
|
|
+ start_index = html_code.find("<table>")
|
|
|
+ end_index = html_code.rfind("</table>") + len("</table>")
|
|
|
+ table_res_dict["table_res"]["html"] = html_code[
|
|
|
+ start_index:end_index
|
|
|
+ ]
|
|
|
+ else:
|
|
|
+ logger.warning(
|
|
|
+ "wired table recognition processing fails, not found expected HTML table end"
|
|
|
+ )
|
|
|
|
|
|
# Create dictionaries to store items by language
|
|
|
need_ocr_lists_by_lang = {} # Dict of lists for each language
|