Ver código fonte

update:Complete the parsing logic of PEK

myhloli 1 ano atrás
pai
commit
831db2e009
1 arquivos alterados com 104 adições e 2 exclusões
  1. 104 2
      magic_pdf/model/pdf_extract_kit.py

+ 104 - 2
magic_pdf/model/pdf_extract_kit.py

@@ -1,6 +1,10 @@
 import os
+import time
+
+import cv2
 import numpy as np
 import yaml
+from PIL import Image
 from ultralytics import YOLO
 from loguru import logger
 from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
@@ -9,7 +13,9 @@ import unimernet.tasks as tasks
 from unimernet.processors import load_processor
 import argparse
 from torchvision import transforms
+from torch.utils.data import Dataset, DataLoader
 
+from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
 from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
 
 
@@ -31,6 +37,25 @@ def mfr_model_init(weight_dir, cfg_path, device='cpu'):
     return model, vis_processor
 
 
+class MathDataset(Dataset):
+    def __init__(self, image_paths, transform=None):
+        self.image_paths = image_paths
+        self.transform = transform
+
+    def __len__(self):
+        return len(self.image_paths)
+
+    def __getitem__(self, idx):
+        # if not pil image, then convert to pil image
+        if isinstance(self.image_paths[idx], str):
+            raw_image = Image.open(self.image_paths[idx])
+        else:
+            raw_image = self.image_paths[idx]
+        if self.transform:
+            image = self.transform(raw_image)
+        return image
+
+
 class CustomPEKModel:
     def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
         """
@@ -82,6 +107,83 @@ class CustomPEKModel:
 
         logger.info('DocAnalysis init done!')
 
+    def __call__(self, images):
+        # layout检测 + 公式检测
+        doc_layout_result = []
+        latex_filling_list = []
+        mf_image_list = []
+        for idx, img_dict in enumerate(images):
+            image = img_dict["img"]
+            img_height, img_width = img_dict["height"], img_dict["width"]
+            layout_res = self.layout_model(image, ignore_catids=[])
+            # 公式检测
+            mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
+            for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
+                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+                new_item = {
+                    'category_id': 13 + int(cla.item()),
+                    'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                    'score': round(float(conf.item()), 2),
+                    'latex': '',
+                }
+                layout_res['layout_dets'].append(new_item)
+                latex_filling_list.append(new_item)
+                bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
+                mf_image_list.append(bbox_img)
+
+            layout_res['page_info'] = dict(
+                page_no=idx,
+                height=img_height,
+                width=img_width
+            )
+            doc_layout_result.append(layout_res)
+
+        # 公式识别,因为识别速度较慢,为了提速,把单个pdf的所有公式裁剪完,一起批量做识别。
+        a = time.time()
+        dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
+        dataloader = DataLoader(dataset, batch_size=128, num_workers=0)
+        mfr_res = []
+        for imgs in dataloader:
+            imgs = imgs.to(self.device)
+            output = self.mfr_model.generate({'image': imgs})
+            mfr_res.extend(output['pred_str'])
+        for res, latex in zip(latex_filling_list, mfr_res):
+            res['latex'] = latex_rm_whitespace(latex)
+        b = time.time()
+        logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {round(b - a, 2)}")
+
+        if self.apply_ocr:
+            # ocr识别
+            for idx, img_dict in enumerate(images):
+                image = img_dict["img"]
+                pil_img = Image.fromarray(image)
+                single_page_res = doc_layout_result[idx]['layout_dets']
+                single_page_mfdetrec_res = []
+                for res in single_page_res:
+                    if int(res['category_id']) in [13, 14]:
+                        xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
+                        xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
+                        single_page_mfdetrec_res.append({
+                            "bbox": [xmin, ymin, xmax, ymax],
+                        })
+                for res in single_page_res:
+                    if int(res['category_id']) in [0, 1, 2, 4, 6, 7]:  # 需要进行ocr的类别
+                        xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
+                        xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
+                        crop_box = [xmin, ymin, xmax, ymax]
+                        cropped_img = Image.new('RGB', pil_img.size, 'white')
+                        cropped_img.paste(pil_img.crop(crop_box), crop_box)
+                        cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
+                        ocr_res = self.ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
+                        if ocr_res:
+                            for box_ocr_res in ocr_res:
+                                p1, p2, p3, p4 = box_ocr_res[0]
+                                text, score = box_ocr_res[1]
+                                doc_layout_result[idx]['layout_dets'].append({
+                                    'category_id': 15,
+                                    'poly': p1 + p2 + p3 + p4,
+                                    'score': round(score, 2),
+                                    'text': text,
+                                })
 
-    def __call__(self, image):
-        pass
+        return doc_layout_result