|
@@ -1,6 +1,10 @@
|
|
|
import os
|
|
import os
|
|
|
|
|
+import time
|
|
|
|
|
+
|
|
|
|
|
+import cv2
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import yaml
|
|
import yaml
|
|
|
|
|
+from PIL import Image
|
|
|
from ultralytics import YOLO
|
|
from ultralytics import YOLO
|
|
|
from loguru import logger
|
|
from loguru import logger
|
|
|
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
|
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
|
@@ -9,7 +13,9 @@ import unimernet.tasks as tasks
|
|
|
from unimernet.processors import load_processor
|
|
from unimernet.processors import load_processor
|
|
|
import argparse
|
|
import argparse
|
|
|
from torchvision import transforms
|
|
from torchvision import transforms
|
|
|
|
|
+from torch.utils.data import Dataset, DataLoader
|
|
|
|
|
|
|
|
|
|
+from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
|
|
|
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
|
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
|
|
|
|
|
|
|
|
|
|
|
@@ -31,6 +37,25 @@ def mfr_model_init(weight_dir, cfg_path, device='cpu'):
|
|
|
return model, vis_processor
|
|
return model, vis_processor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+class MathDataset(Dataset):
|
|
|
|
|
+ def __init__(self, image_paths, transform=None):
|
|
|
|
|
+ self.image_paths = image_paths
|
|
|
|
|
+ self.transform = transform
|
|
|
|
|
+
|
|
|
|
|
+ def __len__(self):
|
|
|
|
|
+ return len(self.image_paths)
|
|
|
|
|
+
|
|
|
|
|
+ def __getitem__(self, idx):
|
|
|
|
|
+ # if not pil image, then convert to pil image
|
|
|
|
|
+ if isinstance(self.image_paths[idx], str):
|
|
|
|
|
+ raw_image = Image.open(self.image_paths[idx])
|
|
|
|
|
+ else:
|
|
|
|
|
+ raw_image = self.image_paths[idx]
|
|
|
|
|
+ if self.transform:
|
|
|
|
|
+ image = self.transform(raw_image)
|
|
|
|
|
+ return image
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
class CustomPEKModel:
|
|
class CustomPEKModel:
|
|
|
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
|
|
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
|
|
|
"""
|
|
"""
|
|
@@ -82,6 +107,83 @@ class CustomPEKModel:
|
|
|
|
|
|
|
|
logger.info('DocAnalysis init done!')
|
|
logger.info('DocAnalysis init done!')
|
|
|
|
|
|
|
|
|
|
+ def __call__(self, images):
|
|
|
|
|
+ # layout检测 + 公式检测
|
|
|
|
|
+ doc_layout_result = []
|
|
|
|
|
+ latex_filling_list = []
|
|
|
|
|
+ mf_image_list = []
|
|
|
|
|
+ for idx, img_dict in enumerate(images):
|
|
|
|
|
+ image = img_dict["img"]
|
|
|
|
|
+ img_height, img_width = img_dict["height"], img_dict["width"]
|
|
|
|
|
+ layout_res = self.layout_model(image, ignore_catids=[])
|
|
|
|
|
+ # 公式检测
|
|
|
|
|
+ mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
|
|
|
|
|
+ for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
|
|
|
|
+ xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
|
|
|
|
+ new_item = {
|
|
|
|
|
+ 'category_id': 13 + int(cla.item()),
|
|
|
|
|
+ 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
|
|
|
|
+ 'score': round(float(conf.item()), 2),
|
|
|
|
|
+ 'latex': '',
|
|
|
|
|
+ }
|
|
|
|
|
+ layout_res['layout_dets'].append(new_item)
|
|
|
|
|
+ latex_filling_list.append(new_item)
|
|
|
|
|
+ bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
|
|
|
|
|
+ mf_image_list.append(bbox_img)
|
|
|
|
|
+
|
|
|
|
|
+ layout_res['page_info'] = dict(
|
|
|
|
|
+ page_no=idx,
|
|
|
|
|
+ height=img_height,
|
|
|
|
|
+ width=img_width
|
|
|
|
|
+ )
|
|
|
|
|
+ doc_layout_result.append(layout_res)
|
|
|
|
|
+
|
|
|
|
|
+ # 公式识别,因为识别速度较慢,为了提速,把单个pdf的所有公式裁剪完,一起批量做识别。
|
|
|
|
|
+ a = time.time()
|
|
|
|
|
+ dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
|
|
|
|
+ dataloader = DataLoader(dataset, batch_size=128, num_workers=0)
|
|
|
|
|
+ mfr_res = []
|
|
|
|
|
+ for imgs in dataloader:
|
|
|
|
|
+ imgs = imgs.to(self.device)
|
|
|
|
|
+ output = self.mfr_model.generate({'image': imgs})
|
|
|
|
|
+ mfr_res.extend(output['pred_str'])
|
|
|
|
|
+ for res, latex in zip(latex_filling_list, mfr_res):
|
|
|
|
|
+ res['latex'] = latex_rm_whitespace(latex)
|
|
|
|
|
+ b = time.time()
|
|
|
|
|
+ logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {round(b - a, 2)}")
|
|
|
|
|
+
|
|
|
|
|
+ if self.apply_ocr:
|
|
|
|
|
+ # ocr识别
|
|
|
|
|
+ for idx, img_dict in enumerate(images):
|
|
|
|
|
+ image = img_dict["img"]
|
|
|
|
|
+ pil_img = Image.fromarray(image)
|
|
|
|
|
+ single_page_res = doc_layout_result[idx]['layout_dets']
|
|
|
|
|
+ single_page_mfdetrec_res = []
|
|
|
|
|
+ for res in single_page_res:
|
|
|
|
|
+ if int(res['category_id']) in [13, 14]:
|
|
|
|
|
+ xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
|
|
|
|
|
+ xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
|
|
|
|
|
+ single_page_mfdetrec_res.append({
|
|
|
|
|
+ "bbox": [xmin, ymin, xmax, ymax],
|
|
|
|
|
+ })
|
|
|
|
|
+ for res in single_page_res:
|
|
|
|
|
+ if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: # 需要进行ocr的类别
|
|
|
|
|
+ xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
|
|
|
|
|
+ xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
|
|
|
|
|
+ crop_box = [xmin, ymin, xmax, ymax]
|
|
|
|
|
+ cropped_img = Image.new('RGB', pil_img.size, 'white')
|
|
|
|
|
+ cropped_img.paste(pil_img.crop(crop_box), crop_box)
|
|
|
|
|
+ cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
|
|
|
|
|
+ ocr_res = self.ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
|
|
|
|
|
+ if ocr_res:
|
|
|
|
|
+ for box_ocr_res in ocr_res:
|
|
|
|
|
+ p1, p2, p3, p4 = box_ocr_res[0]
|
|
|
|
|
+ text, score = box_ocr_res[1]
|
|
|
|
|
+ doc_layout_result[idx]['layout_dets'].append({
|
|
|
|
|
+ 'category_id': 15,
|
|
|
|
|
+ 'poly': p1 + p2 + p3 + p4,
|
|
|
|
|
+ 'score': round(score, 2),
|
|
|
|
|
+ 'text': text,
|
|
|
|
|
+ })
|
|
|
|
|
|
|
|
- def __call__(self, image):
|
|
|
|
|
- pass
|
|
|
|
|
|
|
+ return doc_layout_result
|