Explorar o código

fix: refactor image handling to use numpy arrays instead of PIL images

myhloli hai 2 meses
pai
achega
2fcffcb0af

+ 20 - 21
mineru/backend/pipeline/batch_analyze.py

@@ -9,7 +9,7 @@ from .model_list import AtomicModel
 from ...utils.config_reader import get_formula_enable, get_table_enable
 from ...utils.model_utils import crop_img, get_res_list_from_layout_res
 from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
-from ...utils.pdf_image_tools import get_crop_img
+from ...utils.pdf_image_tools import get_crop_np_img
 
 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
@@ -38,29 +38,28 @@ class BatchAnalyze:
         )
         atom_model_manager = AtomModelSingleton()
 
-        images = [image for image, _, _ in images_with_extra_info]
+        np_images = [np.asarray(image) for image, _, _ in images_with_extra_info]
 
         # doclayout_yolo
-        layout_images = images.copy()
 
         images_layout_res += self.model.layout_model.batch_predict(
-            layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
+            np_images, YOLO_LAYOUT_BASE_BATCH_SIZE
         )
 
         if self.formula_enable:
             # 公式检测
             images_mfd_res = self.model.mfd_model.batch_predict(
-                images, MFD_BASE_BATCH_SIZE
+                np_images, MFD_BASE_BATCH_SIZE
             )
 
             # 公式识别
             images_formula_list = self.model.mfr_model.batch_predict(
                 images_mfd_res,
-                images,
+                np_images,
                 batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
             )
             mfr_count = 0
-            for image_index in range(len(images)):
+            for image_index in range(len(np_images)):
                 images_layout_res[image_index] += images_formula_list[image_index]
                 mfr_count += len(images_formula_list[image_index])
 
@@ -69,10 +68,10 @@ class BatchAnalyze:
 
         ocr_res_list_all_page = []
         table_res_list_all_page = []
-        for index in range(len(images)):
+        for index in range(len(np_images)):
             _, ocr_enable, _lang = images_with_extra_info[index]
             layout_res = images_layout_res[index]
-            pil_img = images[index]
+            np_img = np_images[index]
 
             ocr_res_list, table_res_list, single_page_mfdetrec_res = (
                 get_res_list_from_layout_res(layout_res)
@@ -81,7 +80,7 @@ class BatchAnalyze:
             ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
                                           'lang':_lang,
                                           'ocr_enable':ocr_enable,
-                                          'pil_img':pil_img,
+                                          'np_img':np_img,
                                           'single_page_mfdetrec_res':single_page_mfdetrec_res,
                                           'layout_res':layout_res,
                                           })
@@ -93,7 +92,7 @@ class BatchAnalyze:
                 crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
                 crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
                 bbox = (int(crop_xmin/scale), int(crop_ymin/scale), int(crop_xmax/scale), int(crop_ymax/scale))
-                table_img = get_crop_img(bbox, pil_img, scale=scale)
+                table_img = get_crop_np_img(bbox, np_img, scale=scale)
 
                 table_res_list_all_page.append({'table_res':table_res,
                                                 'lang':_lang,
@@ -111,17 +110,17 @@ class BatchAnalyze:
 
                 for res in ocr_res_list_dict['ocr_res_list']:
                     new_image, useful_list = crop_img(
-                        res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
+                        res, ocr_res_list_dict['np_img'], crop_paste_x=50, crop_paste_y=50
                     )
                     adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
                         ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
                     )
 
                     # BGR转换
-                    new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
+                    bgr_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
 
                     all_cropped_images_info.append((
-                        new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
+                        bgr_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
                     ))
 
             # 按语言分组
@@ -186,7 +185,7 @@ class BatchAnalyze:
 
                     # 处理批处理结果
                     for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
-                        new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
+                        bgr_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
 
                         if dt_boxes is not None and len(dt_boxes) > 0:
                             # 直接应用原始OCR流程中的关键处理步骤
@@ -217,7 +216,7 @@ class BatchAnalyze:
 
                             if ocr_res:
                                 ocr_result_list = get_ocr_result_list(
-                                    ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
+                                    ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], bgr_image, _lang
                                 )
 
                                 ocr_res_list_dict['layout_res'].extend(ocr_result_list)
