Bläddra i källkod

fix: add tdqm for wired table, remove import, remove img ori cls lang group

Sidney233 2 månader sedan
förälder
incheckning
832d28e512

+ 11 - 13
mineru/backend/pipeline/batch_analyze.py

@@ -10,10 +10,8 @@ from .model_init import AtomModelSingleton
 from .model_list import AtomicModel
 from ...utils.config_reader import get_formula_enable, get_table_enable
 from ...utils.model_utils import crop_img, get_res_list_from_layout_res
-from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence, get_rotate_crop_image
-from ...utils.pdf_image_tools import get_crop_img
 from ...utils.ocr_utils import merge_det_boxes, update_det_boxes, sorted_boxes
-from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
+from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence, get_rotate_crop_image
 from ...utils.pdf_image_tools import get_crop_np_img
 
 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
@@ -195,9 +193,6 @@ class BatchAnalyze:
 
                         if dt_boxes is not None and len(dt_boxes) > 0:
                             # 直接应用原始OCR流程中的关键处理步骤
-                            from mineru.utils.ocr_utils import (
-                                merge_det_boxes, update_det_boxes, sorted_boxes
-                            )
 
                             # 1. 排序检测框
                             if len(dt_boxes) > 0:
@@ -267,7 +262,7 @@ class BatchAnalyze:
                 atom_model_name=AtomicModel.ImgOrientationCls,
             )
             try:
-                img_orientation_cls_model.batch_predict(table_res_list_all_page, atom_model_manager, AtomicModel.OCR, self.batch_ratio * OCR_DET_BASE_BATCH_SIZE)
+                img_orientation_cls_model.batch_predict(table_res_list_all_page, self.batch_ratio * OCR_DET_BASE_BATCH_SIZE)
             except Exception as e:
                 logger.warning(
                     f"Image orientation classification failed: {e}, using original image"
@@ -338,22 +333,25 @@ class BatchAnalyze:
             wireless_table_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.WirelessTable,
             )
-
             wireless_table_model.batch_predict(table_res_list_all_page)
+
+            # 单独拿出有线表格进行预测
+            wired_table_res_list = []
+            for table_res_dict in table_res_list_all_page:
+                if table_res_dict["table_res"]["cls_label"] == AtomicModel.WiredTable:
+                    wired_table_res_list.append(table_res_dict)
             for table_res_dict in tqdm(
-                table_res_list_all_page, desc="Wired Table Predict"
+                wired_table_res_list, desc="Wired Table Predict"
             ):
                 if table_res_dict["table_res"]["cls_label"] == AtomicModel.WiredTable:
                     wired_table_model = atom_model_manager.get_atom_model(
                         atom_model_name=AtomicModel.WiredTable,
                         lang=table_res_dict["lang"],
                     )
-                    if table_res_dict["table_res"].get("html") is None:
-                        logger.warning("Table Wireless Predict Error.")
                     html_code = wired_table_model.predict(
                         table_res_dict["table_img"],
-                        table_res_dict["table_res"]["cls_score"],
-                        table_res_dict["table_res"]["html"],
+                        table_res_dict["ocr_result"],
+                        table_res_dict["table_res"].get("html", None)
                     )
                     # 检查html_code是否包含'<table>'和'</table>'
                     if "<table>" in html_code and "</table>" in html_code:

+ 2 - 2
mineru/backend/pipeline/model_init.py

@@ -33,7 +33,7 @@ def table_cls_model_init():
     return PaddleTableClsModel()
 
 
-def wired_table_model_init(lang="ch"):
+def wired_table_model_init(lang=None):
     atom_model_manager = AtomModelSingleton()
     ocr_engine = atom_model_manager.get_atom_model(
         atom_model_name=AtomicModel.OCR,
@@ -46,7 +46,7 @@ def wired_table_model_init(lang="ch"):
     return table_model
 
 
-def wireless_table_model_init(lang="ch"):
+def wireless_table_model_init(lang=None):
     atom_model_manager = AtomModelSingleton()
     ocr_engine = atom_model_manager.get_atom_model(
         atom_model_name=AtomicModel.OCR,

+ 64 - 79
mineru/model/ori_cls/paddle_ori_cls.py

@@ -174,13 +174,14 @@ class PaddleOrientationClsModel:
         return x
 
     def batch_predict(
-        self, imgs: List[Dict], atom_model_manager, ocr_model_name: str, batch_size: int
+        self, imgs: List[Dict], batch_size: int
     ) -> None:
         """
         批量预测传入的包含图片信息列表的旋转信息,并且将旋转过的图片正确地旋转回来
         """
-        # 按语言分组,跳过长宽比小于1.2的图片
-        lang_groups = defaultdict(list)
+        RESOLUTION_GROUP_STRIDE = 64
+        # 跳过长宽比小于1.2的图片
+        resolution_groups = defaultdict(list)
         for img in imgs:
             # RGB图像转换BGR
             table_img: np.ndarray = cv2.cvtColor(img["table_img"], cv2.COLOR_RGB2BGR)
@@ -189,89 +190,73 @@ class PaddleOrientationClsModel:
             img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
             img_is_portrait = img_aspect_ratio > 1.2
             if img_is_portrait:
-                lang = img["lang"]
-                lang_groups[lang].append(img)
-
-        # 对每种语言按分辨率分组并批处理
-        for lang, lang_group_img_list in lang_groups.items():
-            if not lang_group_img_list:
-                continue
-
-            # 获取OCR模型
-            ocr_model = atom_model_manager.get_atom_model(
-                atom_model_name=ocr_model_name, det_db_box_thresh=0.3, lang=lang
-            )
-
-            # 按分辨率分组并同时完成padding
-            resolution_groups = defaultdict(list)
-            for img in lang_group_img_list:
                 h, w = img["table_img_bgr"].shape[:2]
-                normalized_h = ((h + 32) // 32) * 32  # 向上取整到32的倍数
-                normalized_w = ((w + 32) // 32) * 32
+                normalized_h = ((h + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE  # 向上取整到RESOLUTION_GROUP_STRIDE的倍数
+                normalized_w = ((w + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
                 group_key = (normalized_h, normalized_w)
                 resolution_groups[group_key].append(img)
 
             # 对每个分辨率组进行批处理
-            for group_key, group_imgs in tqdm(
-                resolution_groups.items(), desc=f"ORI CLS OCR-det {lang}"
-            ):
-
-                # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
-                max_h = max(img["table_img_bgr"].shape[0] for img in group_imgs)
-                max_w = max(img["table_img_bgr"].shape[1] for img in group_imgs)
-                target_h = ((max_h + 32 - 1) // 32) * 32
-                target_w = ((max_w + 32 - 1) // 32) * 32
-
-                # 对所有图像进行padding到统一尺寸
-                batch_images = []
-                for img in group_imgs:
-                    table_img_ndarray = img["table_img_bgr"]
-                    h, w = table_img_ndarray.shape[:2]
-                    # 创建目标尺寸的白色背景
-                    padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
-                    # 将原图像粘贴到左上角
-                    padded_img[:h, :w] = table_img_ndarray
-                    batch_images.append(padded_img)
-
-                # 批处理检测
-                det_batch_size = min(len(batch_images), batch_size)  # 增加批处理大小
-                batch_results = ocr_model.text_detector.batch_predict(
-                    batch_images, det_batch_size
-                )
+        for group_key, group_imgs in tqdm(
+            resolution_groups.items(), desc=f"ORI CLS OCR-det"
+        ):
+
+            # 计算目标尺寸(组内最大尺寸,向上取整到RESOLUTION_GROUP_STRIDE的倍数)
+            max_h = max(img["table_img_bgr"].shape[0] for img in group_imgs)
+            max_w = max(img["table_img_bgr"].shape[1] for img in group_imgs)
+            target_h = ((max_h + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
+            target_w = ((max_w + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
+
+            # 对所有图像进行padding到统一尺寸
+            batch_images = []
+            for img in group_imgs:
+                table_img_ndarray = img["table_img_bgr"]
+                h, w = table_img_ndarray.shape[:2]
+                # 创建目标尺寸的白色背景
+                padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
+                # 将原图像粘贴到左上角
+                padded_img[:h, :w] = table_img_ndarray
+                batch_images.append(padded_img)
+
+            # 批处理检测
+            det_batch_size = min(len(batch_images), batch_size)  # 增加批处理大小
+            batch_results = self.ocr_engine.text_detector.batch_predict(
+                batch_images, det_batch_size
+            )
 
-                rotated_imgs = []
-                # 根据批处理结果检测图像是否旋转,旋转的图像放入列表中,继续进行旋转角度的预测
-                for index, (img_info, (dt_boxes, elapse)) in enumerate(
-                    zip(group_imgs, batch_results)
-                ):
-                    vertical_count = 0
-                    for box_ocr_res in dt_boxes:
-                        p1, p2, p3, p4 = box_ocr_res
+            rotated_imgs = []
+            # 根据批处理结果检测图像是否旋转,旋转的图像放入列表中,继续进行旋转角度的预测
+            for index, (img_info, (dt_boxes, elapse)) in enumerate(
+                zip(group_imgs, batch_results)
+            ):
+                vertical_count = 0
+                for box_ocr_res in dt_boxes:
+                    p1, p2, p3, p4 = box_ocr_res
 
-                        # Calculate width and height
-                        width = p3[0] - p1[0]
-                        height = p3[1] - p1[1]
+                    # Calculate width and height
+                    width = p3[0] - p1[0]
+                    height = p3[1] - p1[1]
 
-                        aspect_ratio = width / height if height > 0 else 1.0
+                    aspect_ratio = width / height if height > 0 else 1.0
 
-                        # Count vertical text boxes
-                        if aspect_ratio < 0.8:  # Taller than wide - vertical text
-                            vertical_count += 1
+                    # Count vertical text boxes
+                    if aspect_ratio < 0.8:  # Taller than wide - vertical text
+                        vertical_count += 1
 
-                    if vertical_count >= len(dt_boxes) * 0.28 and vertical_count >= 3:
-                        rotated_imgs.append(img_info)
-                if len(rotated_imgs) > 0:
-                    x = self.batch_preprocess(rotated_imgs)
-                    results = self.sess.run(None, {"x": x})
-                    for img_info, res in zip(rotated_imgs, results[0]):
-                        label = self.labels[np.argmax(res)]
-                        if label == "270":
-                            img_info["table_img"] = cv2.rotate(
-                                np.asarray(img_info["table_img"]),
-                                cv2.ROTATE_90_CLOCKWISE,
-                            )
-                        elif label == "90":
-                            img_info["table_img"] = cv2.rotate(
-                                np.asarray(img_info["table_img"]),
-                                cv2.ROTATE_90_COUNTERCLOCKWISE,
-                            )
+                if vertical_count >= len(dt_boxes) * 0.28 and vertical_count >= 3:
+                    rotated_imgs.append(img_info)
+            if len(rotated_imgs) > 0:
+                x = self.batch_preprocess(rotated_imgs)
+                results = self.sess.run(None, {"x": x})
+                for img_info, res in zip(rotated_imgs, results[0]):
+                    label = self.labels[np.argmax(res)]
+                    if label == "270":
+                        img_info["table_img"] = cv2.rotate(
+                            np.asarray(img_info["table_img"]),
+                            cv2.ROTATE_90_CLOCKWISE,
+                        )
+                    elif label == "90":
+                        img_info["table_img"] = cv2.rotate(
+                            np.asarray(img_info["table_img"]),
+                            cv2.ROTATE_90_COUNTERCLOCKWISE,
+                        )

+ 3 - 2
mineru/model/table/cls/paddle_table_cls.py

@@ -145,9 +145,10 @@ class PaddleTableClsModel:
             for img_res in result[0]:
                 idx = np.argmax(img_res)
                 conf = float(np.max(img_res))
-                if idx == 0 and conf < 0.9:
+                # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
+                if idx == 0 and conf < 0.8:
                     idx = 1
                 label_res.append((self.labels[idx],conf))
         for img_info, (label, conf) in zip(img_info_list, label_res):
             img_info['table_res']["cls_label"] = label
-            img_info['table_res']["cls_score"] = conf
+            img_info['table_res']["cls_score"] = conf

+ 2 - 2
mineru/model/table/rec/slanet_plus/main.py

@@ -270,9 +270,9 @@ class RapidTableModel(object):
                         table_res_list[index + i]['table_res']['html'] = result.pred_html[start_index:end_index]
                     else:
                         logger.warning(
-                            'table recognition processing fails, not found expected HTML table end'
+                            'wireless table recognition processing fails, not found expected HTML table end'
                         )
                 else:
                     logger.warning(
-                        "table recognition processing fails, not get html return"
+                        "wireless table recognition processing fails, not get html return"
                     )

+ 6 - 9
mineru/model/table/rec/unet_table/main.py

@@ -17,7 +17,7 @@ from .table_structure_unet import TSRUnet
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 from .table_recover import TableRecover
-from .utils import InputType, LoadImage, VisTable
+from .utils import InputType, LoadImage
 from .utils_table_recover import (
     match_ocr_cell,
     plot_html_table,
@@ -243,12 +243,9 @@ class UnetTableModel:
         model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.unet_structure), ModelPath.unet_structure)
         wired_input_args = WiredTableInput(model_path=model_path)
         self.wired_table_model = WiredTableRecognition(wired_input_args, ocr_engine)
-        slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
-        wireless_input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
-        self.wireless_table_model = RapidTable(wireless_input_args)
         self.ocr_engine = ocr_engine
 
-    def predict(self, input_img, table_cls_score, wireless_html_code):
+    def predict(self, input_img, ocr_result, wireless_html_code):
         if isinstance(input_img, Image.Image):
             np_img = np.asarray(input_img)
         elif isinstance(input_img, np.ndarray):
@@ -256,15 +253,15 @@ class UnetTableModel:
         else:
             raise ValueError("Input must be a pillow object or a numpy array.")
         bgr_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
-        ocr_result = self.ocr_engine.ocr(bgr_img)[0]
-        if ocr_result:
+
+        if ocr_result is None:
+            ocr_result = self.ocr_engine.ocr(bgr_img)[0]
             ocr_result = [
                 [item[0], escape_html(item[1][0]), item[1][1]]
                 for item in ocr_result
                 if len(item) == 2 and isinstance(item[1], tuple)
             ]
-        else:
-            ocr_result = None
+
         if ocr_result:
             try:
                 wired_table_results = self.wired_table_model(np_img, ocr_result)