Pārlūkot izejas kodu

rename imgsz as img_size

leo-q8 10 mēneši atpakaļ
vecāks
revīzija
27b5314c0a

+ 1 - 1
paddlex/configs/pipelines/human_keypoint_detection.yaml

@@ -7,7 +7,7 @@ SubModules:
     model_dir: null
     batch_size: 1
     threshold: null
-    imgsz: null
+    img_size: null
   KeypointDetection:
     module_name: keypoint_detection
     model_name: PP-TinyPose_128x96

+ 1 - 1
paddlex/configs/pipelines/object_detection.yaml

@@ -6,5 +6,5 @@ SubModules:
     model_name: PicoDet-S
     model_dir: null
     batch_size: 1
-    imgsz: null
+    img_size: null
     threshold: null

+ 11 - 11
paddlex/inference/models_new/object_detection/predictor.py

@@ -47,33 +47,33 @@ class DetPredictor(BasicPredictor):
     def __init__(
         self,
         *args,
-        imgsz: Optional[Union[int, Tuple[int, int]]] = None,
+        img_size: Optional[Union[int, Tuple[int, int]]] = None,
         threshold: Optional[float] = None,
         **kwargs,
     ):
         """Initializes DetPredictor.
         Args:
             *args: Arbitrary positional arguments passed to the superclass.
-            imgsz (Optional[Union[int, Tuple[int, int]]], optional): The input image size (w, h). Defaults to None.
+            img_size (Optional[Union[int, Tuple[int, int]]], optional): The input image size (w, h). Defaults to None.
             threshold (Optional[float], optional): The threshold for filtering out low-confidence predictions.
                 Defaults to None.
             **kwargs: Arbitrary keyword arguments passed to the superclass.
         """
         super().__init__(*args, **kwargs)
 
-        if imgsz is not None:
+        if img_size is not None:
             assert (
                 self.model_name not in STATIC_SHAPE_MODEL_LIST
             ), f"The model {self.model_name} is not supported set input shape"
-            if isinstance(imgsz, int):
-                imgsz = (imgsz, imgsz)
-            elif isinstance(imgsz, (tuple, list)):
-                assert len(imgsz) == 2, f"The length of `imgsz` should be 2."
+            if isinstance(img_size, int):
+                img_size = (img_size, img_size)
+            elif isinstance(img_size, (tuple, list)):
+                assert len(img_size) == 2, f"The length of `img_size` should be 2."
             else:
                 raise ValueError(
-                    f"The type of `imgsz` must be int or Tuple[int, int], but got {type(imgsz)}."
+                    f"The type of `img_size` must be int or Tuple[int, int], but got {type(img_size)}."
                 )
-        self.imgsz = imgsz
+        self.img_size = img_size
         self.threshold = threshold
         self.pre_ops, self.infer, self.post_op = self._build()
 
@@ -100,10 +100,10 @@ class DetPredictor(BasicPredictor):
             if op:
                 pre_ops.append(op)
         pre_ops.append(self.build_to_batch())
-        if self.imgsz is not None:
+        if self.img_size is not None:
             if isinstance(pre_ops[1], Resize):
                 pre_ops.pop(1)
-            pre_ops.insert(1, self.build_resize(self.imgsz, False, 2))
+            pre_ops.insert(1, self.build_resize(self.img_size, False, 2))
 
         # build infer
         infer = StaticInfer(

+ 2 - 2
paddlex/inference/pipelines_new/keypoint_detection/pipeline.py

@@ -55,8 +55,8 @@ class KeypointDetectionPipeline(BasePipeline):
         model_kwargs = {}
         if "threshold" in model_cfg:
             model_kwargs["threshold"] = model_cfg["threshold"]
-        if "imgsz" in model_cfg:
-            model_kwargs["imgsz"] = model_cfg["imgsz"]
+        if "img_size" in model_cfg:
+            model_kwargs["img_size"] = model_cfg["img_size"]
         self.det_model = self.create_model(model_cfg, **model_kwargs)
 
         # create keypoint detection model

+ 2 - 2
paddlex/inference/pipelines_new/object_detection/pipeline.py

@@ -52,8 +52,8 @@ class ObjectDetectionPipeline(BasePipeline):
         model_kwargs = {}
         if "threshold" in model_cfg:
             model_kwargs["threshold"] = model_cfg["threshold"]
-        if "imgsz" in model_cfg:
-            model_kwargs["imgsz"] = model_cfg["imgsz"]
+        if "img_size" in model_cfg:
+            model_kwargs["img_size"] = model_cfg["img_size"]
         self.det_model = self.create_model(model_cfg, **model_kwargs)
 
     def predict(