Przeglądaj źródła

Merge branch 'opendatalab:dev' into dev

linfeng 1 rok temu
rodzic
commit
076a2a1463

+ 24 - 21
magic_pdf/model/pdf_extract_kit.py

@@ -32,7 +32,7 @@ except ImportError as e:
     exit(1)
 
 from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
-from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
+from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
 from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
 from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
 from magic_pdf.model.ppTableModel import ppTableModel
@@ -63,7 +63,8 @@ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
     cfg.config.model.tokenizer_config.path = weight_dir
     task = tasks.setup_task(cfg)
     model = task.build_model(cfg)
-    model = model.to(_device_)
+    model.to(_device_)
+    model.eval()
     vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
     mfr_transform = transforms.Compose([vis_processor, ])
     return [model, mfr_transform]
@@ -154,6 +155,23 @@ def atom_model_init(model_name: str, **kwargs):
     return atom_model
 
 
+#  Unified crop img logic
+def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
+    crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
+    crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
+    # Create a white background with an additional width and height of 50
+    crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
+    crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
+    return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
+
+    # Crop image
+    crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
+    cropped_img = input_pil_img.crop(crop_box)
+    return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
+    return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
+    return return_image, return_list
+
+
 class CustomPEKModel:
 
     def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
@@ -264,6 +282,8 @@ class CustomPEKModel:
         layout_cost = round(time.time() - layout_start, 2)
         logger.info(f"layout detection cost: {layout_cost}")
 
+        pil_img = Image.fromarray(image)
+
         if self.apply_formula:
             # 公式检测
             mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
@@ -277,7 +297,8 @@ class CustomPEKModel:
                 }
                 layout_res.append(new_item)
                 latex_filling_list.append(new_item)
-                bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
+                # bbox_img = get_croped_image(pil_img, [xmin, ymin, xmax, ymax])
+                bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
                 mf_image_list.append(bbox_img)
 
             # 公式识别
@@ -309,24 +330,6 @@ class CustomPEKModel:
             elif int(res['category_id']) in [5]:
                 table_res_list.append(res)
 
-        #  Unified crop img logic
-        def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
-            crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
-            crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
-            # Create a white background with an additional width and height of 50
-            crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
-            crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
-            return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
-
-            # Crop image
-            crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
-            cropped_img = input_pil_img.crop(crop_box)
-            return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
-            return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
-            return return_image, return_list
-
-        pil_img = Image.fromarray(image)
-
         # ocr识别
         if self.apply_ocr:
             ocr_start = time.time()

+ 1 - 1
magic_pdf/resources/model_config/model_configs.yaml

@@ -10,6 +10,6 @@ config:
 weights:
   layout: Layout/model_final.pth
   mfd: MFD/weights.pt
-  mfr: MFR/unimernet_base
+  mfr: MFR/unimernet_small
   struct_eqtable: TabRec/StructEqTable
   TableMaster: TabRec/TableMaster