ソースを参照

support to Pad preprocess for YOLOX

gaotingquan 1 年間 前
コミット
0b8eab2a0e

+ 39 - 0
paddlex/modules/object_detection/predictor/transforms.py

@@ -270,6 +270,45 @@ class PadStride(BaseTransform):
         return [K.IMAGE]
 
 
+class Pad(BaseTransform):
+    def __init__(self, size, fill_value=[114.0, 114.0, 114.0]):
+        """
+        Pad image to a specified size.
+        Args:
+            size (list[int]): image target size
+            fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
+        """
+        super(Pad, self).__init__()
+        if isinstance(size, int):
+            size = [size, size]
+        self.size = size
+        self.fill_value = fill_value
+
+    def apply(self, data):
+        im = data[K.IMAGE]
+        im_h, im_w = im.shape[:2]
+        h, w = self.size
+        if h == im_h and w == im_w:
+            # im = im.astype(np.float32)
+            return data
+
+        canvas = np.ones((h, w, 3), dtype=np.float32)
+        canvas *= np.array(self.fill_value, dtype=np.float32)
+        canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
+        data[K.IMAGE] = canvas
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        """ get input keys """
+        return [K.IMAGE]
+
+    @classmethod
+    def get_output_keys(cls):
+        """ get output keys """
+        return [K.IMAGE]
+
+
 class DetResize(_BaseResize):
     """
     Resize the image.

+ 5 - 3
paddlex/modules/object_detection/predictor/utils.py

@@ -12,15 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-
 import codecs
 
 import yaml
 
 from ....utils import logging
 from ...base.predictor.transforms import image_common
-from .transforms import SaveDetResults, PadStride, DetResize
+from .transforms import SaveDetResults, PadStride, DetResize, Pad
 
 
 class InnerConfig(object):
@@ -71,6 +69,10 @@ class InnerConfig(object):
             elif cfg['type'] == 'PadStride':
                 stride = cfg.get('stride', 32)
                 tf = PadStride(stride=stride)
+            elif cfg['type'] == 'Pad':
+                fill_value = cfg.get('fill_value', [114.0, 114.0, 114.0])
+                size = cfg.get('size', [640, 640])
+                tf = Pad(size=size, fill_value=fill_value)
             else:
                 raise RuntimeError(f"Unsupported type: {cfg['type']}")
             tfs.append(tf)