|
|
@@ -141,34 +141,35 @@ class CustomPEKModel:
|
|
|
layout_cost = round(time.time() - layout_start, 2)
|
|
|
logger.info(f"layout detection cost: {layout_cost}")
|
|
|
|
|
|
- # 公式检测
|
|
|
- mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
|
|
|
- for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
|
|
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
|
- new_item = {
|
|
|
- 'category_id': 13 + int(cla.item()),
|
|
|
- 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
|
- 'score': round(float(conf.item()), 2),
|
|
|
- 'latex': '',
|
|
|
- }
|
|
|
- layout_res.append(new_item)
|
|
|
- latex_filling_list.append(new_item)
|
|
|
- bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
|
|
|
- mf_image_list.append(bbox_img)
|
|
|
-
|
|
|
- # 公式识别
|
|
|
- mfr_start = time.time()
|
|
|
- dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
|
|
- dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
|
|
|
- mfr_res = []
|
|
|
- for mf_img in dataloader:
|
|
|
- mf_img = mf_img.to(self.device)
|
|
|
- output = self.mfr_model.generate({'image': mf_img})
|
|
|
- mfr_res.extend(output['pred_str'])
|
|
|
- for res, latex in zip(latex_filling_list, mfr_res):
|
|
|
- res['latex'] = latex_rm_whitespace(latex)
|
|
|
- mfr_cost = round(time.time() - mfr_start, 2)
|
|
|
- logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
|
|
|
+ if self.apply_formula:
|
|
|
+ # 公式检测
|
|
|
+ mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
|
|
|
+ for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
|
|
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
|
+ new_item = {
|
|
|
+ 'category_id': 13 + int(cla.item()),
|
|
|
+ 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
|
+ 'score': round(float(conf.item()), 2),
|
|
|
+ 'latex': '',
|
|
|
+ }
|
|
|
+ layout_res.append(new_item)
|
|
|
+ latex_filling_list.append(new_item)
|
|
|
+ bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
|
|
|
+ mf_image_list.append(bbox_img)
|
|
|
+
|
|
|
+ # 公式识别
|
|
|
+ mfr_start = time.time()
|
|
|
+ dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
|
|
+ dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
|
|
|
+ mfr_res = []
|
|
|
+ for mf_img in dataloader:
|
|
|
+ mf_img = mf_img.to(self.device)
|
|
|
+ output = self.mfr_model.generate({'image': mf_img})
|
|
|
+ mfr_res.extend(output['pred_str'])
|
|
|
+ for res, latex in zip(latex_filling_list, mfr_res):
|
|
|
+ res['latex'] = latex_rm_whitespace(latex)
|
|
|
+ mfr_cost = round(time.time() - mfr_start, 2)
|
|
|
+ logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
|
|
|
|
|
|
# ocr识别
|
|
|
if self.apply_ocr:
|