Răsfoiți Sursa

fix: merge dev

Sidney233 2 luni în urmă
părinte
comite
320cd60c81

+ 17 - 19
mineru/backend/pipeline/batch_analyze.py

@@ -52,22 +52,22 @@ class BatchAnalyze:
             np_images, YOLO_LAYOUT_BASE_BATCH_SIZE
         )
 
-        if self.formula_enable:
-            # 公式检测
-            images_mfd_res = self.model.mfd_model.batch_predict(
-                np_images, MFD_BASE_BATCH_SIZE
-            )
-
-            # 公式识别
-            images_formula_list = self.model.mfr_model.batch_predict(
-                images_mfd_res,
-                np_images,
-                batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
-            )
-            mfr_count = 0
-            for image_index in range(len(np_images)):
-                images_layout_res[image_index] += images_formula_list[image_index]
-                mfr_count += len(images_formula_list[image_index])
+        # if self.formula_enable:
+        #     # 公式检测
+        #     images_mfd_res = self.model.mfd_model.batch_predict(
+        #         np_images, MFD_BASE_BATCH_SIZE
+        #     )
+        #
+        #     # 公式识别
+        #     images_formula_list = self.model.mfr_model.batch_predict(
+        #         images_mfd_res,
+        #         np_images,
+        #         batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
+        #     )
+        #     mfr_count = 0
+        #     for image_index in range(len(np_images)):
+        #         images_layout_res[image_index] += images_formula_list[image_index]
+        #         mfr_count += len(images_formula_list[image_index])
 
         # 清理显存
         # clean_vram(self.model.device, vram_threshold=8)
@@ -326,9 +326,7 @@ class BatchAnalyze:
                 # 按照 table_id 将识别结果进行回填
                 for img_dict, ocr_res in zip(rec_img_list, ocr_res_list):
                     if table_res_list_all_page[img_dict["table_id"]].get("ocr_result"):
-                        table_res_list_all_page[img_dict["table_id"]][
-                            "ocr_result"
-                        ].append(
+                        table_res_list_all_page[img_dict["table_id"]]["ocr_result"].append(
                             [img_dict["dt_box"], html.escape(ocr_res[0]), ocr_res[1]]
                         )
                     else:

+ 13 - 19
mineru/model/ori_cls/paddle_ori_cls.py

@@ -182,11 +182,9 @@ class PaddleOrientationClsModel:
         # 按语言分组,跳过长宽比小于1.2的图片
         lang_groups = defaultdict(list)
         for img in imgs:
-            # PIL RGB图像转换BGR
-            table_img: np.ndarray = cv2.cvtColor(
-                np.asarray(img["table_img"]), cv2.COLOR_RGB2BGR
-            )
-            img["table_img_ndarray"] = table_img
+            # RGB图像转换BGR
+            table_img: np.ndarray = cv2.cvtColor(img["table_img"], cv2.COLOR_RGB2BGR)
+            img["table_img_bgr"] = table_img
             img_height, img_width = table_img.shape[:2]
             img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
             img_is_portrait = img_aspect_ratio > 1.2
@@ -207,7 +205,7 @@ class PaddleOrientationClsModel:
             # 按分辨率分组并同时完成padding
             resolution_groups = defaultdict(list)
             for img in lang_group_img_list:
-                h, w = img["table_img_ndarray"].shape[:2]
+                h, w = img["table_img_bgr"].shape[:2]
                 normalized_h = ((h + 32) // 32) * 32  # 向上取整到32的倍数
                 normalized_w = ((w + 32) // 32) * 32
                 group_key = (normalized_h, normalized_w)
@@ -219,15 +217,15 @@ class PaddleOrientationClsModel:
             ):
 
                 # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
-                max_h = max(img["table_img_ndarray"].shape[0] for img in group_imgs)
-                max_w = max(img["table_img_ndarray"].shape[1] for img in group_imgs)
+                max_h = max(img["table_img_bgr"].shape[0] for img in group_imgs)
+                max_w = max(img["table_img_bgr"].shape[1] for img in group_imgs)
                 target_h = ((max_h + 32 - 1) // 32) * 32
                 target_w = ((max_w + 32 - 1) // 32) * 32
 
                 # 对所有图像进行padding到统一尺寸
                 batch_images = []
                 for img in group_imgs:
-                    table_img_ndarray = img["table_img_ndarray"]
+                    table_img_ndarray = img["table_img_bgr"]
                     h, w = table_img_ndarray.shape[:2]
                     # 创建目标尺寸的白色背景
                     padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
@@ -268,16 +266,12 @@ class PaddleOrientationClsModel:
                     for img_info, res in zip(rotated_imgs, results[0]):
                         label = self.labels[np.argmax(res)]
                         if label == "270":
-                            img_info["table_img"] = Image.fromarray(
-                                cv2.rotate(
-                                    np.asarray(img_info["table_img"]),
-                                    cv2.ROTATE_90_CLOCKWISE,
-                                )
+                            img_info["table_img"] = cv2.rotate(
+                                np.asarray(img_info["table_img"]),
+                                cv2.ROTATE_90_CLOCKWISE,
                             )
                         elif label == "90":
-                            img_info["table_img"] = Image.fromarray(
-                                cv2.rotate(
-                                    np.asarray(img_info["table_img"]),
-                                    cv2.ROTATE_90_COUNTERCLOCKWISE,
-                                )
+                            img_info["table_img"] = cv2.rotate(
+                                np.asarray(img_info["table_img"]),
+                                cv2.ROTATE_90_COUNTERCLOCKWISE,
                             )