Browse Source

feat: add batch predict for table ocr

Sidney233 2 tháng trước cách đây
mục cha
commit
da1431558a

+ 81 - 21
mineru/backend/pipeline/batch_analyze.py

@@ -10,8 +10,14 @@ from .model_init import AtomModelSingleton
 from .model_list import AtomicModel
 from ...utils.config_reader import get_formula_enable, get_table_enable
 from ...utils.model_utils import crop_img, get_res_list_from_layout_res
-from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
+from ...utils.ocr_utils import (
+    get_adjusted_mfdetrec_res,
+    get_ocr_result_list,
+    OcrConfidence,
+    get_rotate_crop_image,
+)
 from ...utils.pdf_image_tools import get_crop_img
+from ...utils.ocr_utils import merge_det_boxes, update_det_boxes, sorted_boxes
 
 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
@@ -193,9 +199,6 @@ class BatchAnalyze:
 
                         if dt_boxes is not None and len(dt_boxes) > 0:
                             # 直接应用原始OCR流程中的关键处理步骤
-                            from mineru.utils.ocr_utils import (
-                                merge_det_boxes, update_det_boxes, sorted_boxes
-                            )
 
                             # 1. 排序检测框
                             if len(dt_boxes) > 0:
@@ -280,9 +283,12 @@ class BatchAnalyze:
                 logger.warning(
                     f"Table classification failed: {e}, using default model"
                 )
-            # 遍历表格,获取 OCR 结果
-            for table_res_dict in tqdm(table_res_list_all_page, desc="Table OCR"):
-                _lang = table_res_dict['lang']
+            rec_img_lang_group = defaultdict(list)
+            # OCR det 过程,顺序执行
+            for index, table_res_dict in enumerate(
+                tqdm(table_res_list_all_page, desc="Table OCR det")
+            ):
+                _lang = table_res_dict["lang"]
                 ocr_engine = atom_model_manager.get_atom_model(
                     atom_model_name=AtomicModel.OCR,
                     det_db_box_thresh=0.5,
@@ -290,26 +296,80 @@ class BatchAnalyze:
                     lang=_lang,
                     enable_merge_det_boxes=False,
                 )
