Pārlūkot izejas kodu

update:Modify the PEK module to parse page by page.

myhloli 1 gadu atpakaļ
vecāks
revīzija
2b8db660d1

+ 18 - 14
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -1,7 +1,9 @@
+import time
+
 import fitz
 import numpy as np
 from loguru import logger
-from magic_pdf.model.model_list import MODEL, MODEL_TYPE
+from magic_pdf.model.model_list import MODEL
 import magic_pdf.model as model_config
 
 
@@ -44,9 +46,9 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
     return images
 
 
-def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, model=MODEL.PEK,
-                model_type=MODEL_TYPE.MULTI_PAGE):
+def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, model=MODEL.PEK):
     if model_config.__use_inside_model__:
+        model_init_start = time.time()
         if model == MODEL.Paddle:
             from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
             custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
@@ -56,6 +58,8 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, mod
         else:
             logger.error("Not allow model_name!")
             exit(1)
+        model_init_cost = time.time() - model_init_start
+        logger.info(f"model init cost: {model_init_cost}")
     else:
         logger.error("use_inside_model is False, not allow to use inside model")
         exit(1)
@@ -63,16 +67,16 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, mod
     images = load_images_from_pdf(pdf_bytes)
 
     model_json = []
-    if model_type == MODEL_TYPE.SINGLE_PAGE:
-        for index, img_dict in enumerate(images):
-            img = img_dict["img"]
-            page_width = img_dict["width"]
-            page_height = img_dict["height"]
-            result = custom_model(img)
-            page_info = {"page_no": index, "height": page_height, "width": page_width}
-            page_dict = {"layout_dets": result, "page_info": page_info}
-            model_json.append(page_dict)
-    elif model_type == MODEL_TYPE.MULTI_PAGE:
-        model_json = custom_model(images)
+    doc_analyze_start = time.time()
+    for index, img_dict in enumerate(images):
+        img = img_dict["img"]
+        page_width = img_dict["width"]
+        page_height = img_dict["height"]
+        result = custom_model(img)
+        page_info = {"page_no": index, "height": page_height, "width": page_width}
+        page_dict = {"layout_dets": result, "page_info": page_info}
+        model_json.append(page_dict)
+    doc_analyze_cost = time.time() - doc_analyze_start
+    logger.info(f"doc analyze cost: {doc_analyze_cost}")
 
     return model_json

+ 0 - 7
magic_pdf/model/model_list.py

@@ -1,10 +1,3 @@
 class MODEL:
     Paddle = "pp_structure_v2"
     PEK = "pdf_extract_kit"
-
-
-class MODEL_TYPE:
-    # 单页解析
-    SINGLE_PAGE = 1
-    # 多页解析
-    MULTI_PAGE = 2

+ 64 - 69
magic_pdf/model/pdf_extract_kit.py

@@ -107,83 +107,78 @@ class CustomPEKModel:
 
         logger.info('DocAnalysis init done!')
 
-    def __call__(self, images):
-        # layout检测 + 公式检测
-        doc_layout_result = []
+    def __call__(self, image):
+
         latex_filling_list = []
         mf_image_list = []
-        for idx, img_dict in enumerate(images):
-            image = img_dict["img"]
-            img_height, img_width = img_dict["height"], img_dict["width"]
-            layout_res = self.layout_model(image, ignore_catids=[])
-            # 公式检测
-            mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
-            for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
-                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
-                new_item = {
-                    'category_id': 13 + int(cla.item()),
-                    'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
-                    'score': round(float(conf.item()), 2),
-                    'latex': '',
-                }
-                layout_res['layout_dets'].append(new_item)
-                latex_filling_list.append(new_item)
-                bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
-                mf_image_list.append(bbox_img)
-
-            layout_res['page_info'] = dict(
-                page_no=idx,
-                height=img_height,
-                width=img_width
-            )
-            doc_layout_result.append(layout_res)
 
-        # 公式识别,因为识别速度较慢,为了提速,把单个pdf的所有公式裁剪完,一起批量做识别。
-        a = time.time()
+        # layout检测
+        layout_start = time.time()
+        layout_res = self.layout_model(image, ignore_catids=[])
+        layout_cost = round(time.time() - layout_start, 2)
+        logger.info(f"layout detection cost: {layout_cost}")
+
+        # 公式检测
+        mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
+        for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
+            xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+            new_item = {
+                'category_id': 13 + int(cla.item()),
+                'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                'score': round(float(conf.item()), 2),
+                'latex': '',
+            }
+            layout_res.append(new_item)
+            latex_filling_list.append(new_item)
+            bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
+            mf_image_list.append(bbox_img)
+
+        # 公式识别
+        mfr_start = time.time()
         dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
-        dataloader = DataLoader(dataset, batch_size=128, num_workers=0)
+        dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
         mfr_res = []
-        for imgs in dataloader:
-            imgs = imgs.to(self.device)
-            output = self.mfr_model.generate({'image': imgs})
+        for mf_img in dataloader:
+            mf_img = mf_img.to(self.device)
+            output = self.mfr_model.generate({'image': mf_img})
             mfr_res.extend(output['pred_str'])
         for res, latex in zip(latex_filling_list, mfr_res):
             res['latex'] = latex_rm_whitespace(latex)
-        b = time.time()
-        logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {round(b - a, 2)}")
+        mfr_cost = round(time.time() - mfr_start, 2)
+        logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
 
         # ocr识别
         if self.apply_ocr:
