|
|
@@ -1,17 +1,15 @@
|
|
|
from transformers.image_processing_utils import BaseImageProcessor
|
|
|
-from PIL import Image, ImageOps
|
|
|
import numpy as np
|
|
|
import cv2
|
|
|
import albumentations as alb
|
|
|
from albumentations.pytorch import ToTensorV2
|
|
|
-from torchvision.transforms.functional import resize
|
|
|
|
|
|
|
|
|
# TODO: dereference cv2 if possible
|
|
|
class UnimerSwinImageProcessor(BaseImageProcessor):
|
|
|
def __init__(
|
|
|
self,
|
|
|
- image_size = [192, 672],
|
|
|
+ image_size = (192, 672),
|
|
|
):
|
|
|
self.input_size = [int(_) for _ in image_size]
|
|
|
assert len(self.input_size) == 2
|
|
|
@@ -27,56 +25,90 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
|
|
|
|
|
|
def __call__(self, item):
|
|
|
image = self.prepare_input(item)
|
|
|
- return self.transform(image=np.array(image))['image'][:1]
|
|
|
+ return self.transform(image=image)['image'][:1]
|
|
|
|
|
|
@staticmethod
|
|
|
- def crop_margin(img: Image.Image) -> Image.Image:
|
|
|
- data = np.array(img.convert("L"))
|
|
|
- data = data.astype(np.uint8)
|
|
|
- max_val = data.max()
|
|
|
- min_val = data.min()
|
|
|
- if max_val == min_val:
|
|
|
+ def crop_margin_numpy(img: np.ndarray) -> np.ndarray:
|
|
|
+ """Crop margins of image using NumPy operations"""
|
|
|
+ # Convert to grayscale if it's a color image
|
|
|
+ if len(img.shape) == 3 and img.shape[2] == 3:
|
|
|
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
|
|
+ else:
|
|
|
+ gray = img.copy()
|
|
|
+
|
|
|
+ # Normalize and threshold
|
|
|
+ if gray.max() == gray.min():
|
|
|
return img
|
|
|
- data = (data - min_val) / (max_val - min_val) * 255
|
|
|
- gray = 255 * (data < 200).astype(np.uint8)
|
|
|
|
|
|
- coords = cv2.findNonZero(gray) # Find all non-zero points (text)
|
|
|
- a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
|
|
|
- return img.crop((a, b, w + a, h + b))
|
|
|
+ normalized = (((gray - gray.min()) / (gray.max() - gray.min())) * 255).astype(np.uint8)
|
|
|
+ binary = 255 * (normalized < 200).astype(np.uint8)
|
|
|
+
|
|
|
+ # Find bounding box
|
|
|
+ coords = cv2.findNonZero(binary) # Find all non-zero points (text)
|
|
|
+ x, y, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
|
|
|
|
|
|
- def prepare_input(self, img: Image.Image, random_padding: bool = False):
|
|
|
+ # Return cropped image
|
|
|
+ return img[y:y + h, x:x + w]
|
|
|
+
|
|
|
+ def prepare_input(self, img, random_padding: bool = False):
|
|
|
"""
|
|
|
- Convert PIL Image to tensor according to specified input_size after following steps below:
|
|
|
- - resize
|
|
|
- - rotate (if align_long_axis is True and image is not aligned longer axis with canvas)
|
|
|
- - pad
|
|
|
+ Convert PIL Image or numpy array to properly sized and padded image after:
|
|
|
+ - crop margins
|
|
|
+ - resize while maintaining aspect ratio
|
|
|
+ - pad to target size
|
|
|
"""
|
|
|
if img is None:
|
|
|
- return
|
|
|
- # crop margins
|
|
|
+ return None
|
|
|
+
|
|
|
try:
|
|
|
- img = self.crop_margin(img.convert("RGB"))
|
|
|
- except OSError:
|
|
|
+ img = self.crop_margin_numpy(img)
|
|
|
+ except Exception:
|
|
|
# might throw an error for broken files
|
|
|
- return
|
|
|
+ return None
|
|
|
+
|
|
|
+ if img.shape[0] == 0 or img.shape[1] == 0:
|
|
|
+ return None
|
|
|
+
|
|
|
+ # Resize while preserving aspect ratio
|
|
|
+ h, w = img.shape[:2]
|
|
|
+ scale = min(self.input_size[0] / h, self.input_size[1] / w)
|
|
|
+ new_h, new_w = int(h * scale), int(w * scale)
|
|
|
+ resized_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
|
|
+
|
|
|
+ # Calculate padding
|
|
|
+ pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
|
|
|
+
|
|
|
+ # Create and apply padding
|
|
|
+ channels = 3 if len(img.shape) == 3 else 1
|
|
|
+ padded_img = np.full((self.input_size[0], self.input_size[1], channels), 255, dtype=np.uint8)
|
|
|
+ padded_img[pad_height:pad_height + new_h, pad_width:pad_width + new_w] = resized_img
|
|
|
+
|
|
|
+ return padded_img
|
|
|
+
|
|
|
+ def _calculate_padding(self, new_w, new_h, random_padding):
|
|
|
+ """Calculate padding values for PIL images"""
|
|
|
+ delta_width = self.input_size[1] - new_w
|
|
|
+ delta_height = self.input_size[0] - new_h
|
|
|
+
|
|
|
+ pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
|
|
|
|
|
|
- if img.height == 0 or img.width == 0:
|
|
|
- return
|
|
|
+ return (
|
|
|
+ pad_width,
|
|
|
+ pad_height,
|
|
|
+ delta_width - pad_width,
|
|
|
+ delta_height - pad_height,
|
|
|
+ )
|
|
|
+
|
|
|
+ def _get_padding_values(self, new_w, new_h, random_padding):
|
|
|
+ """Get padding values based on image dimensions and padding strategy"""
|
|
|
+ delta_width = self.input_size[1] - new_w
|
|
|
+ delta_height = self.input_size[0] - new_h
|
|
|
|
|
|
- img = resize(img, min(self.input_size))
|
|
|
- img.thumbnail((self.input_size[1], self.input_size[0]))
|
|
|
- delta_width = self.input_size[1] - img.width
|
|
|
- delta_height = self.input_size[0] - img.height
|
|
|
if random_padding:
|
|
|
pad_width = np.random.randint(low=0, high=delta_width + 1)
|
|
|
pad_height = np.random.randint(low=0, high=delta_height + 1)
|
|
|
else:
|
|
|
pad_width = delta_width // 2
|
|
|
pad_height = delta_height // 2
|
|
|
- padding = (
|
|
|
- pad_width,
|
|
|
- pad_height,
|
|
|
- delta_width - pad_width,
|
|
|
- delta_height - pad_height,
|
|
|
- )
|
|
|
- return ImageOps.expand(img, padding)
|
|
|
+
|
|
|
+ return pad_width, pad_height
|