|
@@ -1,8 +1,10 @@
|
|
|
|
|
+from PIL import Image, ImageOps
|
|
|
from transformers.image_processing_utils import BaseImageProcessor
|
|
from transformers.image_processing_utils import BaseImageProcessor
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import cv2
|
|
import cv2
|
|
|
import albumentations as alb
|
|
import albumentations as alb
|
|
|
from albumentations.pytorch import ToTensorV2
|
|
from albumentations.pytorch import ToTensorV2
|
|
|
|
|
+from torchvision.transforms.functional import resize
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: dereference cv2 if possible
|
|
# TODO: dereference cv2 if possible
|
|
@@ -28,6 +30,21 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
|
|
|
return self.transform(image=image)['image'][:1]
|
|
return self.transform(image=image)['image'][:1]
|
|
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
|
|
+ def crop_margin(img: Image.Image) -> Image.Image:
|
|
|
|
|
+ data = np.array(img.convert("L"))
|
|
|
|
|
+ data = data.astype(np.uint8)
|
|
|
|
|
+ max_val = data.max()
|
|
|
|
|
+ min_val = data.min()
|
|
|
|
|
+ if max_val == min_val:
|
|
|
|
|
+ return img
|
|
|
|
|
+ data = (data - min_val) / (max_val - min_val) * 255
|
|
|
|
|
+ gray = 255 * (data < 200).astype(np.uint8)
|
|
|
|
|
+
|
|
|
|
|
+ coords = cv2.findNonZero(gray) # Find all non-zero points (text)
|
|
|
|
|
+ a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
|
|
|
|
|
+ return img.crop((a, b, w + a, h + b))
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
def crop_margin_numpy(img: np.ndarray) -> np.ndarray:
|
|
def crop_margin_numpy(img: np.ndarray) -> np.ndarray:
|
|
|
"""Crop margins of image using NumPy operations"""
|
|
"""Crop margins of image using NumPy operations"""
|
|
|
# Convert to grayscale if it's a color image
|
|
# Convert to grayscale if it's a color image
|
|
@@ -60,48 +77,73 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
|
|
|
if img is None:
|
|
if img is None:
|
|
|
return None
|
|
return None
|
|
|
|
|
|
|
|
- # try:
|
|
|
|
|
- # img = self.crop_margin_numpy(img)
|
|
|
|
|
- # except Exception:
|
|
|
|
|
- # # might throw an error for broken files
|
|
|
|
|
- # return None
|
|
|
|
|
|
|
+ # Handle numpy array
|
|
|
|
|
+ elif isinstance(img, np.ndarray):
|
|
|
|
|
+ try:
|
|
|
|
|
+ img = self.crop_margin_numpy(img)
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ # might throw an error for broken files
|
|
|
|
|
+ return None
|
|
|
|
|
|
|
|
- if img.shape[0] == 0 or img.shape[1] == 0:
|
|
|
|
|
- return None
|
|
|
|
|
|
|
+ if img.shape[0] == 0 or img.shape[1] == 0:
|
|
|
|
|
+ return None
|
|
|
|
|
|
|
|
- # Get current dimensions
|
|
|
|
|
- h, w = img.shape[:2]
|
|
|
|
|
- target_h, target_w = self.input_size
|
|
|
|
|
|
|
+ # Get current dimensions
|
|
|
|
|
+ h, w = img.shape[:2]
|
|
|
|
|
+ target_h, target_w = self.input_size
|
|
|
|
|
|
|
|
- # Calculate scale to preserve aspect ratio (equivalent to resize + thumbnail)
|
|
|
|
|
- scale = min(target_h / h, target_w / w)
|
|
|
|
|
|
|
+ # Calculate scale to preserve aspect ratio (equivalent to resize + thumbnail)
|
|
|
|
|
+ scale = min(target_h / h, target_w / w)
|
|
|
|
|
|
|
|
- # Calculate new dimensions
|
|
|
|
|
- new_h, new_w = int(h * scale), int(w * scale)
|
|
|
|
|
|
|
+ # Calculate new dimensions
|
|
|
|
|
+ new_h, new_w = int(h * scale), int(w * scale)
|
|
|
|
|
|
|
|
- # Resize the image while preserving aspect ratio
|
|
|
|
|
- resized_img = cv2.resize(img, (new_w, new_h))
|
|
|
|
|
|
|
+ # Resize the image while preserving aspect ratio
|
|
|
|
|
+ resized_img = cv2.resize(img, (new_w, new_h))
|
|
|
|
|
|
|
|
- # Calculate padding values using the existing method
|
|
|
|
|
- delta_width = target_w - new_w
|
|
|
|
|
- delta_height = target_h - new_h
|
|
|
|
|
|
|
+ # Calculate padding values using the existing method
|
|
|
|
|
+ delta_width = target_w - new_w
|
|
|
|
|
+ delta_height = target_h - new_h
|
|
|
|
|
|
|
|
- pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
|
|
|
|
|
|
|
+ pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
|
|
|
|
|
|
|
|
- # Apply padding (convert PIL padding format to OpenCV format)
|
|
|
|
|
- padding_color = [0, 0, 0] if len(img.shape) == 3 else [0]
|
|
|
|
|
-
|
|
|
|
|
- padded_img = cv2.copyMakeBorder(
|
|
|
|
|
- resized_img,
|
|
|
|
|
- pad_height, # top
|
|
|
|
|
- delta_height - pad_height, # bottom
|
|
|
|
|
- pad_width, # left
|
|
|
|
|
- delta_width - pad_width, # right
|
|
|
|
|
- cv2.BORDER_CONSTANT,
|
|
|
|
|
- value=padding_color
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ # Apply padding (convert PIL padding format to OpenCV format)
|
|
|
|
|
+ padding_color = [0, 0, 0] if len(img.shape) == 3 else [0]
|
|
|
|
|
+
|
|
|
|
|
+ padded_img = cv2.copyMakeBorder(
|
|
|
|
|
+ resized_img,
|
|
|
|
|
+ pad_height, # top
|
|
|
|
|
+ delta_height - pad_height, # bottom
|
|
|
|
|
+ pad_width, # left
|
|
|
|
|
+ delta_width - pad_width, # right
|
|
|
|
|
+ cv2.BORDER_CONSTANT,
|
|
|
|
|
+ value=padding_color
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- return padded_img
|
|
|
|
|
|
|
+ return padded_img
|
|
|
|
|
+
|
|
|
|
|
+ # Handle PIL Image
|
|
|
|
|
+ elif isinstance(img, Image.Image):
|
|
|
|
|
+ try:
|
|
|
|
|
+ img = self.crop_margin(img.convert("RGB"))
|
|
|
|
|
+ except OSError:
|
|
|
|
|
+ # might throw an error for broken files
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ if img.height == 0 or img.width == 0:
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ # Resize while preserving aspect ratio
|
|
|
|
|
+ img = resize(img, min(self.input_size))
|
|
|
|
|
+ img.thumbnail((self.input_size[1], self.input_size[0]))
|
|
|
|
|
+ new_w, new_h = img.width, img.height
|
|
|
|
|
+
|
|
|
|
|
+ # Calculate and apply padding
|
|
|
|
|
+ padding = self._calculate_padding(new_w, new_h, random_padding)
|
|
|
|
|
+ return np.array(ImageOps.expand(img, padding))
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ return None
|
|
|
|
|
|
|
|
def _calculate_padding(self, new_w, new_h, random_padding):
|
|
def _calculate_padding(self, new_w, new_h, random_padding):
|
|
|
"""Calculate padding values for PIL images"""
|
|
"""Calculate padding values for PIL images"""
|