瀏覽代碼

refactor: enhance image margin cropping and processing for improved handling of PIL and NumPy images

myhloli 5 月之前
父節點
當前提交
7a22bfeebe

+ 2 - 2
mineru/model/mfr/unimernet/Unimernet.py

@@ -70,7 +70,7 @@ class UnimernetModel(object):
         # Collect images with their original indices
         for image_index in range(len(images_mfd_res)):
             mfd_res = images_mfd_res[image_index]
-            np_array_image = images[image_index]
+            pil_img = images[image_index]
             formula_list = []
 
             for idx, (xyxy, conf, cla) in enumerate(zip(
@@ -84,7 +84,7 @@ class UnimernetModel(object):
                     "latex": "",
                 }
                 formula_list.append(new_item)
-                bbox_img = np_array_image[ymin:ymax, xmin:xmax]
+                bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
                 area = (xmax - xmin) * (ymax - ymin)
 
                 curr_idx = len(mf_image_list)

+ 75 - 33
mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py

@@ -1,8 +1,10 @@
+from PIL import Image, ImageOps
 from transformers.image_processing_utils import BaseImageProcessor
 import numpy as np
 import cv2
 import albumentations as alb
 from albumentations.pytorch import ToTensorV2
+from torchvision.transforms.functional import resize
 
 
 # TODO: dereference cv2 if possible
@@ -28,6 +30,21 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
         return self.transform(image=image)['image'][:1]
 
     @staticmethod
+    def crop_margin(img: Image.Image) -> Image.Image:
+        data = np.array(img.convert("L"))
+        data = data.astype(np.uint8)
+        max_val = data.max()
+        min_val = data.min()
+        if max_val == min_val:
+            return img
+        data = (data - min_val) / (max_val - min_val) * 255
+        gray = 255 * (data < 200).astype(np.uint8)
+
+        coords = cv2.findNonZero(gray)  # Find all non-zero points (text)
+        a, b, w, h = cv2.boundingRect(coords)  # Find minimum spanning bounding box
+        return img.crop((a, b, w + a, h + b))
+
+    @staticmethod
     def crop_margin_numpy(img: np.ndarray) -> np.ndarray:
         """Crop margins of image using NumPy operations"""
         # Convert to grayscale if it's a color image
@@ -60,48 +77,73 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
         if img is None:
             return None
 
-        # try:
-        #     img = self.crop_margin_numpy(img)
-        # except Exception:
-        #     # might throw an error for broken files
-        #     return None
+        # Handle numpy array
+        elif isinstance(img, np.ndarray):
+            try:
+                img = self.crop_margin_numpy(img)
+            except Exception:
+                # might throw an error for broken files
+                return None
 
-        if img.shape[0] == 0 or img.shape[1] == 0:
-            return None
+            if img.shape[0] == 0 or img.shape[1] == 0:
+                return None
 
-        # Get current dimensions
-        h, w = img.shape[:2]
-        target_h, target_w = self.input_size
+            # Get current dimensions
+            h, w = img.shape[:2]
+            target_h, target_w = self.input_size
 
-        # Calculate scale to preserve aspect ratio (equivalent to resize + thumbnail)
-        scale = min(target_h / h, target_w / w)
+            # Calculate scale to preserve aspect ratio (equivalent to resize + thumbnail)
+            scale = min(target_h / h, target_w / w)
 
-        # Calculate new dimensions
-        new_h, new_w = int(h * scale), int(w * scale)
+            # Calculate new dimensions
+            new_h, new_w = int(h * scale), int(w * scale)
 
-        # Resize the image while preserving aspect ratio
-        resized_img = cv2.resize(img, (new_w, new_h))
+            # Resize the image while preserving aspect ratio
+            resized_img = cv2.resize(img, (new_w, new_h))
 
-        # Calculate padding values using the existing method
-        delta_width = target_w - new_w
-        delta_height = target_h - new_h
+            # Calculate padding values using the existing method
+            delta_width = target_w - new_w
+            delta_height = target_h - new_h
 
-        pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
+            pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
 
-        # Apply padding (convert PIL padding format to OpenCV format)
-        padding_color = [0, 0, 0] if len(img.shape) == 3 else [0]
-
-        padded_img = cv2.copyMakeBorder(
-            resized_img,
-            pad_height,  # top
-            delta_height - pad_height,  # bottom
-            pad_width,  # left
-            delta_width - pad_width,  # right
-            cv2.BORDER_CONSTANT,
-            value=padding_color
-        )
+            # Apply padding (convert PIL padding format to OpenCV format)
+            padding_color = [0, 0, 0] if len(img.shape) == 3 else [0]
+
+            padded_img = cv2.copyMakeBorder(
+                resized_img,
+                pad_height,  # top
+                delta_height - pad_height,  # bottom
+                pad_width,  # left
+                delta_width - pad_width,  # right
+                cv2.BORDER_CONSTANT,
+                value=padding_color
+            )
 
-        return padded_img
+            return padded_img
+
+        # Handle PIL Image
+        elif isinstance(img, Image.Image):
+            try:
+                img = self.crop_margin(img.convert("RGB"))
+            except OSError:
+                # might throw an error for broken files
+                return None
+
+            if img.height == 0 or img.width == 0:
+                return None
+
+            # Resize while preserving aspect ratio
+            img = resize(img, min(self.input_size))
+            img.thumbnail((self.input_size[1], self.input_size[0]))
+            new_w, new_h = img.width, img.height
+
+            # Calculate and apply padding
+            padding = self._calculate_padding(new_w, new_h, random_padding)
+            return np.array(ImageOps.expand(img, padding))
+
+        else:
+            return None
 
     def _calculate_padding(self, new_w, new_h, random_padding):
         """Calculate padding values for PIL images"""