Преглед изворни кода

refactor(magic_pdf): replace PIL with NumPy for image processing

- Remove PIL usage across multiple files
- Convert image processing functions to use NumPy arrays
- Update crop_img function to work with NumPy arrays
- Modify image loading and resizing to use NumPy and OpenCV
- Clean up unused imports and comments related to PIL
myhloli пре 8 месеци
родитељ
комит
1b34f7e4ff

+ 5 - 8
magic_pdf/data/utils.py

@@ -3,10 +3,8 @@ import fitz
 import numpy as np
 from loguru import logger
 
-from magic_pdf.utils.annotations import ImportPIL
 
 
-@ImportPIL
 def fitz_doc_to_image(doc, dpi=200) -> dict:
     """Convert fitz.Document to image, Then convert the image to numpy array.
 
@@ -17,7 +15,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
     Returns:
         dict:  {'img': numpy array, 'width': width, 'height': height }
     """
-    from PIL import Image
     mat = fitz.Matrix(dpi / 72, dpi / 72)
     pm = doc.get_pixmap(matrix=mat, alpha=False)
 
@@ -25,8 +22,8 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
     if pm.width > 4500 or pm.height > 4500:
         pm = doc.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
 
-    img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
-    img = np.array(img)
+    # Convert pixmap samples directly to numpy array
+    img = np.frombuffer(pm.samples, dtype=np.uint8).reshape(pm.height, pm.width, 3)
 
     img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
 
@@ -34,7 +31,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
 
 @ImportPIL
 def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
-    from PIL import Image
     images = []
     with fitz.open('pdf', pdf_bytes) as doc:
         pdf_page_num = doc.page_count
@@ -57,8 +53,9 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
                 if pm.width > 4500 or pm.height > 4500:
                     pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
 
-                img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
-                img = np.array(img)
+                # Convert pixmap samples directly to numpy array
+                img = np.frombuffer(pm.samples, dtype=np.uint8).reshape(pm.height, pm.width, 3)
+
                 img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
             else:
                 img_dict = {'img': [], 'width': 0, 'height': 0}

+ 11 - 6
magic_pdf/libs/pdf_image_tools.py

@@ -44,14 +44,19 @@ def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"):
     # 截取图片
     pix = page.get_pixmap(clip=rect, matrix=zoom)
 
-    # 将字节数据转换为文件对象
-    image_file = BytesIO(pix.tobytes(output='png'))
-    # 使用 Pillow 打开图像
-    pil_image = Image.open(image_file)
     if mode == "cv2":
-        image_result = cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_RGB2BGR)
+        # 直接转换为numpy数组供cv2使用
+        img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
+        # PyMuPDF使用RGB顺序,而cv2使用BGR顺序
+        if pix.n == 3 or pix.n == 4:
+            image_result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
+        else:
+            image_result = img_array
     elif mode == "pillow":
-        image_result = pil_image
+        # 将字节数据转换为文件对象
+        image_file = BytesIO(pix.tobytes(output='png'))
+        # 使用 Pillow 打开图像
+        image_result = Image.open(image_file)
     else:
         raise ValueError(f"mode: {mode} is not supported.")
 

+ 5 - 116
magic_pdf/model/batch_analyze.py

@@ -1,23 +1,15 @@
 import time
 
 import cv2
-import numpy as np
 import torch
 from loguru import logger
-from PIL import Image
 
 from magic_pdf.config.constants import MODEL_NAME
-# from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
-# from magic_pdf.data.dataset import Dataset
-# from magic_pdf.libs.clean_memory import clean_memory
-# from magic_pdf.libs.config_reader import get_device
-# from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
 from magic_pdf.model.pdf_extract_kit import CustomPEKModel
 from magic_pdf.model.sub_modules.model_utils import (
     clean_vram, crop_img, get_res_list_from_layout_res)
 from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
     get_adjusted_mfdetrec_res, get_ocr_result_list)
-# from magic_pdf.operators.models import InferenceResult
 
 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
@@ -31,7 +23,6 @@ class BatchAnalyze:
 
     def __call__(self, images: list) -> list:
         images_layout_res = []
-
         layout_start_time = time.time()
         if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
             # layoutlmv3
@@ -41,36 +32,14 @@ class BatchAnalyze:
         elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
             # doclayout_yolo
             layout_images = []
-            modified_images = []
             for image_index, image in enumerate(images):
-                pil_img = Image.fromarray(image)
-                # width, height = pil_img.size
-                # if height > width:
-                #     input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
-                #     new_image, useful_list = crop_img(
-                #         input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
-                #     )
-                #     layout_images.append(new_image)
-                #     modified_images.append([image_index, useful_list])
-                # else:
-                layout_images.append(pil_img)
+                layout_images.append(image)
 
             images_layout_res += self.model.layout_model.batch_predict(
                 # layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
                 layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
             )
 
