Explorar el Código

refactor: improve image handling by transitioning from NumPy arrays to PIL images in cropping functions

myhloli hace 5 meses
padre
commit
101b12a10a
Se han modificado 3 ficheros con 58 adiciones y 51 borrados
  1. 31 32
      mineru/backend/pipeline/batch_analyze.py
  2. 9 11
      mineru/cli/common.py
  3. 18 8
      mineru/utils/model_utils.py

+ 31 - 32
mineru/backend/pipeline/batch_analyze.py

@@ -71,7 +71,7 @@ class BatchAnalyze:
         for index in range(len(images)):
             _, ocr_enable, _lang = images_with_extra_info[index]
             layout_res = images_layout_res[index]
-            np_array_img = images[index]
+            pil_img = images[index]
 
             ocr_res_list, table_res_list, single_page_mfdetrec_res = (
                 get_res_list_from_layout_res(layout_res)
@@ -80,13 +80,13 @@ class BatchAnalyze:
             ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
                                           'lang':_lang,
                                           'ocr_enable':ocr_enable,
-                                          'np_array_img':np_array_img,
+                                          'pil_img':pil_img,
                                           'single_page_mfdetrec_res':single_page_mfdetrec_res,
                                           'layout_res':layout_res,
                                           })
 
             for table_res in table_res_list:
-                table_img, _ = crop_img(table_res, np_array_img)
+                table_img, _ = crop_img(table_res, pil_img)
                 table_res_list_all_page.append({'table_res':table_res,
                                                 'lang':_lang,
                                                 'table_img':table_img,
@@ -103,14 +103,14 @@ class BatchAnalyze:
 
                         for res in ocr_res_list_dict['ocr_res_list']:
                             new_image, useful_list = crop_img(
-                                res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
+                                res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
                             )
                             adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
                                 ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
                             )
 
                             # BGR转换
-                            new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
+                            new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
 
                             all_cropped_images_info.append((
                                 new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
@@ -215,37 +215,36 @@ class BatchAnalyze:
                         )
                         for res in ocr_res_list_dict['ocr_res_list']:
                             new_image, useful_list = crop_img(
-                                res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
+                                res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
                             )
                             adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
                                 ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
                             )
-
-                        # OCR-det
-                        new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
-                        ocr_res = ocr_model.ocr(
-                            new_image, mfd_res=adjusted_mfdetrec_res, rec=False
-                        )[0]
-
-                        # Integration results
-                        if ocr_res:
-                            ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],
-                                                                  new_image, _lang)
-
-                            if res["category_id"] == 3:
-                                # ocr_result_list中所有bbox的面积之和
-                                ocr_res_area = sum(
-                                    get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
-                                # 求ocr_res_area和res的面积的比值
-                                res_area = get_coords_and_area(res)[4]
-                                if res_area > 0:
-                                    ratio = ocr_res_area / res_area
-                                    if ratio > 0.25:
-                                        res["category_id"] = 1
-                                    else:
-                                        continue
-
-                            ocr_res_list_dict['layout_res'].extend(ocr_result_list)
+                            # OCR-det
+                            new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
+                            ocr_res = ocr_model.ocr(
+                                new_image, mfd_res=adjusted_mfdetrec_res, rec=False
+                            )[0]
+
+                            # Integration results
+                            if ocr_res:
+                                ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],
+                                                                      new_image, _lang)
+
+                                if res["category_id"] == 3:
+                                    # ocr_result_list中所有bbox的面积之和
+                                    ocr_res_area = sum(
+                                        get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
+                                    # 求ocr_res_area和res的面积的比值
+                                    res_area = get_coords_and_area(res)[4]
+                                    if res_area > 0:
+                                        ratio = ocr_res_area / res_area
+                                        if ratio > 0.25:
+                                            res["category_id"] = 1
+                                        else:
+                                            continue
+
+                                ocr_res_list_dict['layout_res'].extend(ocr_result_list)
 
         # 表格识别 table recognition
         if self.table_enable:

+ 9 - 11
mineru/cli/common.py

@@ -8,13 +8,13 @@ import pypdfium2 as pdfium
 from loguru import logger
 
 from mineru.backend.pipeline.model_json_to_middle_json import result_to_middle_json as pipeline_result_to_middle_json
-from ..api.vlm_middle_json_mkcontent import union_make
-from ..backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
-from ..backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
-from ..data.data_reader_writer import FileBasedDataWriter
-from ..utils.draw_bbox import draw_layout_bbox, draw_span_bbox
-from ..utils.enum_class import MakeMode
-from ..utils.pdf_image_tools import images_bytes_to_pdf_bytes
+from mineru.api.vlm_middle_json_mkcontent import union_make
+from mineru.backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
+from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
+from mineru.data.data_reader_writer import FileBasedDataWriter
+from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox
+from mineru.utils.enum_class import MakeMode
+from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes
 
 pdf_suffixes = [".pdf"]
 image_suffixes = [".png", ".jpeg", ".jpg"]
@@ -211,11 +211,9 @@ def do_parse(
 
 
 if __name__ == "__main__":
-    pdf_path = "../../demo/demo2.pdf"
+    pdf_path = "../../demo/pdfs/demo2.pdf"
     with open(pdf_path, "rb") as f:
         try:
-            result = do_parse("./output", Path(pdf_path).stem, f.read())
+           do_parse("./output", [Path(pdf_path).stem], [f.read()],["ch"],)
         except Exception as e:
             logger.exception(e)
-        # dict转成json
-        print(json.dumps(result, ensure_ascii=False, indent=4))

+ 18 - 8
mineru/utils/model_utils.py

@@ -1,13 +1,14 @@
 import time
 import torch
 import gc
+from PIL import Image
 from loguru import logger
 import numpy as np
 
 from mineru.utils.boxbase import get_minbox_if_overlap_by_ratio
 
 
-def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):
+def crop_img(input_res, input_img, crop_paste_x=0, crop_paste_y=0):
 
     crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
     crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
@@ -16,15 +17,24 @@ def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):
     crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
     crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
 
-    # Create a white background array
-    return_image = np.ones((crop_new_height, crop_new_width, 3), dtype=np.uint8) * 255
+    if isinstance(input_img, np.ndarray):
 
-    # Crop the original image using numpy slicing
-    cropped_img = input_np_img[crop_ymin:crop_ymax, crop_xmin:crop_xmax]
+        # Create a white background array
+        return_image = np.ones((crop_new_height, crop_new_width, 3), dtype=np.uint8) * 255
 
-    # Paste the cropped image onto the white background
-    return_image[crop_paste_y:crop_paste_y + (crop_ymax - crop_ymin),
-    crop_paste_x:crop_paste_x + (crop_xmax - crop_xmin)] = cropped_img
+        # Crop the original image using numpy slicing
+        cropped_img = input_img[crop_ymin:crop_ymax, crop_xmin:crop_xmax]
+
+        # Paste the cropped image onto the white background
+        return_image[crop_paste_y:crop_paste_y + (crop_ymax - crop_ymin),
+        crop_paste_x:crop_paste_x + (crop_xmax - crop_xmin)] = cropped_img
+    else:
+        # Create a white background array
+        return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
+        # Crop image
+        crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
+        cropped_img = input_img.crop(crop_box)
+        return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
 
     return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width,
                    crop_new_height]