|
|
@@ -1,6 +1,8 @@
|
|
|
from loguru import logger
|
|
|
import os
|
|
|
import time
|
|
|
+
|
|
|
+os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
|
|
try:
|
|
|
import cv2
|
|
|
import yaml
|
|
|
@@ -17,14 +19,17 @@ try:
|
|
|
import unimernet.tasks as tasks
|
|
|
from unimernet.processors import load_processor
|
|
|
|
|
|
- from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
|
|
- from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
|
|
|
- from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
|
|
except ImportError as e:
|
|
|
logger.exception(e)
|
|
|
- logger.error('Required dependency not installed, please install by \n"pip install magic-pdf[full] detectron2 --extra-index-url https://myhloli.github.io/wheels/"')
|
|
|
+ logger.error(
|
|
|
+ 'Required dependency not installed, please install by \n'
|
|
|
+ '"pip install magic-pdf[full] detectron2 --extra-index-url https://myhloli.github.io/wheels/"')
|
|
|
exit(1)
|
|
|
|
|
|
+from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
|
|
+from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
|
|
|
+from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
|
|
+
|
|
|
|
|
|
def mfd_model_init(weight):
|
|
|
mfd_model = YOLO(weight)
|
|
|
@@ -100,6 +105,7 @@ class CustomPEKModel:
|
|
|
self.device = kwargs.get("device", self.configs["config"]["device"])
|
|
|
logger.info("using device: {}".format(self.device))
|
|
|
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
|
|
|
+ logger.info("using models_dir: {}".format(models_dir))
|
|
|
|
|
|
# 初始化公式识别
|
|
|
if self.apply_formula:
|
|
|
@@ -135,34 +141,35 @@ class CustomPEKModel:
|
|
|
layout_cost = round(time.time() - layout_start, 2)
|
|
|
logger.info(f"layout detection cost: {layout_cost}")
|
|
|
|
|
|
- # 公式检测
|
|
|
- mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
|
|
|
- for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
|
|
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
|
- new_item = {
|
|
|
- 'category_id': 13 + int(cla.item()),
|
|
|
- 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
|
- 'score': round(float(conf.item()), 2),
|
|
|
- 'latex': '',
|
|
|
- }
|
|
|
- layout_res.append(new_item)
|
|
|
- latex_filling_list.append(new_item)
|
|
|
- bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
|
|
|
- mf_image_list.append(bbox_img)
|
|
|
-
|
|
|
- # 公式识别
|
|
|
- mfr_start = time.time()
|
|
|
- dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
|
|
- dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
|
|
|
- mfr_res = []
|
|
|
- for mf_img in dataloader:
|
|
|
- mf_img = mf_img.to(self.device)
|
|
|
- output = self.mfr_model.generate({'image': mf_img})
|
|
|
- mfr_res.extend(output['pred_str'])
|
|
|
- for res, latex in zip(latex_filling_list, mfr_res):
|
|
|
- res['latex'] = latex_rm_whitespace(latex)
|
|
|
- mfr_cost = round(time.time() - mfr_start, 2)
|
|
|
- logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
|
|
|
+ if self.apply_formula:
|
|
|
+ # 公式检测
|
|
|
+ mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
|
|
|
+ for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
|
|
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
|
+ new_item = {
|
|
|
+ 'category_id': 13 + int(cla.item()),
|
|
|
+ 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
|
+ 'score': round(float(conf.item()), 2),
|
|
|
+ 'latex': '',
|
|
|
+ }
|
|
|
+ layout_res.append(new_item)
|
|
|
+ latex_filling_list.append(new_item)
|
|
|
+ bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
|
|
|
+ mf_image_list.append(bbox_img)
|
|
|
+
|
|
|
+ # 公式识别
|
|
|
+ mfr_start = time.time()
|
|
|
+ dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
|
|
+ dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
|
|
|
+ mfr_res = []
|
|
|
+ for mf_img in dataloader:
|
|
|
+ mf_img = mf_img.to(self.device)
|
|
|
+ output = self.mfr_model.generate({'image': mf_img})
|
|
|
+ mfr_res.extend(output['pred_str'])
|
|
|
+ for res, latex in zip(latex_filling_list, mfr_res):
|
|
|
+ res['latex'] = latex_rm_whitespace(latex)
|
|
|
+ mfr_cost = round(time.time() - mfr_start, 2)
|
|
|
+ logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
|
|
|
|
|
|
# ocr识别
|
|
|
if self.apply_ocr:
|
|
|
@@ -189,8 +196,8 @@ class CustomPEKModel:
|
|
|
paste_x = 50
|
|
|
paste_y = 50
|
|
|
# 创建一个宽高各多50的白色背景
|
|
|
- new_width = xmax - xmin + paste_x*2
|
|
|
- new_height = ymax - ymin + paste_y*2
|
|
|
+ new_width = xmax - xmin + paste_x * 2
|
|
|
+ new_height = ymax - ymin + paste_y * 2
|
|
|
new_image = Image.new('RGB', (new_width, new_height), 'white')
|
|
|
|
|
|
# 裁剪图像
|