Forráskód Böngészése

feat: implement RapidTable model for enhanced table structure prediction and batch processing

myhloli 2 hónapja
szülő
commit
d0e68a3018

+ 7 - 8
mineru/backend/pipeline/batch_analyze.py

@@ -137,18 +137,17 @@ class BatchAnalyze:
 
             # OCR det 过程,顺序执行
             rec_img_lang_group = defaultdict(list)
+            det_ocr_engine = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.OCR,
+                det_db_box_thresh=0.5,
+                det_db_unclip_ratio=1.6,
+                enable_merge_det_boxes=False,
+            )
             for index, table_res_dict in enumerate(
                     tqdm(table_res_list_all_page, desc="Table-ocr det")
             ):
-                ocr_engine = atom_model_manager.get_atom_model(
-                    atom_model_name=AtomicModel.OCR,
-                    det_db_box_thresh=0.5,
-                    det_db_unclip_ratio=1.6,
-                    # lang= table_res_dict["lang"],
-                    enable_merge_det_boxes=False,
-                )
                 bgr_image = cv2.cvtColor(table_res_dict["table_img"], cv2.COLOR_RGB2BGR)
-                ocr_result = ocr_engine.ocr(bgr_image, rec=False)[0]
+                ocr_result = det_ocr_engine.ocr(bgr_image, rec=False)[0]
                 # 构造需要 OCR 识别的图片字典,包括cropped_img, dt_box, table_id,并按照语言进行分组
                 for dt_box in ocr_result:
                     rec_img_lang_group[_lang].append(

+ 1 - 0
mineru/backend/pipeline/model_init.py

@@ -10,6 +10,7 @@ from ...model.mfr.unimernet.Unimernet import UnimernetModel
 from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
 from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
 from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
+# from ...model.table.rec.RapidTable import RapidTableModel
 from ...model.table.rec.slanet_plus.main import RapidTableModel
 from ...model.table.rec.unet_table.main import UnetTableModel
 from ...utils.enum_class import ModelPath

+ 154 - 0
mineru/model/table/rec/RapidTable.py

@@ -0,0 +1,154 @@
+import html
+import os
+import time
+from pathlib import Path
+from typing import List
+
+import cv2
+import numpy as np
+from loguru import logger
+from rapid_table import ModelType, RapidTable, RapidTableInput
+from rapid_table.utils import RapidTableOutput
+from tqdm import tqdm
+
+from mineru.model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
+
+
+def escape_html(input_string):
+    """Escape HTML Entities."""
+    return html.escape(input_string)
+
+
+class CustomRapidTable(RapidTable):
+    def __init__(self, cfg: RapidTableInput):
+        import logging
+        # 通过环境变量控制日志级别
+        logging.disable(logging.INFO)
+        super().__init__(cfg)
+    def __call__(self, img_contents, ocr_results=None, batch_size=1):
+        if not isinstance(img_contents, list):
+            img_contents = [img_contents]
+
+        s = time.perf_counter()
+
+        results = RapidTableOutput()
+
+        total_nums = len(img_contents)
+
+        with tqdm(total=total_nums, desc="Table-wireless Predict") as pbar:
+            for start_i in range(0, total_nums, batch_size):
+                end_i = min(total_nums, start_i + batch_size)
+
+                imgs = self._load_imgs(img_contents[start_i:end_i])
+
+                pred_structures, cell_bboxes = self.table_structure(imgs)
+                logic_points = self.table_matcher.decode_logic_points(pred_structures)
+
+                dt_boxes, rec_res = self.get_ocr_results(imgs, start_i, end_i, ocr_results)
+                pred_htmls = self.table_matcher(
+                    pred_structures, cell_bboxes, dt_boxes, rec_res
+                )
+
+                results.pred_htmls.extend(pred_htmls)
+                # 更新进度条
+                pbar.update(end_i - start_i)
+
+        elapse = time.perf_counter() - s
+        results.elapse = elapse / total_nums
+        return results
+
+
+class RapidTableModel():
+    def __init__(self, ocr_engine):
+        slanet_plus_model_path = os.path.join(
+            auto_download_and_get_model_root_path(ModelPath.slanet_plus),
+            ModelPath.slanet_plus,
+        )
+        input_args = RapidTableInput(
+            model_type=ModelType.SLANETPLUS,
+            model_dir_or_path=slanet_plus_model_path,
+            use_ocr=False
+        )
+        self.table_model = CustomRapidTable(input_args)
+        self.ocr_engine = ocr_engine
+
+    def predict(self, image, ocr_result=None):
+        bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
+        # Continue with OCR on potentially rotated image
+
+        if not ocr_result:
+            raw_ocr_result = self.ocr_engine.ocr(bgr_image)[0]
+            # 分离边界框、文本和置信度
+            boxes = []
+            texts = []
+            scores = []
+            for item in raw_ocr_result:
+                if len(item) == 3:
+                    boxes.append(item[0])
+                    texts.append(escape_html(item[1]))
+                    scores.append(item[2])
+                elif len(item) == 2 and isinstance(item[1], tuple):
+                    boxes.append(item[0])
+                    texts.append(escape_html(item[1][0]))
+                    scores.append(item[1][1])
+            # 按照 rapid_table 期望的格式构建 ocr_results
+            ocr_result = [(boxes, texts, scores)]
+
+        if ocr_result:
+            try:
+                table_results = self.table_model(img_contents=np.asarray(image), ocr_results=ocr_result)
+                html_code = table_results.pred_htmls
+                table_cell_bboxes = table_results.cell_bboxes
+                logic_points = table_results.logic_points
+                elapse = table_results.elapse
+                return html_code, table_cell_bboxes, logic_points, elapse
+            except Exception as e:
+                logger.exception(e)
+
+        return None, None, None, None
+
+    def batch_predict(self, table_res_list: List[dict], batch_size: int = 4):
+        not_none_table_res_list = []
+        for table_res in table_res_list:
+            if table_res.get("ocr_result", None):
+                not_none_table_res_list.append(table_res)
+
+        if not_none_table_res_list:
+            img_contents = [table_res["table_img"] for table_res in not_none_table_res_list]
+            ocr_results = []
+            # ocr_results需要按照rapid_table期望的格式构建
+            for table_res in not_none_table_res_list:
+                raw_ocr_result = table_res["ocr_result"]
+                boxes = []
+                texts = []
+                scores = []
+                for item in raw_ocr_result:
+                    if len(item) == 3:
+                        boxes.append(item[0])
+                        texts.append(escape_html(item[1]))
+                        scores.append(item[2])
+                    elif len(item) == 2 and isinstance(item[1], tuple):
+                        boxes.append(item[0])
+                        texts.append(escape_html(item[1][0]))
+                        scores.append(item[1][1])
+                ocr_results.append((boxes, texts, scores))
+            table_results = self.table_model(img_contents=img_contents, ocr_results=ocr_results, batch_size=batch_size)
+
+            for i, result in enumerate(table_results.pred_htmls):
+                if result:
+                    not_none_table_res_list[i]['table_res']['html'] = result
+
+if __name__ == '__main__':
+    ocr_engine= PytorchPaddleOCR(
+            det_db_box_thresh=0.5,
+            det_db_unclip_ratio=1.6,
+            enable_merge_det_boxes=False,
+    )
+    table_model = RapidTableModel(ocr_engine)
+    img_path = Path(r"D:\project\20240729ocrtest\pythonProject\images\601c939cc6dabaf07af763e2f935f54896d0251f37cc47beb7fc6b069353455d.jpg")
+    image = cv2.imread(str(img_path))
+    html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(image)
+    print(html_code)
+

+ 1 - 1
mineru/utils/model_utils.py

@@ -434,7 +434,7 @@ def clean_vram(device, vram_threshold=8):
         gc_start = time.time()
         clean_memory(device)
         gc_time = round(time.time() - gc_start, 2)
-        logger.info(f"gc time: {gc_time}")
+        # logger.info(f"gc time: {gc_time}")
 
 
 def get_vram(device):