-            for image_index, useful_list in modified_images:
-                for res in images_layout_res[image_index]:
-                    for i in range(len(res['poly'])):
-                        if i % 2 == 0:
-                            res['poly'][i] = (
-                                res['poly'][i] - useful_list[0] + useful_list[2]
-                            )
-                        else:
-                            res['poly'][i] = (
-                                res['poly'][i] - useful_list[1] + useful_list[3]
-                            )
         logger.info(
             f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
         )
@@ -111,7 +80,7 @@ class BatchAnalyze:
         # reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
         for index in range(len(images)):
             layout_res = images_layout_res[index]
-            pil_img = Image.fromarray(images[index])
+            np_array_img = images[index]
 
             ocr_res_list, table_res_list, single_page_mfdetrec_res = (
                 get_res_list_from_layout_res(layout_res)
@@ -121,14 +90,14 @@ class BatchAnalyze:
             # Process each area that requires OCR processing
             for res in ocr_res_list:
                 new_image, useful_list = crop_img(
-                    res, pil_img, crop_paste_x=50, crop_paste_y=50
+                    res, np_array_img, crop_paste_x=50, crop_paste_y=50
                 )
                 adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
                     single_page_mfdetrec_res, useful_list
                 )
 
                 # OCR recognition
-                new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
+                new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
 
                 if self.model.apply_ocr:
                     ocr_res = self.model.ocr_model.ocr(
@@ -150,7 +119,7 @@ class BatchAnalyze:
             if self.model.apply_table:
                 table_start = time.time()
                 for res in table_res_list:
-                    new_image, _ = crop_img(res, pil_img)
+                    new_image, _ = crop_img(res, np_array_img)
                     single_table_start_time = time.time()
                     html_code = None
                     if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
@@ -197,83 +166,3 @@ class BatchAnalyze:
             logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
 
         return images_layout_res
-
-
-# def doc_batch_analyze(
-#     dataset: Dataset,
-#     ocr: bool = False,
-#     show_log: bool = False,
-#     start_page_id=0,
-#     end_page_id=None,
-#     lang=None,
-#     layout_model=None,
-#     formula_enable=None,
-#     table_enable=None,
-#     batch_ratio: int | None = None,
-# ) -> InferenceResult:
-#     """Perform batch analysis on a document dataset.
-#
-#     Args:
-#         dataset (Dataset): The dataset containing document pages to be analyzed.
-#         ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
-#         show_log (bool, optional): Flag to enable logging. Defaults to False.
-#         start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
-#         end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
-#         lang (str, optional): Language for OCR. Defaults to None.
-#         layout_model (optional): Layout model to be used for analysis. Defaults to None.
-#         formula_enable (optional): Flag to enable formula detection. Defaults to None.
-#         table_enable (optional): Flag to enable table detection. Defaults to None.
-#         batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
-#
-#     Raises:
-#         CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
-#
-#     Returns:
-#         InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
-#     """
-#
-#     if not torch.cuda.is_available():
-#         raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
-#
-#     lang = None if lang == '' else lang
-#     # TODO: auto detect batch size
-#     batch_ratio = 1 if batch_ratio is None else batch_ratio
-#     end_page_id = end_page_id if end_page_id else len(dataset)
-#
-#     model_manager = ModelSingleton()
-#     custom_model: CustomPEKModel = model_manager.get_model(
-#         ocr, show_log, lang, layout_model, formula_enable, table_enable
-#     )
-#     batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
-#
-#     model_json = []
-#
-#     # batch analyze
-#     images = []
-#     for index in range(len(dataset)):
-#         if start_page_id <= index <= end_page_id:
-#             page_data = dataset.get_page(index)
-#             img_dict = page_data.get_image()
-#             images.append(img_dict['img'])
-#     analyze_result = batch_model(images)
-#
-#     for index in range(len(dataset)):
-#         page_data = dataset.get_page(index)
-#         img_dict = page_data.get_image()
-#         page_width = img_dict['width']
-#         page_height = img_dict['height']
-#         if start_page_id <= index <= end_page_id:
-#             result = analyze_result.pop(0)
-#         else:
-#             result = []
-#
-#         page_info = {'page_no': index, 'height': page_height, 'width': page_width}
-#         page_dict = {'layout_dets': result, 'page_info': page_info}
-#         model_json.append(page_dict)
-#
-#     # TODO: clean memory when gpu memory is not enough
-#     clean_memory_start_time = time.time()
-#     clean_memory(get_device())
-#     logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
-#
-#     return InferenceResult(model_json, dataset)

+ 3 - 28
magic_pdf/model/pdf_extract_kit.py

@@ -3,11 +3,9 @@ import os
 import time
 
 import cv2
-import numpy as np
 import torch
 import yaml
 from loguru import logger
-from PIL import Image
 
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
 
@@ -174,11 +172,6 @@ class CustomPEKModel:
         logger.info('DocAnalysis init done!')
 
     def __call__(self, image):
-
-        pil_img = Image.fromarray(image)
-        width, height = pil_img.size
-        # logger.info(f'width: {width}, height: {height}')
-
         # layout检测
         layout_start = time.time()
         layout_res = []
@@ -186,24 +179,6 @@ class CustomPEKModel:
             # layoutlmv3
             layout_res = self.layout_model(image, ignore_catids=[])
         elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
-            # doclayout_yolo
-            # if height > width:
-            #     input_res = {"poly":[0,0,width,0,width,height,0,height]}
-            #     new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
-            #     paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
-            #     layout_res = self.layout_model.predict(new_image)
-            #     for res in layout_res:
-            #         p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
-            #         p1 = p1 - paste_x + xmin
-            #         p2 = p2 - paste_y + ymin
-            #         p3 = p3 - paste_x + xmin
-            #         p4 = p4 - paste_y + ymin
-            #         p5 = p5 - paste_x + xmin
-            #         p6 = p6 - paste_y + ymin
-            #         p7 = p7 - paste_x + xmin
-            #         p8 = p8 - paste_y + ymin
-            #         res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
-            # else:
             layout_res = self.layout_model.predict(image)
 
         layout_cost = round(time.time() - layout_start, 2)
@@ -234,11 +209,11 @@ class CustomPEKModel:
         ocr_start = time.time()
         # Process each area that requires OCR processing
         for res in ocr_res_list:
-            new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
+            new_image, useful_list = crop_img(res, image, crop_paste_x=50, crop_paste_y=50)
             adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
 
             # OCR recognition
-            new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
+            new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
 
             if self.apply_ocr:
                 ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
@@ -260,7 +235,7 @@ class CustomPEKModel:
         if self.apply_table:
             table_start = time.time()
             for res in table_res_list:
-                new_image, _ = crop_img(res, pil_img)
+                new_image, _ = crop_img(res, image)
                 single_table_start_time = time.time()
                 html_code = None
                 if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:

+ 2 - 4
magic_pdf/model/sub_modules/language_detection/utils.py

@@ -3,8 +3,6 @@ import os
 from pathlib import Path
 
 import yaml
-from PIL import Image
-
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
 
 from magic_pdf.config.constants import MODEL_NAME
@@ -42,7 +40,7 @@ def get_text_images(simple_images):
     )
     text_images = []
     for simple_image in simple_images:
