Browse Source

fix image cls process (#2731)

zhangyubo0722 10 months ago
parent
commit
597c80cfe3

+ 1 - 1
paddlex/inference/components/transforms/image/common.py

@@ -278,7 +278,7 @@ class Crop(BaseComponent):
         x2 = min(w, x1 + cw)
         y2 = min(h, y1 + ch)
         coords = (x1, y1, x2, y2)
-        if coords == (0, 0, w, h):
+        if w < cw or h < ch:
             raise ValueError(
                 f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
             )

+ 26 - 2
paddlex/inference/models_new/common/vision/funcs.py

@@ -13,6 +13,9 @@
 # limitations under the License.
 
 import cv2
+import numpy as np
+from PIL import Image
+from .....utils import logging
 
 
 def check_image_size(input_):
@@ -26,13 +29,34 @@ def check_image_size(input_):
         raise TypeError(f"{input_} cannot represent a valid image size.")
 
 
-def resize(im, target_size, interp):
+def resize(im, target_size, interp, backend="cv2"):
     """resize image to target size"""
     w, h = target_size
-    im = cv2.resize(im, (w, h), interpolation=interp)
+    if backend.lower() == "pil":
+        resize_function = _pil_resize
+    else:
+        resize_function = _cv2_resize
+        if backend.lower() != "cv2":
+            logging.warning(
+                f"Unknown backend {backend}. Defaulting to cv2 for resizing."
+            )
+    im = resize_function(im, (w, h), interp)
     return im
 
 
+def _cv2_resize(src, size, resample):
+    return cv2.resize(src, size, interpolation=resample)
+
+
+def _pil_resize(src, size, resample):
+    if isinstance(src, np.ndarray):
+        pil_img = Image.fromarray(src)
+    else:
+        pil_img = src
+    pil_img = pil_img.resize(size, resample)
+    return np.asarray(pil_img)
+
+
 def flip_h(im):
     """flip image horizontally"""
     if len(im.shape) == 3:

+ 49 - 14
paddlex/inference/models_new/common/vision/processors.py

@@ -20,20 +20,28 @@ from copy import deepcopy
 
 import numpy as np
 import cv2
+from PIL import Image
 
 from . import funcs as F
 
 
 class _BaseResize:
-    _INTERP_DICT = {
+    _CV2_INTERP_DICT = {
         "NEAREST": cv2.INTER_NEAREST,
         "LINEAR": cv2.INTER_LINEAR,
-        "CUBIC": cv2.INTER_CUBIC,
+        "BICUBIC": cv2.INTER_CUBIC,
         "AREA": cv2.INTER_AREA,
         "LANCZOS4": cv2.INTER_LANCZOS4,
     }
+    _PIL_INTERP_DICT = {
+        "NEAREST": Image.NEAREST,
+        "BILINEAR": Image.BILINEAR,
+        "BICUBIC": Image.BICUBIC,
+        "BOX": Image.BOX,
+        "LANCZOS4": Image.LANCZOS,
+    }
 
-    def __init__(self, size_divisor, interp):
+    def __init__(self, size_divisor, interp, backend="cv2"):
         super().__init__()
 
         if size_divisor is not None:
@@ -43,12 +51,26 @@ class _BaseResize:
         self.size_divisor = size_divisor
 
         try:
-            interp = self._INTERP_DICT[interp]
+            interp = interp.upper()
+            if backend == "cv2":
+                interp = self._CV2_INTERP_DICT[interp]
+            elif backend == "pil":
+                interp = self._PIL_INTERP_DICT[interp]
+            else:
+                raise ValueError("backend must be `cv2` or `pil`")
         except KeyError:
             raise ValueError(
-                "`interp` should be one of {}.".format(self._INTERP_DICT.keys())
+                "For backend '{}', `interp` should be one of {}. Please ensure the interpolation method matches the selected backend.".format(
+                    backend,
+                    (
+                        self._CV2_INTERP_DICT.keys()
+                        if backend == "cv2"
+                        else self._PIL_INTERP_DICT.keys()
+                    ),
+                )
             )
         self.interp = interp
+        self.backend = backend
 
     @staticmethod
     def _rescale_size(img_size, target_size):
@@ -62,7 +84,12 @@ class Resize(_BaseResize):
     """Resize the image."""
 
     def __init__(
-        self, target_size, keep_ratio=False, size_divisor=None, interp="LINEAR"
+        self,
+        target_size,
+        keep_ratio=False,
+        size_divisor=None,
+        interp="LINEAR",
+        backend="cv2",
     ):
         """
         Initialize the instance.
@@ -76,7 +103,7 @@ class Resize(_BaseResize):
             interp (str, optional): Interpolation method. Choices are 'NEAREST',
                 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
         """
-        super().__init__(size_divisor=size_divisor, interp=interp)
+        super().__init__(size_divisor=size_divisor, interp=interp, backend=backend)
 
         if isinstance(target_size, int):
             target_size = [target_size, target_size]
@@ -102,7 +129,7 @@ class Resize(_BaseResize):
                 math.ceil(i / self.size_divisor) * self.size_divisor
                 for i in target_size
             ]
-        img = F.resize(img, target_size, interp=self.interp)
+        img = F.resize(img, target_size, interp=self.interp, backend=self.backend)
         return img
 
 
@@ -112,7 +139,9 @@ class ResizeByLong(_BaseResize):
     longest side.
     """
 
-    def __init__(self, target_long_edge, size_divisor=None, interp="LINEAR"):
+    def __init__(
+        self, target_long_edge, size_divisor=None, interp="LINEAR", backend="cv2"
+    ):
         """
         Initialize the instance.
 
@@ -123,7 +152,7 @@ class ResizeByLong(_BaseResize):
             interp (str, optional): Interpolation method. Choices are 'NEAREST',
                 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
         """
-        super().__init__(size_divisor=size_divisor, interp=interp)
+        super().__init__(size_divisor=size_divisor, interp=interp, backend=backend)
         self.target_long_edge = target_long_edge
 
     def __call__(self, imgs):
@@ -139,7 +168,9 @@ class ResizeByLong(_BaseResize):
             h_resize = math.ceil(h_resize / self.size_divisor) * self.size_divisor
             w_resize = math.ceil(w_resize / self.size_divisor) * self.size_divisor
 
-        img = F.resize(img, (w_resize, h_resize), interp=self.interp)
+        img = F.resize(
+            img, (w_resize, h_resize), interp=self.interp, backend=self.backend
+        )
         return img
 
 
@@ -149,7 +180,9 @@ class ResizeByShort(_BaseResize):
     shortest side.
     """
 
-    def __init__(self, target_short_edge, size_divisor=None, interp="LINEAR"):
+    def __init__(
+        self, target_short_edge, size_divisor=None, interp="LINEAR", backend="cv2"
+    ):
         """
         Initialize the instance.
 
@@ -160,7 +193,7 @@ class ResizeByShort(_BaseResize):
             interp (str, optional): Interpolation method. Choices are 'NEAREST',
                 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
         """
-        super().__init__(size_divisor=size_divisor, interp=interp)
+        super().__init__(size_divisor=size_divisor, interp=interp, backend=backend)
         self.target_short_edge = target_short_edge
 
     def __call__(self, imgs):
@@ -176,7 +209,9 @@ class ResizeByShort(_BaseResize):
             h_resize = math.ceil(h_resize / self.size_divisor) * self.size_divisor
             w_resize = math.ceil(w_resize / self.size_divisor) * self.size_divisor
 
-        img = F.resize(img, (w_resize, h_resize), interp=self.interp)
+        img = F.resize(
+            img, (w_resize, h_resize), interp=self.interp, backend=self.backend
+        )
         return img
 
 

+ 10 - 2
paddlex/inference/models_new/image_classification/predictor.py

@@ -139,10 +139,18 @@ class ClasPredictor(BasicPredictor):
         assert resize_short or size
         if resize_short:
             op = ResizeByShort(
-                target_short_edge=resize_short, size_divisor=None, interp="LINEAR"
+                target_short_edge=resize_short,
+                size_divisor=None,
+                interp=interpolation,
+                backend=backend,
             )
         else:
-            op = Resize(target_size=size)
+            op = Resize(
+                target_size=size,
+                size_divisor=None,
+                interp=interpolation,
+                backend=backend,
+            )
         return "Resize", op
 
     @register("CropImage")

+ 1 - 1
paddlex/inference/models_new/image_classification/processors.py

@@ -56,7 +56,7 @@ class Crop:
         x2 = min(w, x1 + cw)
         y2 = min(h, y1 + ch)
         coords = (x1, y1, x2, y2)
-        if coords == (0, 0, w, h):
+        if w < cw or h < ch:
             raise ValueError(
                 f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
             )

+ 1 - 1
paddlex/inference/models_new/object_detection/predictor.py

@@ -213,7 +213,7 @@ class DetPredictor(BasicPredictor):
             interp = {
                 0: "NEAREST",
                 1: "LINEAR",
-                2: "CUBIC",
+                2: "BICUBIC",
                 3: "AREA",
                 4: "LANCZOS4",
             }[interp]