Sfoglia il codice sorgente

refactor: enhance image margin cropping and processing for improved handling of PIL and NumPy images

myhloli 5 mesi fa
parent
commit
7a22bfeebe

+ 2 - 2
mineru/model/mfr/unimernet/Unimernet.py

@@ -70,7 +70,7 @@ class UnimernetModel(object):
         # Collect images with their original indices
         # Collect images with their original indices
         for image_index in range(len(images_mfd_res)):
         for image_index in range(len(images_mfd_res)):
             mfd_res = images_mfd_res[image_index]
             mfd_res = images_mfd_res[image_index]
-            np_array_image = images[image_index]
+            pil_img = images[image_index]
             formula_list = []
             formula_list = []
 
 
             for idx, (xyxy, conf, cla) in enumerate(zip(
             for idx, (xyxy, conf, cla) in enumerate(zip(
@@ -84,7 +84,7 @@ class UnimernetModel(object):
                     "latex": "",
                     "latex": "",
                 }
                 }
                 formula_list.append(new_item)
                 formula_list.append(new_item)
-                bbox_img = np_array_image[ymin:ymax, xmin:xmax]
+                bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
                 area = (xmax - xmin) * (ymax - ymin)
                 area = (xmax - xmin) * (ymax - ymin)
 
 
                 curr_idx = len(mf_image_list)
                 curr_idx = len(mf_image_list)

+ 75 - 33
mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py

@@ -1,8 +1,10 @@
+from PIL import Image, ImageOps
 from transformers.image_processing_utils import BaseImageProcessor
 from transformers.image_processing_utils import BaseImageProcessor
 import numpy as np
 import numpy as np
 import cv2
 import cv2
 import albumentations as alb
 import albumentations as alb
 from albumentations.pytorch import ToTensorV2
 from albumentations.pytorch import ToTensorV2
+from torchvision.transforms.functional import resize
 
 
 
 
 # TODO: dereference cv2 if possible
 # TODO: dereference cv2 if possible
@@ -28,6 +30,21 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
         return self.transform(image=image)['image'][:1]
         return self.transform(image=image)['image'][:1]
 
 
     @staticmethod
     @staticmethod
+    def crop_margin(img: Image.Image) -> Image.Image:
+        data = np.array(img.convert("L"))
+        data = data.astype(np.uint8)
+        max_val = data.max()
+        min_val = data.min()
+        if max_val == min_val:
+            return img
+        data = (data - min_val) / (max_val - min_val) * 255
+        gray = 255 * (data < 200).astype(np.uint8)
+
+        coords = cv2.findNonZero(gray)  # Find all non-zero points (text)
+        a, b, w, h = cv2.boundingRect(coords)  # Find minimum spanning bounding box
+        return img.crop((a, b, w + a, h + b))
+
+    @staticmethod
     def crop_margin_numpy(img: np.ndarray) -> np.ndarray:
     def crop_margin_numpy(img: np.ndarray) -> np.ndarray:
         """Crop margins of image using NumPy operations"""
         """Crop margins of image using NumPy operations"""
         # Convert to grayscale if it's a color image
         # Convert to grayscale if it's a color image
@@ -60,48 +77,73 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
         if img is None:
         if img is None:
             return None
             return None
 
 
-        # try:
-        #     img = self.crop_margin_numpy(img)
-        # except Exception:
-        #     # might throw an error for broken files
-        #     return None
+        # Handle numpy array
+        elif isinstance(img, np.ndarray):
+            try:
+                img = self.crop_margin_numpy(img)
+            except Exception:
+                # might throw an error for broken files
+                return None
 
 
-        if img.shape[0] == 0 or img.shape[1] == 0:
-            return None
+            if img.shape[0] == 0 or img.shape[1] == 0:
+                return None
 
 
-        # Get current dimensions
-        h, w = img.shape[:2]
-        target_h, target_w = self.input_size
+            # Get current dimensions
+            h, w = img.shape[:2]
+            target_h, target_w = self.input_size
 
 
-        # Calculate scale to preserve aspect ratio (equivalent to resize + thumbnail)
-        scale = min(target_h / h, target_w / w)
+            # Calculate scale to preserve aspect ratio (equivalent to resize + thumbnail)
+            scale = min(target_h / h, target_w / w)
 
 
-        # Calculate new dimensions
-        new_h, new_w = int(h * scale), int(w * scale)
+            # Calculate new dimensions
+            new_h, new_w = int(h * scale), int(w * scale)
 
 
-        # Resize the image while preserving aspect ratio
-        resized_img = cv2.resize(img, (new_w, new_h))
+            # Resize the image while preserving aspect ratio
+            resized_img = cv2.resize(img, (new_w, new_h))
 
 
-        # Calculate padding values using the existing method
-        delta_width = target_w - new_w
-        delta_height = target_h - new_h
+            # Calculate padding values using the existing method
+            delta_width = target_w - new_w
+            delta_height = target_h - new_h
 
 
-        pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
+            pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
 
 
-        # Apply padding (convert PIL padding format to OpenCV format)
-        padding_color = [0, 0, 0] if len(img.shape) == 3 else [0]
-
-        padded_img = cv2.copyMakeBorder(
-            resized_img,
-            pad_height,  # top
-            delta_height - pad_height,  # bottom
-            pad_width,  # left
-            delta_width - pad_width,  # right
-            cv2.BORDER_CONSTANT,
-            value=padding_color
-        )
+            # Apply padding (convert PIL padding format to OpenCV format)
+            padding_color = [0, 0, 0] if len(img.shape) == 3 else [0]
+
+            padded_img = cv2.copyMakeBorder(
+                resized_img,
+                pad_height,  # top
+                delta_height - pad_height,  # bottom
+                pad_width,  # left
+                delta_width - pad_width,  # right
+                cv2.BORDER_CONSTANT,
+                value=padding_color
+            )
 
 
-        return padded_img
+            return padded_img
+
+        # Handle PIL Image
+        elif isinstance(img, Image.Image):
+            try:
+                img = self.crop_margin(img.convert("RGB"))
+            except OSError:
+                # might throw an error for broken files
+                return None
+
+            if img.height == 0 or img.width == 0:
+                return None
+
+            # Resize while preserving aspect ratio
+            img = resize(img, min(self.input_size))
+            img.thumbnail((self.input_size[1], self.input_size[0]))
+            new_w, new_h = img.width, img.height
+
+            # Calculate and apply padding
+            padding = self._calculate_padding(new_w, new_h, random_padding)
+            return np.array(ImageOps.expand(img, padding))
+
+        else:
+            return None
 
 
     def _calculate_padding(self, new_w, new_h, random_padding):
     def _calculate_padding(self, new_w, new_h, random_padding):
         """Calculate padding values for PIL images"""
         """Calculate padding values for PIL images"""