瀏覽代碼

Merge pull request #1272 from myhloli/add-llm-aided

perf(layout): optimize layout detection for PDF extraction
Xiaomeng Zhao 11 月之前
父節點
當前提交
5bbd07a195
共有 2 個文件被更改,包括 23 次插入20 次删除
  1. 22 20
      magic_pdf/model/pdf_extract_kit.py
  2. 1 0
      magic_pdf/post_proc/__init__.py

+ 22 - 20
magic_pdf/model/pdf_extract_kit.py

@@ -171,6 +171,10 @@ class CustomPEKModel:
 
     def __call__(self, image):
 
+        pil_img = Image.fromarray(image)
+        width, height = pil_img.size
+        # logger.info(f'width: {width}, height: {height}')
+
         # layout检测
         layout_start = time.time()
         layout_res = []
@@ -179,30 +183,28 @@ class CustomPEKModel:
             layout_res = self.layout_model(image, ignore_catids=[])
         elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
             # doclayout_yolo
-            img_pil = Image.fromarray(image)
-            width, height = img_pil.size
-            # logger.info(f'width: {width}, height: {height}')
-            input_res = {"poly":[0,0,width,0,width,height,0,height]}
-            new_image, useful_list = crop_img(input_res, img_pil, crop_paste_x=width//2, crop_paste_y=0)
-            paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
-            layout_res = self.layout_model.predict(new_image)
-            for res in layout_res:
-                p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
-                p1 = p1 - paste_x + xmin
-                p2 = p2 - paste_y + ymin
-                p3 = p3 - paste_x + xmin
-                p4 = p4 - paste_y + ymin
-                p5 = p5 - paste_x + xmin
-                p6 = p6 - paste_y + ymin
-                p7 = p7 - paste_x + xmin
-                p8 = p8 - paste_y + ymin
-                res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
+            if height > width:
+                input_res = {"poly":[0,0,width,0,width,height,0,height]}
+                new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
+                paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
+                layout_res = self.layout_model.predict(new_image)
+                for res in layout_res:
+                    p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
+                    p1 = p1 - paste_x + xmin
+                    p2 = p2 - paste_y + ymin
+                    p3 = p3 - paste_x + xmin
+                    p4 = p4 - paste_y + ymin
+                    p5 = p5 - paste_x + xmin
+                    p6 = p6 - paste_y + ymin
+                    p7 = p7 - paste_x + xmin
+                    p8 = p8 - paste_y + ymin
+                    res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
+            else:
+                layout_res = self.layout_model.predict(image)
 
         layout_cost = round(time.time() - layout_start, 2)
         logger.info(f'layout detection time: {layout_cost}')
 
-        pil_img = Image.fromarray(image)
-
         if self.apply_formula:
             # 公式检测
             mfd_start = time.time()

+ 1 - 0
magic_pdf/post_proc/__init__.py

@@ -0,0 +1 @@
+# Copyright (c) Opendatalab. All rights reserved.