Browse Source

Cli input param new feat (#2955)

* add fast infer for semantic seg; by zzl

* fix sth

* fix sth oh seg

* add platform judge

* add platform judge

* Re-trigger CI

* add inst seg; fix sth

* fix sth

* add docs

* fix sth

* add param for seg

* fix sth

* adapt seg rst to lower version of pillow

* fix sth

* Trigger CI

* fix pil

* fix pil

* add GDINO; add SAM

* Trigger CI

* add pipelines of seg, inst seg, sod, rod

* fix sth

* fix sth

* add ovd & ovs pipeline; add modules for ovd & ovs; fix sth

* add docs

* fix sth

* fix doc images urls

* fix module docs & rename PP-YOLOE-R-L

* fix typo

* fix pipeline param bugs

* allow multi-type inputs for one param
Zhang Zelun 9 months ago
parent
commit
62faefc3c4
1 changed files with 82 additions and 3 deletions
  1. 82 3
      paddlex/utils/pipeline_arguments.py

+ 82 - 3
paddlex/utils/pipeline_arguments.py

@@ -12,6 +12,36 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from ast import literal_eval
+from pydantic import TypeAdapter, ValidationError
+from functools import wraps
+from typing import Dict, List, Tuple, Union, Literal, Optional
+
+
+def custom_type(cli_expected_type):
+    """Create validator for CLI input conversion and type checking"""
+
+    def validator(cli_input: str) -> cli_expected_type:
+        try:
+            parsed = literal_eval(cli_input)
+        except (ValueError, SyntaxError, TypeError, MemoryError, RecursionError) as exc:
+            err = f"""Malformed input:
+            - Input: {cli_input!r}
+            - Error: {exc}"""
+            raise ValueError(err) from exc
+
+        try:
+            return TypeAdapter(cli_expected_type).validate_python(parsed)
+        except ValidationError as exc:
+            err = f"""Invalid input type:
+            - Expected: {cli_expected_type}
+            - Received: {cli_input!r}
+            """
+            raise ValueError(err) from exc
+
+    return validator
+
+
 PIPELINE_ARGUMENTS = {
     "OCR": [
         {
@@ -214,9 +244,27 @@ PIPELINE_ARGUMENTS = {
             "help": "Sets the layout merge bboxes mode for layout detection.",
         },
     ],
-    "instance_segmentation": None,
-    "semantic_segmentation": None,
-    "small_object_detection": None,
+    "instance_segmentation": [
+        {
+            "name": "--threshold",
+            "type": custom_type(Optional[float]),
+            "help": "Sets the threshold for instance segmentation.",
+        },
+    ],
+    "semantic_segmentation": [
+        {
+            "name": "--target_size",
+            "type": custom_type(Optional[Union[int, Tuple[int, int], Literal[-1]]]),
+            "help": "Sets the inference image resolution for semantic segmentation.",
+        },
+    ],
+    "small_object_detection": [
+        {
+            "name": "--threshold",
+            "type": custom_type(Optional[Union[float, dict[int, float]]]),
+            "help": "Sets the threshold for small object detection.",
+        },
+    ],
     "anomaly_detection": None,
     "video_classification": [
         {
@@ -249,4 +297,35 @@ PIPELINE_ARGUMENTS = {
             "help": "Determines whether to use document unwarping.",
         },
     ],
+    "rotated_object_detection": [
+        {
+            "name": "--threshold",
+            "type": custom_type(Optional[Union[float, dict[int, float]]]),
+            "help": "Sets the threshold for rotated object detection.",
+        },
+    ],
+    "open_vocabulary_detection": [
+        {
+            "name": "--thresholds",
+            "type": custom_type(dict[str, float]),
+            "help": "Sets the thresholds for open vocabulary detection.",
+        },
+        {
+            "name": "--prompt",
+            "type": str,
+            "help": "Sets the prompt for open vocabulary detection.",
+        },
+    ],
+    "open_vocabulary_segmentation": [
+        {
+            "name": "--prompt_type",
+            "type": str,
+            "help": "Sets the prompt type for open vocabulary segmentation.",
+        },
+        {
+            "name": "--prompt",
+            "type": custom_type(list[list[float]]),
+            "help": "Sets the prompt for open vocabulary segmentation.",
+        },
+    ],
 }