-            for idx, img_dict in enumerate(images):
-                image = img_dict["img"]
-                pil_img = Image.fromarray(image)
-                single_page_res = doc_layout_result[idx]['layout_dets']
-                single_page_mfdetrec_res = []
-                for res in single_page_res:
-                    if int(res['category_id']) in [13, 14]:
-                        xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
-                        xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
-                        single_page_mfdetrec_res.append({
-                            "bbox": [xmin, ymin, xmax, ymax],
-                        })
-                for res in single_page_res:
-                    if int(res['category_id']) in [0, 1, 2, 4, 6, 7]:  # 需要进行ocr的类别
-                        xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
-                        xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
-                        crop_box = [xmin, ymin, xmax, ymax]
-                        cropped_img = Image.new('RGB', pil_img.size, 'white')
-                        cropped_img.paste(pil_img.crop(crop_box), crop_box)
-                        cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
-                        ocr_res = self.ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
-                        if ocr_res:
-                            for box_ocr_res in ocr_res:
-                                p1, p2, p3, p4 = box_ocr_res[0]
-                                text, score = box_ocr_res[1]
-                                doc_layout_result[idx]['layout_dets'].append({
-                                    'category_id': 15,
-                                    'poly': p1 + p2 + p3 + p4,
-                                    'score': round(score, 2),
-                                    'text': text,
-                                })
-
-        return doc_layout_result
+            ocr_start = time.time()
+            pil_img = Image.fromarray(image)
+            single_page_mfdetrec_res = []
+            for res in layout_res:
+                if int(res['category_id']) in [13, 14]:
+                    xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
+                    xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
+                    single_page_mfdetrec_res.append({
+                        "bbox": [xmin, ymin, xmax, ymax],
+                    })
+            for res in layout_res:
+                if int(res['category_id']) in [0, 1, 2, 4, 6, 7]:  # 需要进行ocr的类别
+                    xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
+                    xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
+                    crop_box = (xmin, ymin, xmax, ymax)
+                    cropped_img = Image.new('RGB', pil_img.size, 'white')
+                    cropped_img.paste(pil_img.crop(crop_box), crop_box)
+                    cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
+                    ocr_res = self.ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
+                    if ocr_res:
+                        for box_ocr_res in ocr_res:
+                            p1, p2, p3, p4 = box_ocr_res[0]
+                            text, score = box_ocr_res[1]
+                            layout_res.append({
+                                'category_id': 15,
+                                'poly': p1 + p2 + p3 + p4,
+                                'score': round(score, 2),
+                                'text': text,
+                            })
+            ocr_cost = round(time.time() - ocr_start, 2)
+            logger.info(f"ocr cost: {ocr_cost}")
+
+        return layout_res

+ 15 - 11
magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py

@@ -8,6 +8,7 @@ from detectron2.data import MetadataCatalog, DatasetCatalog
 from detectron2.data.datasets import register_coco_instances
 from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch, DefaultPredictor
 
+
 def add_vit_config(cfg):
     """
     Add config for VIT.
@@ -72,14 +73,14 @@ def setup(args):
     cfg.merge_from_list(args.opts)
     cfg.freeze()
     default_setup(cfg, args)
-    
+
     register_coco_instances(
         "scihub_train",
         {},
         cfg.SCIHUB_DATA_DIR_TRAIN + ".json",
         cfg.SCIHUB_DATA_DIR_TRAIN
     )
-    
+
     return cfg
 
 
@@ -94,10 +95,11 @@ class DotDict(dict):
         if isinstance(value, dict):
             value = DotDict(value)
         return value
-    
+
     def __setattr__(self, key, value):
         self[key] = value
-        
+
+
 class Layoutlmv3_Predictor(object):
     def __init__(self, weights, config_file):
         layout_args = {
@@ -113,14 +115,16 @@ class Layoutlmv3_Predictor(object):
         layout_args = DotDict(layout_args)
 
         cfg = setup(layout_args)
-        self.mapping = ["title", "plain text", "abandon", "figure", "figure_caption", "table", "table_caption", "table_footnote", "isolate_formula", "formula_caption"]
+        self.mapping = ["title", "plain text", "abandon", "figure", "figure_caption", "table", "table_caption",
+                        "table_footnote", "isolate_formula", "formula_caption"]
         MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes = self.mapping
         self.predictor = DefaultPredictor(cfg)
-        
+
     def __call__(self, image, ignore_catids=[]):
-        page_layout_result = {
-            "layout_dets": []
-        }
+        # page_layout_result = {
+        #     "layout_dets": []
+        # }
+        layout_dets = []
         outputs = self.predictor(image)
         boxes = outputs["instances"].to("cpu")._fields["pred_boxes"].tensor.tolist()
         labels = outputs["instances"].to("cpu")._fields["pred_classes"].tolist()
@@ -128,7 +132,7 @@ class Layoutlmv3_Predictor(object):
         for bbox_idx in range(len(boxes)):
             if labels[bbox_idx] in ignore_catids:
                 continue
-            page_layout_result["layout_dets"].append({
+            layout_dets.append({
                 "category_id": labels[bbox_idx],
                 "poly": [
                     boxes[bbox_idx][0], boxes[bbox_idx][1],
@@ -138,4 +142,4 @@ class Layoutlmv3_Predictor(object):
                 ],
                 "score": scores[bbox_idx]
             })
-        return page_layout_result
+        return layout_dets

+ 4 - 3
magic_pdf/model/pek_sub_modules/self_modify.py

@@ -136,9 +136,10 @@ class ModifiedPaddleOCR(PaddleOCR):
             logger.error('When input a list of images, det must be false')
             exit(0)
         if cls == True and self.use_angle_cls == False:
-            logger.warning(
-                'Since the angle classifier is not initialized, it will not be used during the forward process'
-            )
+            pass
+            # logger.warning(
+            #     'Since the angle classifier is not initialized, it will not be used during the forward process'
+            # )
 
         img = check_img(img)
         # for infer pdf file