Browse Source

change default device to None in CLI

gaotingquan 1 year ago
parent
commit
852e59ed12

+ 2 - 2
paddlex/inference/components/utils/mixin.py

@@ -44,10 +44,10 @@ class PPEngineMixin:
 
     @option.setter
     def option(self, value):
-        if value != self.option:
+        if value is not None and value != self.option:
             self._option = value
             self._reset()
 
     @abstractmethod
     def _reset(self):
-        raise NotImplementedError
+        raise NotImplementedError

+ 9 - 6
paddlex/inference/pipelines/__init__.py

@@ -63,12 +63,15 @@ def create_pipeline(
     pipeline_name = config["Global"]["pipeline_name"]
     pipeline_setting = config["Pipeline"]
     pipeline_setting.update(kwargs)
+    if device:
+        pipeline_setting["device"] = device
+    if pp_option:
+        pipeline_setting["pp_option"] = pp_option
+    if use_hpip:
+        pipeline_setting["use_hpip"] = use_hpip
 
-    predictor_kwargs = {"device": device, "pp_option": pp_option, "use_hpip": use_hpip}
-    if hpi_params is not None:
-        predictor_kwargs["hpi_params"] = hpi_params
+    if hpi_params:
+        pipeline_setting["hpi_params"] = hpi_params
 
-    pipeline = BasePipeline.get(pipeline_name)(
-        predictor_kwargs=predictor_kwargs, *args, **pipeline_setting
-    )
+    pipeline = BasePipeline.get(pipeline_name)(*args, **pipeline_setting)
     return pipeline

+ 1 - 3
paddlex/inference/pipelines/base.py

@@ -24,10 +24,8 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
 
     __is_base = True
 
-    def __init__(self, predictor_kwargs: Optional[Dict[str, Any]]) -> None:
+    def __init__(self, **predictor_kwargs) -> None:
         super().__init__()
-        if predictor_kwargs is None:
-            predictor_kwargs = {}
         self._predictor_kwargs = predictor_kwargs
 
     # alias the __call__() to predict()

+ 2 - 8
paddlex/inference/pipelines/ocr.py

@@ -22,14 +22,8 @@ class OCRPipeline(BasePipeline):
 
     entities = "OCR"
 
-    def __init__(
-        self,
-        det_model,
-        rec_model,
-        batch_size=1,
-        predictor_kwargs=None,
-    ):
-        super().__init__(predictor_kwargs)
+    def __init__(self, det_model, rec_model, batch_size=1, **predictor_kwargs):
+        super().__init__(**predictor_kwargs)
         self._build_predictor(det_model, rec_model)
         self.set_predictor(batch_size)
 

+ 2 - 2
paddlex/inference/pipelines/single_model_pipeline.py

@@ -17,8 +17,8 @@ from .base import BasePipeline
 
 class _SingleModelPipeline(BasePipeline):
 
-    def __init__(self, model, batch_size=1, predictor_kwargs=None):
-        super().__init__(predictor_kwargs)
+    def __init__(self, model, batch_size=1, **predictor_kwargs):
+        super().__init__(**predictor_kwargs)
         self._build_predictor(model)
         self.set_predictor(batch_size)
 

+ 2 - 2
paddlex/inference/pipelines/table_recognition/table_recognition.py

@@ -34,9 +34,9 @@ class TableRecPipeline(BasePipeline):
         layout_batch_size=1,
         text_rec_batch_size=1,
         table_batch_size=1,
-        predictor_kwargs=None,
+        **predictor_kwargs,
     ):
-        super().__init__(predictor_kwargs)
+        super().__init__(**predictor_kwargs)
         self._build_predictor(layout_model, text_det_model, text_rec_model, table_model)
         self.set_predictor(layout_batch_size, text_rec_batch_size, table_batch_size)
 

+ 1 - 0
paddlex/inference/utils/pp_option.py

@@ -77,6 +77,7 @@ class PaddlePredictorOption(object):
             )
         self._cfg["run_mode"] = run_mode
 
+    # TODO(gaotingquan): setter
     @register("device")
     def set_device(self, device: str):
         """set device"""

+ 1 - 1
paddlex/paddlex_cli.py

@@ -66,7 +66,7 @@ def args_cfg():
     parser.add_argument("--model_dir", nargs="+", type=parse_str, help="")
     parser.add_argument("--input", type=str, help="")
     parser.add_argument("--save_dir", type=str, default="./", help="")
-    parser.add_argument("--device", type=str, default="gpu:0", help="")
+    parser.add_argument("--device", type=str, default=None, help="")
 
     return parser.parse_args()
 

+ 1 - 0
paddlex/pipelines/image_classification.yaml

@@ -7,6 +7,7 @@ Global:
 
 Pipeline:
   model: PP-LCNet_x0_5
+  device: cpu
 
 ######################################## Support ########################################
 NOTE: