소스 검색

fix: add enable_merge_det_boxes parameter to model initialization for improved box merging control

myhloli 3 달 전
부모
커밋
1d55925954

+ 30 - 8
mineru/backend/pipeline/model_init.py

@@ -19,7 +19,11 @@ from ...utils.models_download_utils import auto_download_and_get_model_root_path
 def img_orientation_cls_model_init():
     atom_model_manager = AtomModelSingleton()
     ocr_engine = atom_model_manager.get_atom_model(
-        atom_model_name="ocr", det_db_box_thresh=0.5, det_db_unclip_ratio=1.6, lang="ch_lite"
+        atom_model_name=AtomicModel.OCR,
+        det_db_box_thresh=0.5,
+        det_db_unclip_ratio=1.6,
+        lang="ch_lite",
+        enable_merge_det_boxes=False
     )
     cls_model = PaddleOrientationClsModel(ocr_engine)
     return cls_model
@@ -32,7 +36,11 @@ def table_cls_model_init():
 def wired_table_model_init(lang=None):
     atom_model_manager = AtomModelSingleton()
     ocr_engine = atom_model_manager.get_atom_model(
-        atom_model_name="ocr", det_db_box_thresh=0.5, det_db_unclip_ratio=1.6, lang=lang
+        atom_model_name=AtomicModel.OCR,
+        det_db_box_thresh=0.5,
+        det_db_unclip_ratio=1.6,
+        lang=lang,
+        enable_merge_det_boxes=False
     )
     table_model = UnetTableModel(ocr_engine)
     return table_model
@@ -41,7 +49,11 @@ def wired_table_model_init(lang=None):
 def wireless_table_model_init(lang=None):
     atom_model_manager = AtomModelSingleton()
     ocr_engine = atom_model_manager.get_atom_model(
-        atom_model_name="ocr", det_db_box_thresh=0.5, det_db_unclip_ratio=1.6, lang=lang
+        atom_model_name=AtomicModel.OCR,
+        det_db_box_thresh=0.5,
+        det_db_unclip_ratio=1.6,
+        lang=lang,
+        enable_merge_det_boxes=False
     )
     table_model = RapidTableModel(ocr_engine)
     return table_model
@@ -67,21 +79,23 @@ def doclayout_yolo_model_init(weight, device='cpu'):
 
 def ocr_model_init(det_db_box_thresh=0.3,
                    lang=None,
-                   use_dilation=True,
                    det_db_unclip_ratio=1.8,
+                   enable_merge_det_boxes=True
                    ):
     if lang is not None and lang != '':
         model = PytorchPaddleOCR(
             det_db_box_thresh=det_db_box_thresh,
             lang=lang,
-            use_dilation=use_dilation,
+            use_dilation=True,
             det_db_unclip_ratio=det_db_unclip_ratio,
+            enable_merge_det_boxes=enable_merge_det_boxes,
         )
     else:
         model = PytorchPaddleOCR(
             det_db_box_thresh=det_db_box_thresh,
-            use_dilation=use_dilation,
+            use_dilation=True,
             det_db_unclip_ratio=det_db_unclip_ratio,
+            enable_merge_det_boxes=enable_merge_det_boxes,
         )
     return model
 
@@ -99,8 +113,14 @@ class AtomModelSingleton:
 
         lang = kwargs.get('lang', None)
 
-        if atom_model_name in [AtomicModel.OCR, AtomicModel.WiredTable, AtomicModel.WirelessTable]:
+        if atom_model_name in [AtomicModel.WiredTable, AtomicModel.WirelessTable]:
             key = (atom_model_name, lang)
+        elif atom_model_name in [AtomicModel.OCR]:
+            key = (atom_model_name,
+                   kwargs.get('det_db_box_thresh', 0.3),
+                   lang, kwargs.get('det_db_unclip_ratio', 1.8),
+                   kwargs.get('enable_merge_det_boxes', True)
+                   )
         else:
             key = atom_model_name
 
@@ -127,8 +147,10 @@ def atom_model_init(model_name: str, **kwargs):
         )
     elif model_name == AtomicModel.OCR:
         atom_model = ocr_model_init(
-            kwargs.get('det_db_box_thresh'),
+            kwargs.get('det_db_box_thresh', 0.3),
             kwargs.get('lang'),
+            kwargs.get('det_db_unclip_ratio', 1.8),
+            kwargs.get('enable_merge_det_boxes', True)
         )
     elif model_name == AtomicModel.WirelessTable:
         atom_model = wireless_table_model_init(

+ 0 - 1
mineru/backend/pipeline/model_json_to_middle_json.py

@@ -212,7 +212,6 @@ def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=N
         atom_model_manager = AtomModelSingleton()
         ocr_model = atom_model_manager.get_atom_model(
             atom_model_name='ocr',
-            ocr_show_log=False,
             det_db_box_thresh=0.3,
             lang=lang
         )

+ 5 - 2
mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py

@@ -56,6 +56,7 @@ class PytorchPaddleOCR(TextSystem):
         args = parser.parse_args(args)
 
         self.lang = kwargs.get('lang', 'ch')
+        self.enable_merge_det_boxes = kwargs.get("enable_merge_det_boxes", True)
 
         device = get_device()
         if device == 'cpu' and self.lang in ['ch', 'ch_server', 'japan', 'chinese_cht']:
@@ -135,7 +136,8 @@ class PytorchPaddleOCR(TextSystem):
                         continue
                     dt_boxes = sorted_boxes(dt_boxes)
                     # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
-                    dt_boxes = merge_det_boxes(dt_boxes)
+                    if self.enable_merge_det_boxes:
+                        dt_boxes = merge_det_boxes(dt_boxes)
                     if mfd_res:
                         dt_boxes = update_det_boxes(dt_boxes, mfd_res)
                     tmp_res = [box.tolist() for box in dt_boxes]
@@ -172,7 +174,8 @@ class PytorchPaddleOCR(TextSystem):
         dt_boxes = sorted_boxes(dt_boxes)
 
         # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
-        dt_boxes = merge_det_boxes(dt_boxes)
+        if self.enable_merge_det_boxes:
+            dt_boxes = merge_det_boxes(dt_boxes)
 
         if mfd_res:
             dt_boxes = update_det_boxes(dt_boxes, mfd_res)

+ 1 - 2
mineru/model/table/cls/paddle_table_cls.py

@@ -24,8 +24,7 @@ class PaddleTableClsModel:
 
     def preprocess(self, img):
         # PIL图像转cv2
-        img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
-        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        img = np.array(img)
         # 放大图片,使其最短边长为256
         h, w = img.shape[:2]
         scale = 256 / min(h, w)