@@ -235,21 +234,21 @@ class BatchAnalyze:
                 )
                 for res in ocr_res_list_dict['ocr_res_list']:
                     new_image, useful_list = crop_img(
-                        res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
+                        res, ocr_res_list_dict['np_img'], crop_paste_x=50, crop_paste_y=50
                     )
                     adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
                         ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
                     )
                     # OCR-det
-                    new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
+                    bgr_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
                     ocr_res = ocr_model.ocr(
-                        new_image, mfd_res=adjusted_mfdetrec_res, rec=False
+                        bgr_image, mfd_res=adjusted_mfdetrec_res, rec=False
                     )[0]
 
                     # Integration results
                     if ocr_res:
                         ocr_result_list = get_ocr_result_list(
-                            ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],new_image, _lang
+                            ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],bgr_image, _lang
                         )
 
                         ocr_res_list_dict['layout_res'].extend(ocr_result_list)
@@ -273,7 +272,7 @@ class BatchAnalyze:
                     )
                     rotate_label = "0"
 
-                np_table_img = np.asarray(table_res_dict["table_img"])
+                np_table_img = table_res_dict["table_img"]
                 if rotate_label == "270":
                     np_table_img = cv2.rotate(np_table_img, cv2.ROTATE_90_CLOCKWISE)
                 elif rotate_label == "90":

+ 2 - 2
mineru/model/mfr/unimernet/Unimernet.py

@@ -70,7 +70,7 @@ class UnimernetModel(object):
         # Collect images with their original indices
         for image_index in range(len(images_mfd_res)):
             mfd_res = images_mfd_res[image_index]
-            pil_img = images[image_index]
+            image = images[image_index]
             formula_list = []
 
             for idx, (xyxy, conf, cla) in enumerate(zip(
@@ -84,7 +84,7 @@ class UnimernetModel(object):
                     "latex": "",
                 }
                 formula_list.append(new_item)
-                bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
+                bbox_img = image[ymin:ymax, xmin:xmax]
                 area = (xmax - xmin) * (ymax - ymin)
 
                 curr_idx = len(mf_image_list)

+ 2 - 2
mineru/utils/ocr_utils.py

@@ -330,10 +330,10 @@ def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
     return adjusted_mfdetrec_res
 
 
-def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image, lang):
+def get_ocr_result_list(ocr_res, useful_list, ocr_enable, bgr_image, lang):
     paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
     ocr_result_list = []
-    ori_im = new_image.copy()
+    ori_im = bgr_image.copy()
     for box_ocr_res in ocr_res:
 
         if len(box_ocr_res) == 2:

+ 19 - 0
mineru/utils/pdf_image_tools.py

@@ -1,6 +1,7 @@
 # Copyright (c) Opendatalab. All rights reserved.
 from io import BytesIO
 
+import numpy as np
 import pypdfium2 as pdfium
 from loguru import logger
 from PIL import Image
@@ -91,6 +92,24 @@ def get_crop_img(bbox: tuple, pil_img, scale=2):
     return pil_img.crop(scale_bbox)
 
 
+def get_crop_np_img(bbox: tuple, input_img, scale=2):
+
+    if isinstance(input_img, Image.Image):
+        np_img = np.asarray(input_img)
+    elif isinstance(input_img, np.ndarray):
+        np_img = input_img
+    else:
+        raise ValueError("Input must be a pillow object or a numpy array.")
+
+    scale_bbox = (
+        int(bbox[0] * scale),
+        int(bbox[1] * scale),
+        int(bbox[2] * scale),
+        int(bbox[3] * scale),
+    )
+
+    return np_img[scale_bbox[1]:scale_bbox[3], scale_bbox[0]:scale_bbox[2]]
+
 def images_bytes_to_pdf_bytes(image_bytes):
     # 内存缓冲区
     pdf_buffer = BytesIO()