|
|
@@ -1,3 +1,5 @@
|
|
|
+import html
|
|
|
+
|
|
|
import cv2
|
|
|
from loguru import logger
|
|
|
from tqdm import tqdm
|
|
|
@@ -278,48 +280,36 @@ class BatchAnalyze:
|
|
|
logger.warning(
|
|
|
f"Table classification failed: {e}, using default model"
|
|
|
)
|
|
|
- # 遍历表格,根据分类识别结构
|
|
|
- for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
|
|
|
+ # 遍历表格,获取 OCR 结果
|
|
|
+ for table_res_dict in tqdm(table_res_list_all_page, desc="Table OCR"):
|
|
|
_lang = table_res_dict['lang']
|
|
|
- table_cls_score = 0.5
|
|
|
- try:
|
|
|
- table_label, table_cls_score = table_res_dict['table_res']["cls_label"], table_res_dict['table_res']["cls_score"]
|
|
|
- except Exception as e:
|
|
|
- logger.warning(
|
|
|
- f"Table classification failed: {e}, return error classification result: {table_res_dict}"
|
|
|
- )
|
|
|
- table_label = AtomicModel.WirelessTable
|
|
|
- if table_label not in [
|
|
|
- AtomicModel.WirelessTable,
|
|
|
- AtomicModel.WiredTable,
|
|
|
- ]:
|
|
|
- raise ValueError(
|
|
|
- "Table classification failed, please check the model"
|
|
|
- )
|
|
|
-
|
|
|
- # 根据表格分类结果选择有线表格识别模型和无线表格识别模型
|
|
|
- table_model = atom_model_manager.get_atom_model(
|
|
|
- atom_model_name=table_label,
|
|
|
+ ocr_engine = atom_model_manager.get_atom_model(
|
|
|
+ atom_model_name=AtomicModel.OCR,
|
|
|
+ det_db_box_thresh=0.5,
|
|
|
+ det_db_unclip_ratio=1.6,
|
|
|
lang=_lang,
|
|
|
+ enable_merge_det_boxes=False,
|
|
|
)
|
|
|
-
|
|
|
- html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict["table_img"], table_cls_score)
|
|
|
- # 判断是否返回正常
|
|
|
- if html_code:
|
|
|
- # 检查html_code是否包含'<table>'和'</table>'
|
|
|
- if '<table>' in html_code and '</table>' in html_code:
|
|
|
- # 选用<table>到</table>的内容,放入table_res_dict['table_res']['html']
|
|
|
- start_index = html_code.find('<table>')
|
|
|
- end_index = html_code.rfind('</table>') + len('</table>')
|
|
|
- table_res_dict['table_res']['html'] = html_code[start_index:end_index]
|
|
|
- else:
|
|
|
- logger.warning(
|
|
|
- 'table recognition processing fails, not found expected HTML table end'
|
|
|
- )
|
|
|
+ bgr_image = cv2.cvtColor(np.asarray(table_res_dict["table_img"]), cv2.COLOR_RGB2BGR)
|
|
|
+ ocr_result = ocr_engine.ocr(bgr_image)[0]
|
|
|
+ if ocr_result:
|
|
|
+ ocr_result = [
|
|
|
+ [item[0], html.escape(item[1][0]), item[1][1]]
|
|
|
+ for item in ocr_result
|
|
|
+ if len(item) == 2 and isinstance(item[1], tuple)
|
|
|
+ ]
|
|
|
else:
|
|
|
logger.warning(
|
|
|
- 'table recognition processing fails, not get html return'
|
|
|
+ f"Table OCR result returns None"
|
|
|
)
|
|
|
+ table_res_dict["ocr_result"] = ocr_result
|
|
|
+
|
|
|
+ # 先用无线表格模型
|
|
|
+ wireless_table_model = atom_model_manager.get_atom_model(
|
|
|
+ atom_model_name="wireless_table",
|
|
|
+ )
|
|
|
+
|
|
|
+ wireless_table_model.batch_predict(table_res_list_all_page)
|
|
|
|
|
|
# Create dictionaries to store items by language
|
|
|
need_ocr_lists_by_lang = {} # Dict of lists for each language
|