-                bgr_image = cv2.cvtColor(np.asarray(table_res_dict["table_img"]), cv2.COLOR_RGB2BGR)
-                ocr_result = ocr_engine.ocr(bgr_image)[0]
-                if ocr_result:
-                    ocr_result = [
-                        [item[0], html.escape(item[1][0]), item[1][1]]
-                        for item in ocr_result
-                        if len(item) == 2 and isinstance(item[1], tuple)
-                    ]
-                else:
-                    logger.warning(
-                        f"Table OCR result returns None"
+                bgr_image = cv2.cvtColor(
+                    np.asarray(table_res_dict["table_img"]), cv2.COLOR_RGB2BGR
+                )
+                ocr_result = ocr_engine.ocr(bgr_image, det=True, rec=False)[0]
+                # 构造需要 OCR 识别的图片字典,包括cropped_img, dt_box, table_id,并按照语言进行分组
+                for dt_box in ocr_result:
+                    rec_img_lang_group[_lang].append(
+                        {
+                            "cropped_img": get_rotate_crop_image(
+                                bgr_image, np.asarray(dt_box, dtype=np.float32)
+                            ),
+                            "dt_box": np.asarray(dt_box, dtype=np.float32),
+                            "table_id": index,
+                        }
                     )
-                table_res_dict["ocr_result"] = ocr_result
+            # OCR rec,按照语言分批处理
+            for _lang, rec_img_list in rec_img_lang_group.items():
+                ocr_engine = atom_model_manager.get_atom_model(
+                    atom_model_name=AtomicModel.OCR,
+                    det_db_box_thresh=0.5,
+                    det_db_unclip_ratio=1.6,
+                    lang=_lang,
+                    enable_merge_det_boxes=False,
+                )
+                cropped_img_list = [item["cropped_img"] for item in rec_img_list]
+                ocr_res_list = ocr_engine.ocr(
+                    cropped_img_list, det=False, rec=True, tqdm_enable=True
+                )[0]
+                # 按照 table_id 将识别结果进行回填
+                for img_dict, ocr_res in zip(rec_img_list, ocr_res_list):
+                    if table_res_list_all_page[img_dict["table_id"]].get("ocr_result"):
+                        table_res_list_all_page[img_dict["table_id"]][
+                            "ocr_result"
+                        ].append(
+                            [img_dict["dt_box"], html.escape(ocr_res[0]), ocr_res[1]]
+                        )
+                    else:
+                        table_res_list_all_page[img_dict["table_id"]]["ocr_result"] = [
+                            [img_dict["dt_box"], html.escape(ocr_res[0]), ocr_res[1]]
+                        ]
 
-            # 先用无线表格模型
+            # 先对所有表格使用无线表格模型,然后对分类为有线的表格使用有线表格模型
             wireless_table_model = atom_model_manager.get_atom_model(
-                atom_model_name="wireless_table",
+                atom_model_name=AtomicModel.WirelessTable,
             )
 
             wireless_table_model.batch_predict(table_res_list_all_page)
+            for table_res_dict in tqdm(
+                table_res_list_all_page, desc="Wired Table Predict"
+            ):
+                if table_res_dict["table_res"]["cls_label"] == AtomicModel.WiredTable:
+                    wired_table_model = atom_model_manager.get_atom_model(
+                        atom_model_name=AtomicModel.WiredTable,
+                        lang=table_res_dict["lang"],
+                    )
+                    if table_res_dict["table_res"].get("html") is None:
+                        logger.warning("Table Wireless Predict Error.")
+                    html_code = wired_table_model.predict(
+                        table_res_dict["table_img"],
+                        table_res_dict["table_res"]["cls_score"],
+                        table_res_dict["table_res"]["html"],
+                    )
+                    # 检查html_code是否包含'<table>'和'</table>'
+                    if "<table>" in html_code and "</table>" in html_code:
+                        # 选用<table>到</table>的内容,放入table_res_dict['table_res']['html']
+                        start_index = html_code.find("<table>")
+                        end_index = html_code.rfind("</table>") + len("</table>")
+                        table_res_dict["table_res"]["html"] = html_code[
+                            start_index:end_index
+                        ]
+                    else:
+                        logger.warning(
+                            "wired table recognition processing fails, not found expected HTML table end"
+                        )
 
         # Create dictionaries to store items by language
         need_ocr_lists_by_lang = {}  # Dict of lists for each language

+ 1 - 1
mineru/model/table/rec/slanet_plus/main.py

@@ -247,7 +247,7 @@ class RapidTableModel(object):
         """对传入的字典列表进行批量预测,无返回值"""
         for index in tqdm(
             range(0, len(table_res_list), batch_size),
-            desc=f"Table Batch Predict, total={len(table_res_list)}, batch_size={batch_size}",
+            desc=f"Wireless Table Batch Predict, total={len(table_res_list)}, batch_size={batch_size}",
         ):
             batch_imgs = [
                 cv2.cvtColor(np.asarray(table_res_list[i]["table_img"]), cv2.COLOR_RGB2BGR)

+ 3 - 30
mineru/model/table/rec/unet_table/main.py

@@ -248,7 +248,7 @@ class UnetTableModel:
         self.wireless_table_model = RapidTable(wireless_input_args)
         self.ocr_engine = ocr_engine
 
-    def predict(self, input_img, table_cls_score):
+    def predict(self, input_img, table_cls_score, wireless_html_code):
         if isinstance(input_img, Image.Image):
             np_img = np.asarray(input_img)
         elif isinstance(input_img, np.ndarray):
@@ -269,29 +269,7 @@ class UnetTableModel:
             try:
                 wired_table_results = self.wired_table_model(np_img, ocr_result)
 
-                # viser = VisTable()
-                # save_html_path = f"outputs/output.html"
-                # save_drawed_path = f"outputs/output_table_vis.jpg"
-                # save_logic_path = (
-                #     f"outputs/output_table_vis_logic.jpg"
-                # )
-                # vis_imged = viser(
-                #     np_img, wired_table_results, save_html_path, save_drawed_path, save_logic_path
-                # )
-
                 wired_html_code = wired_table_results.pred_html
-                wired_table_cell_bboxes = wired_table_results.cell_bboxes
-                wired_logic_points = wired_table_results.logic_points
-                wired_elapse = wired_table_results.elapse
-
-                wireless_table_results = self.wireless_table_model(np_img, ocr_result)
-                wireless_html_code = wireless_table_results.pred_html
-                wireless_table_cell_bboxes = wireless_table_results.cell_bboxes
-                wireless_logic_points = wireless_table_results.logic_points
-                wireless_elapse = wireless_table_results.elapse
-
-                # wired_len = len(wired_table_cell_bboxes) if wired_table_cell_bboxes is not None else 0
-                # wireless_len = len(wireless_table_cell_bboxes) if wireless_table_cell_bboxes is not None else 0
 
                 wired_len = count_table_cells_physical(wired_html_code)
                 wireless_len = count_table_cells_physical(wireless_html_code)
@@ -308,15 +286,10 @@ class UnetTableModel:
                 ):
                     # logger.debug("fall back to wireless table model")
                     html_code = wireless_html_code
-                    table_cell_bboxes = wireless_table_cell_bboxes
-                    logic_points = wireless_logic_points
                 else:
                     html_code = wired_html_code
-                    table_cell_bboxes = wired_table_cell_bboxes
-                    logic_points = wired_logic_points
 
-                elapse = wired_elapse + wireless_elapse
-                return html_code, table_cell_bboxes, logic_points, elapse
+                return html_code
             except Exception as e:
                 logger.exception(e)
-        return None, None, None, None
+        return None