Эх сурвалжийг харах

support to specify inference model directory in CLI

gaotingquan 1 жил өмнө
parent
commit
99ff4d31f5

+ 13 - 6
paddlex/paddlex_cli.py

@@ -1,5 +1,5 @@
 # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 import argparse
 import textwrap
@@ -32,6 +31,11 @@ def args_cfg():
         """
         return v.lower() in ("true", "t", "1")
 
+    def str2None(s):
+        """convert to None type if it is "None"
+        """
+        return None if s.lower() == 'none' else s
+
     parser = argparse.ArgumentParser()
 
     ################# install pdx #################
@@ -52,6 +56,7 @@ def args_cfg():
     parser.add_argument('--predict', action='store_true', default=True, help="")
     parser.add_argument('--pipeline', type=str, help="")
     parser.add_argument('--model', nargs='+', help="")
+    parser.add_argument('--model_dir', nargs='+', type=str2None, help="")
     parser.add_argument('--input', type=str, help="")
     parser.add_argument('--output', type=str, default="./", help="")
     parser.add_argument('--device', type=str, default='gpu:0', help="")
@@ -79,10 +84,12 @@ def install(args):
     return
 
 
-def pipeline_predict(pipeline, model_name_list, input_path, output, device):
+def pipeline_predict(pipeline, model_name_list, model_dir_list, input_path,
+                     output, device):
     """pipeline predict
     """
-    pipeline = build_pipeline(pipeline, model_name_list, output, device)
+    pipeline = build_pipeline(pipeline, model_name_list, model_dir_list, output,
+                              device)
     pipeline.predict({"input_path": input_path})
 
 
@@ -94,5 +101,5 @@ def main():
     if args.install:
         install(args)
     else:
-        return pipeline_predict(args.pipeline, args.model, args.input,
-                                args.output, args.device)
+        return pipeline_predict(args.pipeline, args.model, args.model_dir,
+                                args.input, args.output, args.device)

+ 8 - 4
paddlex/pipelines/OCR/pipeline.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 import cv2
 
@@ -110,15 +109,20 @@ Only support: {text_rec_models}."
 
         return result
 
-    def update_model_name(self, model_name_list):
-        """update model name and re
+    def update_model(self, model_name_list, model_dir_list):
+        """update model
 
         Args:
-            model_list (list): list of model name.
+            model_name_list (list): list of model name.
+            model_dir_list (list): list of model directory.
         """
         assert len(model_name_list) == 2
         self.text_det_model_name = model_name_list[0]
         self.text_rec_model_name = model_name_list[1]
+        if model_dir_list:
+            assert len(model_dir_list) == 2
+            self.text_det_model_dir = model_dir_list[0]
+            self.text_rec_model_dir = model_dir_list[1]
 
     def get_kernel_option(self):
         """get kernel option

+ 6 - 4
paddlex/pipelines/base/pipeline.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 from abc import ABC, abstractmethod
 
 from ...utils.misc import AutoRegisterABCMetaClass
@@ -21,6 +20,7 @@ from ...utils.misc import AutoRegisterABCMetaClass
 def build_pipeline(
         pipeline_name: str,
         model_list: list,
+        model_dir_list: list,
         output: str,
         device: str, ) -> "BasePipeline":
     """build model evaluater
@@ -33,6 +33,7 @@ def build_pipeline(
     """
     pipeline = BasePipeline.get(pipeline_name)(output=output, device=device)
     pipeline.update_model_name(model_list)
+    pipeline.update_model(model_list, model_dir_list)
     pipeline.load_model()
     return pipeline
 
@@ -52,11 +53,12 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
         raise NotImplementedError
 
     @abstractmethod
-    def update_model_name(self, model_list: list) -> dict:
-        """update model name and re
+    def update_model(self, model_name_list, model_dir_list):
+        """update model
 
         Args:
-            model_list (list): list of model name.
+            model_name_list (list): list of model name.
+            model_dir_list (list): list of model directory.
         """
         raise NotImplementedError
 

+ 7 - 4
paddlex/pipelines/image_classification/pipeline.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 from ..base import BasePipeline
 from ...modules.image_classification.model_list import MODELS
 from ...modules import create_model, PaddleInferenceOption
@@ -69,14 +68,18 @@ class ClsPipeline(BasePipeline):
         kernel_option.set_device(self.device)
         return kernel_option
 
-    def update_model_name(self, model_name_list):
-        """update model name and re
+    def update_model(self, model_name_list, model_dir_list):
+        """update model
 
         Args:
-            model_list (list): list of model name.
+            model_name_list (list): list of model name.
+            model_dir_list (list): list of model directory.
         """
         assert len(model_name_list) == 1
         self.model_name = model_name_list[0]
+        if model_dir_list:
+            assert len(model_dir_list) == 1
+            self.model_dir = model_dir_list[0]
 
     def get_input_keys(self):
         """get dict keys of input argument input

+ 7 - 4
paddlex/pipelines/instance_segmentation/pipeline.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 from ..base import BasePipeline
 from ...modules.instance_segmentation.model_list import MODELS
 from ...modules import create_model, PaddleInferenceOption
@@ -68,14 +67,18 @@ class InstanceSegPipeline(BasePipeline):
         kernel_option.set_device(self.device)
         return kernel_option
 
-    def update_model_name(self, model_name_list):
-        """update model name and re
+    def update_model(self, model_name_list, model_dir_list):
+        """update model
 
         Args:
-            model_list (list): list of model name.
+            model_name_list (list): list of model name.
+            model_dir_list (list): list of model directory.
         """
         assert len(model_name_list) == 1
         self.model_name = model_name_list[0]
+        if model_dir_list:
+            assert len(model_dir_list) == 1
+            self.model_dir = model_dir_list[0]
 
     def get_input_keys(self):
         """get dict keys of input argument input

+ 7 - 4
paddlex/pipelines/object_detection/pipeline.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 from ..base import BasePipeline
 from ...modules.object_detection.model_list import MODELS
 from ...modules import create_model, PaddleInferenceOption
@@ -67,14 +66,18 @@ class DetPipeline(BasePipeline):
         kernel_option = PaddleInferenceOption()
         kernel_option.set_device(self.device)
 
-    def update_model_name(self, model_name_list):
-        """update model name and re
+    def update_model(self, model_name_list, model_dir_list):
+        """update model
 
         Args:
-            model_list (list): list of model name.
+            model_name_list (list): list of model name.
+            model_dir_list (list): list of model directory.
         """
         assert len(model_name_list) == 1
         self.model_name = model_name_list[0]
+        if model_dir_list:
+            assert len(model_dir_list) == 1
+            self.model_dir = model_dir_list[0]
 
     def get_input_keys(self):
         """get dict keys of input argument input

+ 7 - 4
paddlex/pipelines/semantic_segmentation/pipeline.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 from ..base import BasePipeline
 from ...modules.semantic_segmentation.model_list import MODELS
 from ...modules import create_model, PaddleInferenceOption
@@ -68,14 +67,18 @@ class SegPipeline(BasePipeline):
         kernel_option.set_device(self.device)
         return kernel_option
 
-    def update_model_name(self, model_name_list):
-        """update model name and re
+    def update_model(self, model_name_list, model_dir_list):
+        """update model
 
         Args:
-            model_list (list): list of model name.
+            model_name_list (list): list of model name.
+            model_dir_list (list): list of model directory.
         """
         assert len(model_name_list) == 1
         self.model_name = model_name_list[0]
+        if model_dir_list:
+            assert len(model_dir_list) == 1
+            self.model_dir = model_dir_list[0]
 
     def get_input_keys(self):
         """get dict keys of input argument input