Forráskód Böngészése

feat(magic-pdf): add conditional application of formula detection and recognition

赵小蒙 1 éve
szülő
commit
4c39bcd366
1 módosított fájl, 29 hozzáadás és 28 törlés
  1. 29 28
      magic_pdf/model/pdf_extract_kit.py

+ 29 - 28
magic_pdf/model/pdf_extract_kit.py

@@ -141,34 +141,35 @@ class CustomPEKModel:
         layout_cost = round(time.time() - layout_start, 2)
         logger.info(f"layout detection cost: {layout_cost}")
 
-        # 公式检测
-        mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
-        for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
-            xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
-            new_item = {
-                'category_id': 13 + int(cla.item()),
-                'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
-                'score': round(float(conf.item()), 2),
-                'latex': '',
-            }
-            layout_res.append(new_item)
-            latex_filling_list.append(new_item)
-            bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
-            mf_image_list.append(bbox_img)
-
-        # 公式识别
-        mfr_start = time.time()
-        dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
-        dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
-        mfr_res = []
-        for mf_img in dataloader:
-            mf_img = mf_img.to(self.device)
-            output = self.mfr_model.generate({'image': mf_img})
-            mfr_res.extend(output['pred_str'])
-        for res, latex in zip(latex_filling_list, mfr_res):
-            res['latex'] = latex_rm_whitespace(latex)
-        mfr_cost = round(time.time() - mfr_start, 2)
-        logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
+        if self.apply_formula:
+            # 公式检测
+            mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
+            for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
+                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+                new_item = {
+                    'category_id': 13 + int(cla.item()),
+                    'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                    'score': round(float(conf.item()), 2),
+                    'latex': '',
+                }
+                layout_res.append(new_item)
+                latex_filling_list.append(new_item)
+                bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
+                mf_image_list.append(bbox_img)
+
+            # 公式识别
+            mfr_start = time.time()
+            dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
+            dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
+            mfr_res = []
+            for mf_img in dataloader:
+                mf_img = mf_img.to(self.device)
+                output = self.mfr_model.generate({'image': mf_img})
+                mfr_res.extend(output['pred_str'])
+            for res, latex in zip(latex_filling_list, mfr_res):
+                res['latex'] = latex_rm_whitespace(latex)
+            mfr_cost = round(time.time() - mfr_start, 2)
+            logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
 
         # ocr识别
         if self.apply_ocr: