Ver código fonte

Merge pull request #119 from icecraft/feat/parallel_paddle

feat: parallelize paddle
myhloli 1 ano atrás
pai
commit
c96aa88d13
1 arquivos alterados com 55 adições e 49 exclusões
  1. 55 49
      magic_pdf/model/doc_analyze_by_pp_structurev2.py

+ 55 - 49
magic_pdf/model/doc_analyze_by_pp_structurev2.py

@@ -7,6 +7,7 @@ from PIL import Image
 from loguru import logger
 import numpy as np
 
+
 def region_to_bbox(region):
     x0 = region[0][0]
     y0 = region[0][1]
@@ -22,12 +23,14 @@ def dict_compare(d1, d2):
 def remove_duplicates_dicts(lst):
     unique_dicts = []
     for dict_item in lst:
-        if not any(dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts):
+        if not any(
+            dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts
+        ):
             unique_dicts.append(dict_item)
     return unique_dicts
-def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False):
-    ocr_engine = PPStructure(table=False, ocr=ocr, show_log=show_log)
 
+
+def load_imags_from_pdf(pdf_bytes: bytes, dpi=200):
     imgs = []
     with fitz.open("pdf", pdf_bytes) as doc:
         for index in range(0, doc.page_count):
@@ -42,23 +45,20 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False):
 
             img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples)
             img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
-            img_dict = {
-                "img": img,
-                "width": pm.width,
-                "height": pm.height
-            }
+            img_dict = {"img": img, "width": pm.width, "height": pm.height}
             imgs.append(img_dict)
 
-    model_json = []
-    for index, img_dict in enumerate(imgs):
-        img = img_dict['img']
-        page_width = img_dict['width']
-        page_height = img_dict['height']
-        result = ocr_engine(img)
+
+class CustomPaddleModel:
+    def __init___(self, ocr: bool = False, show_log: bool = False):
+        self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
+
+    def __call__(self, img):
+        result = self.model(img)
         spans = []
         for line in result:
-            line.pop('img')
-            '''
+            line.pop("img")
+            """
             为paddle输出适配type no.    
             title: 0 # 标题
             text: 1 # 文本
@@ -71,54 +71,60 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False):
             figure_caption: 4 # 图片描述
             table: 5 # 表格
             table_caption: 6 # 表格描述
-            '''
-            if line['type'] == 'title':
-                line['category_id'] = 0
-            elif line['type'] in ['text', 'reference']:
-                line['category_id'] = 1
-            elif line['type'] == 'figure':
-                line['category_id'] = 3
-            elif line['type'] == 'figure_caption':
-                line['category_id'] = 4
-            elif line['type'] == 'table':
-                line['category_id'] = 5
-            elif line['type'] == 'table_caption':
-                line['category_id'] = 6
-            elif line['type'] == 'equation':
-                line['category_id'] = 8
-            elif line['type'] in ['header', 'footer']:
-                line['category_id'] = 2
+            """
+            if line["type"] == "title":
+                line["category_id"] = 0
+            elif line["type"] in ["text", "reference"]:
+                line["category_id"] = 1
+            elif line["type"] == "figure":
+                line["category_id"] = 3
+            elif line["type"] == "figure_caption":
+                line["category_id"] = 4
+            elif line["type"] == "table":
+                line["category_id"] = 5
+            elif line["type"] == "table_caption":
+                line["category_id"] = 6
+            elif line["type"] == "equation":
+                line["category_id"] = 8
+            elif line["type"] in ["header", "footer"]:
+                line["category_id"] = 2
             else:
                 logger.warning(f"unknown type: {line['type']}")
 
             # 兼容不输出score的paddleocr版本
             if line.get("score") is None:
-                line['score'] = 0.5 + random.random() * 0.5
+                line["score"] = 0.5 + random.random() * 0.5
 
-            res = line.pop('res', None)
+            res = line.pop("res", None)
             if res is not None and len(res) > 0:
                 for span in res:
-                    new_span = {'category_id': 15,
-                                'bbox': region_to_bbox(span['text_region']),
-                                'score': span['confidence'],
-                                'text': span['text']
-                                }
+                    new_span = {
+                        "category_id": 15,
+                        "bbox": region_to_bbox(span["text_region"]),
+                        "score": span["confidence"],
+                        "text": span["text"],
+                    }
                     spans.append(new_span)
 
         if len(spans) > 0:
             result.extend(spans)
 
         result = remove_duplicates_dicts(result)
+        return result
+
+
+def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False):
+    imgs = load_imags_from_pdf(pdf_bytes)
+    custom_paddle =  CustomPaddleModel()
 
-        page_info = {
-            "page_no": index,
-            "height": page_height,
-            "width": page_width
-        }
-        page_dict = {
-            "layout_dets": result,
-            "page_info": page_info
-        }
+    model_json = []
+    for index, img_dict in enumerate(imgs):
+        img = img_dict["img"]
+        page_width = img_dict["width"]
+        page_height = img_dict["height"]
+        result = custom_paddle(img)
+        page_info = {"page_no": index, "height": page_height, "width": page_width}
+        page_dict = {"layout_dets": result, "page_info": page_info}
 
         model_json.append(page_dict)