Przeglądaj źródła

fix(pdf_extract_kit):change unimernet base -> small

myhloli 1 rok temu
rodzic
commit
f2a3a49541

+ 18 - 17
magic_pdf/model/pdf_extract_kit.py

@@ -63,7 +63,7 @@ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
     cfg.config.model.tokenizer_config.path = weight_dir
     task = tasks.setup_task(cfg)
     model = task.build_model(cfg)
-    model = model.to(_device_)
+    model.to(_device_)
     model.eval()
     vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
     mfr_transform = transforms.Compose([vis_processor, ])
@@ -155,6 +155,23 @@ def atom_model_init(model_name: str, **kwargs):
     return atom_model
 
 
+#  Unified crop img logic
+def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
+    crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
+    crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
+    # Create a white background with an additional width and height of 50
+    crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
+    crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
+    return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
+
+    # Crop image
+    crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
+    cropped_img = input_pil_img.crop(crop_box)
+    return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
+    return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
+    return return_image, return_list
+
+
 class CustomPEKModel:
 
     def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
@@ -313,22 +330,6 @@ class CustomPEKModel:
             elif int(res['category_id']) in [5]:
                 table_res_list.append(res)
 
-        #  Unified crop img logic
-        def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
-            crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
-            crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
-            # Create a white background with an additional width and height of 50
-            crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
-            crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
-            return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
-
-            # Crop image
-            crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
-            cropped_img = input_pil_img.crop(crop_box)
-            return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
-            return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
-            return return_image, return_list
-
         # ocr识别
         if self.apply_ocr:
             ocr_start = time.time()

+ 1 - 1
magic_pdf/resources/model_config/model_configs.yaml

@@ -10,6 +10,6 @@ config:
 weights:
   layout: Layout/model_final.pth
   mfd: MFD/weights.pt
-  mfr: MFR/unimernet_base
+  mfr: MFR/unimernet_small
   struct_eqtable: TabRec/StructEqTable
   TableMaster: TabRec/TableMaster