Ver código fonte

fix: merge dev

Sidney233 2 meses atrás
pai
commit
0641cc07f7
1 arquivos alterados com 16 adições e 16 exclusões
  1. 16 16
      mineru/backend/pipeline/batch_analyze.py

+ 16 - 16
mineru/backend/pipeline/batch_analyze.py

@@ -52,22 +52,22 @@ class BatchAnalyze:
             np_images, YOLO_LAYOUT_BASE_BATCH_SIZE
         )
 
-        # if self.formula_enable:
-        #     # 公式检测
-        #     images_mfd_res = self.model.mfd_model.batch_predict(
-        #         np_images, MFD_BASE_BATCH_SIZE
-        #     )
-        #
-        #     # 公式识别
-        #     images_formula_list = self.model.mfr_model.batch_predict(
-        #         images_mfd_res,
-        #         np_images,
-        #         batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
-        #     )
-        #     mfr_count = 0
-        #     for image_index in range(len(np_images)):
-        #         images_layout_res[image_index] += images_formula_list[image_index]
-        #         mfr_count += len(images_formula_list[image_index])
+        if self.formula_enable:
+            # 公式检测
+            images_mfd_res = self.model.mfd_model.batch_predict(
+                np_images, MFD_BASE_BATCH_SIZE
+            )
+
+            # 公式识别
+            images_formula_list = self.model.mfr_model.batch_predict(
+                images_mfd_res,
+                np_images,
+                batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
+            )
+            mfr_count = 0
+            for image_index in range(len(np_images)):
+                images_layout_res[image_index] += images_formula_list[image_index]
+                mfr_count += len(images_formula_list[image_index])
 
         # 清理显存
         # clean_vram(self.model.device, vram_threshold=8)