-        image = Image.fromarray(simple_image['img'])
+        image = simple_image['img']
         layout_res = temp_layout_model.predict(image)
         # 给textblock截图
         for res in layout_res:
@@ -51,7 +49,7 @@ def get_text_images(simple_images):
                 # 初步清洗(宽和高都小于100)
                 if x2 - x1 < 100 and y2 - y1 < 100:
                     continue
-                text_images.append(image.crop((x1, y1, x2, y2)))
+                text_images.append(image[y1:y2, x1:x2])
     return text_images
 
 

+ 20 - 10
magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py

@@ -3,8 +3,8 @@ import time
 from collections import Counter
 from uuid import uuid4
 
+import numpy as np
 import torch
-from PIL import Image
 from loguru import logger
 from ultralytics import YOLO
 
@@ -64,21 +64,32 @@ def split_images(image, result_images=None):
 
 def resize_images_to_224(image):
     """
-    若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小,并保存到输出文件夹中。
+    若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小。
+    Works directly with NumPy arrays.
     """
     try:
-        width, height = image.size
+        # Handle numpy array directly
+        if len(image.shape) == 3:  # Color image
+            height, width, channels = image.shape
+        else:  # Grayscale image
+            height, width = image.shape
+            image = np.stack([image] * 3, axis=2)  # Convert to RGB
+
         if width < 224 or height < 224:
-            new_image = Image.new('RGB', (224, 224), (0, 0, 0))
+            # Create black background
+            new_image = np.zeros((224, 224, 3), dtype=np.uint8)
+            # Calculate paste position
             paste_x = (224 - width) // 2
             paste_y = (224 - height) // 2
-            new_image.paste(image, (paste_x, paste_y))
+            # Paste original image onto black background
+            new_image[paste_y:paste_y + height, paste_x:paste_x + width] = image
             image = new_image
         else:
-            image = image.resize((224, 224), Image.Resampling.LANCZOS)
+            # Resize using cv2 functionality or numpy interpolation
+            # Method 1: Using cv2 (preferred for better quality)
+            import cv2
+            image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LANCZOS4)
 
-        # uuid = str(uuid4())
-        # image.save(f"/tmp/{uuid}.jpg")
         return image
     except Exception as e:
         logger.exception(e)
@@ -96,8 +107,7 @@ class YOLOv11LangDetModel(object):
     def do_detect(self, images: list):
         all_images = []
         for image in images:
-            width, height = image.size
-            # logger.info(f"image size: {width} x {height}")
+            height, width = image.shape[:2]
             if width < 100 and height < 100:
                 continue
             temp_images = split_images(image)

+ 2 - 42
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py

@@ -4,7 +4,6 @@ import re
 
 import torch
 import unimernet.tasks as tasks
-from PIL import Image
 from torch.utils.data import DataLoader, Dataset
 from torchvision import transforms
 from unimernet.common.config import Config
@@ -100,45 +99,6 @@ class UnimernetModel(object):
             res["latex"] = latex_rm_whitespace(latex)
         return formula_list
 
-    # def batch_predict(
-    #     self, images_mfd_res: list, images: list, batch_size: int = 64
-    # ) -> list:
-    #     images_formula_list = []
-    #     mf_image_list = []
-    #     backfill_list = []
-    #     for image_index in range(len(images_mfd_res)):
-    #         mfd_res = images_mfd_res[image_index]
-    #         pil_img = Image.fromarray(images[image_index])
-    #         formula_list = []
-    #
-    #         for xyxy, conf, cla in zip(
-    #             mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
-    #         ):
-    #             xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
-    #             new_item = {
-    #                 "category_id": 13 + int(cla.item()),
-    #                 "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
-    #                 "score": round(float(conf.item()), 2),
-    #                 "latex": "",
-    #             }
-    #             formula_list.append(new_item)
-    #             bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
-    #             mf_image_list.append(bbox_img)
-    #
-    #         images_formula_list.append(formula_list)
-    #         backfill_list += formula_list
-    #
-    #     dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
-    #     dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
-    #     mfr_res = []
-    #     for mf_img in dataloader:
-    #         mf_img = mf_img.to(self.device)
-    #         with torch.no_grad():
-    #             output = self.model.generate({"image": mf_img})
-    #         mfr_res.extend(output["pred_str"])
-    #     for res, latex in zip(backfill_list, mfr_res):
-    #         res["latex"] = latex_rm_whitespace(latex)
-    #     return images_formula_list
 
     def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
         images_formula_list = []
@@ -149,7 +109,7 @@ class UnimernetModel(object):
         # Collect images with their original indices
         for image_index in range(len(images_mfd_res)):
             mfd_res = images_mfd_res[image_index]
-            pil_img = Image.fromarray(images[image_index])
+            np_array_image = images[image_index]
             formula_list = []
 
             for idx, (xyxy, conf, cla) in enumerate(zip(
@@ -163,7 +123,7 @@ class UnimernetModel(object):
                     "latex": "",
                 }
                 formula_list.append(new_item)
-                bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
+                bbox_img = np_array_image[ymin:ymax, xmin:xmax]
                 area = (xmax - xmin) * (ymax - ymin)
 
                 curr_idx = len(mf_image_list)

+ 17 - 11
magic_pdf/model/sub_modules/model_utils.py

@@ -1,25 +1,31 @@
 import time
-
 import torch
-from PIL import Image
 from loguru import logger
-
+import numpy as np
 from magic_pdf.libs.clean_memory import clean_memory
 
 
-def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
+def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):
+
     crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
     crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
-    # Create a white background with an additional width and height of 50
+
+    # Calculate new dimensions
     crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
     crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
-    return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
 
-    # Crop image
-    crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
-    cropped_img = input_pil_img.crop(crop_box)
-    return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
-    return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
+    # Create a white background array
+    return_image = np.ones((crop_new_height, crop_new_width, 3), dtype=np.uint8) * 255
+
+    # Crop the original image using numpy slicing
+    cropped_img = input_np_img[crop_ymin:crop_ymax, crop_xmin:crop_xmax]
+
+    # Paste the cropped image onto the white background
+    return_image[crop_paste_y:crop_paste_y + (crop_ymax - crop_ymin),
+    crop_paste_x:crop_paste_x + (crop_xmax - crop_xmin)] = cropped_img
+
+    return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width,
+                   crop_new_height]
     return return_image, return_list