Sfoglia il codice sorgente

Merge pull request #635 from myhloli/dev

refactor(pdf_extract): use Image.crop directly with layout detection
Xiaomeng Zhao 1 anno fa
parent
commit
19a74db869
1 ha cambiato i file con 5 aggiunte e 4 eliminazioni
  1. 5 4
      magic_pdf/model/pdf_extract_kit.py

+ 5 - 4
magic_pdf/model/pdf_extract_kit.py

@@ -32,7 +32,7 @@ except ImportError as e:
     exit(1)
 
 from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
-from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
+from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
 from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
 from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
 from magic_pdf.model.ppTableModel import ppTableModel
@@ -264,6 +264,8 @@ class CustomPEKModel:
         layout_cost = round(time.time() - layout_start, 2)
         logger.info(f"layout detection cost: {layout_cost}")
 
+        pil_img = Image.fromarray(image)
+
         if self.apply_formula:
             # 公式检测
             mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
@@ -277,7 +279,8 @@ class CustomPEKModel:
                 }
                 layout_res.append(new_item)
                 latex_filling_list.append(new_item)
-                bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
+                # bbox_img = get_croped_image(pil_img, [xmin, ymin, xmax, ymax])
+                bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
                 mf_image_list.append(bbox_img)
 
             # 公式识别
@@ -325,8 +328,6 @@ class CustomPEKModel:
             return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
             return return_image, return_list
 
-        pil_img = Image.fromarray(image)
-
         # ocr识别
         if self.apply_ocr:
             ocr_start = time.time()