Jelajahi Sumber

feat: add batch predict for slanet_plus

Sidney233 2 bulan lalu
induk
melakukan
193d5d8e44
1 mengubah file dengan 26 tambahan dan 36 penghapusan
  1. 26 36
      mineru/backend/pipeline/batch_analyze.py

+ 26 - 36
mineru/backend/pipeline/batch_analyze.py

@@ -1,3 +1,5 @@
+import html
+
 import cv2
 from loguru import logger
 from tqdm import tqdm
@@ -278,48 +280,36 @@ class BatchAnalyze:
                 logger.warning(
                     f"Table classification failed: {e}, using default model"
                 )
-            # 遍历表格,根据分类识别结构
-            for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
+            # 遍历表格,获取 OCR 结果
+            for table_res_dict in tqdm(table_res_list_all_page, desc="Table OCR"):
                 _lang = table_res_dict['lang']
-                table_cls_score = 0.5
-                try:
-                    table_label, table_cls_score = table_res_dict['table_res']["cls_label"], table_res_dict['table_res']["cls_score"]
-                except Exception as e:
-                    logger.warning(
-                        f"Table classification failed: {e}, return error classification result: {table_res_dict}"
-                    )
-                    table_label = AtomicModel.WirelessTable
-                if table_label not in [
-                    AtomicModel.WirelessTable,
-                    AtomicModel.WiredTable,
-                ]:
-                    raise ValueError(
-                        "Table classification failed, please check the model"
-                    )
-
-                # 根据表格分类结果选择有线表格识别模型和无线表格识别模型
-                table_model = atom_model_manager.get_atom_model(
-                    atom_model_name=table_label,
+                ocr_engine = atom_model_manager.get_atom_model(
+                    atom_model_name=AtomicModel.OCR,
+                    det_db_box_thresh=0.5,
+                    det_db_unclip_ratio=1.6,
                     lang=_lang,
+                    enable_merge_det_boxes=False,
                 )
-
-                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict["table_img"], table_cls_score)
-                # 判断是否返回正常
-                if html_code:
-                    # 检查html_code是否包含'<table>'和'</table>'
-                    if '<table>' in html_code and '</table>' in html_code:
-                        # 选用<table>到</table>的内容,放入table_res_dict['table_res']['html']
-                        start_index = html_code.find('<table>')
-                        end_index = html_code.rfind('</table>') + len('</table>')
-                        table_res_dict['table_res']['html'] = html_code[start_index:end_index]
-                    else:
-                        logger.warning(
-                            'table recognition processing fails, not found expected HTML table end'
-                        )
+                bgr_image = cv2.cvtColor(np.asarray(table_res_dict["table_img"]), cv2.COLOR_RGB2BGR)
+                ocr_result = ocr_engine.ocr(bgr_image)[0]
+                if ocr_result:
+                    ocr_result = [
+                        [item[0], html.escape(item[1][0]), item[1][1]]
+                        for item in ocr_result
+                        if len(item) == 2 and isinstance(item[1], tuple)
+                    ]
                 else:
                     logger.warning(
-                        'table recognition processing fails, not get html return'
+                        f"Table OCR result returns None"
                     )
+                table_res_dict["ocr_result"] = ocr_result
+
+            # 先用无线表格模型
+            wireless_table_model = atom_model_manager.get_atom_model(
+                atom_model_name="wireless_table",
+            )
+
+            wireless_table_model.batch_predict(table_res_list_all_page)
 
         # Create dictionaries to store items by language
         need_ocr_lists_by_lang = {}  # Dict of lists for each language