Ver Fonte

rm kwargs to set predictor

gaotingquan há 1 ano atrás
pai
commit
f54e835f37

+ 4 - 3
paddlex/inference/models/__init__.py

@@ -14,8 +14,9 @@
 
 
 from pathlib import Path
-from ..utils.official_models import official_models
+from typing import Any, Dict, Optional
 
+from ..utils.official_models import official_models
 from .base import BasePredictor, BasicPredictor
 from .image_classification import ClasPredictor
 from .text_detection import TextDetPredictor
@@ -64,8 +65,8 @@ def create_predictor(
     model: str,
     device=None,
     pp_option=None,
-    use_hpip=False,
-    hpi_params=None,
+    use_hpip: bool = False,
+    hpi_params: Optional[Dict[str, Any]] = None,
     *args,
     **kwargs,
 ) -> BasePredictor:

+ 2 - 4
paddlex/inference/models/base/basic_predictor.py

@@ -34,9 +34,7 @@ class BasicPredictor(
 
     __is_base = True
 
-    def __init__(
-        self, model_dir, config=None, device=None, pp_option=None, **option_kwargs
-    ):
+    def __init__(self, model_dir, config=None, device=None, pp_option=None):
         super().__init__(model_dir=model_dir, config=config)
         self._pred_set_func_map = {}
         self._pred_set_register = FuncRegister(self._pred_set_func_map)
@@ -47,7 +45,7 @@ class BasicPredictor(
         self.pp_option = (
             pp_option
             if pp_option
-            else PaddlePredictorOption(model_name=self.model_name, **option_kwargs)
+            else PaddlePredictorOption(model_name=self.model_name)
         )
         self.pp_option.set_device(device)
         self.components = {}

+ 19 - 10
paddlex/inference/pipelines/__init__.py

@@ -62,16 +62,25 @@ def create_pipeline(
     config = parse_config(pipeline)
     pipeline_name = config["Global"]["pipeline_name"]
     pipeline_setting = config["Pipeline"]
-    pipeline_setting.update(kwargs)
-    if device:
-        pipeline_setting["device"] = device
-    if pp_option:
-        pipeline_setting["pp_option"] = pp_option
-    if use_hpip:
-        pipeline_setting["use_hpip"] = use_hpip
 
-    if hpi_params:
-        pipeline_setting["hpi_params"] = hpi_params
+    predictor_kwargs = {"use_hpip": use_hpip}
+    if "use_hpip" in pipeline_setting:
+        predictor_kwargs["use_hpip"] = use_hpip
+    if hpi_params is not None:
+        predictor_kwargs["hpi_params"] = hpi_params
+    elif "hpi_params" in pipeline_setting:
+        predictor_kwargs["hpi_params"] = pipeline_setting.pop("hpi_params")
+    if device is not None:
+        predictor_kwargs["device"] = device
+    elif "device" in pipeline_setting:
+        predictor_kwargs["device"] = pipeline_setting.pop("device")
+    if pp_option is not None:
+        predictor_kwargs["pp_option"] = pp_option
+    elif "pp_option" in pipeline_setting:
+        predictor_kwargs["pp_option"] = pipeline_setting.pop("pp_option")
 
-    pipeline = BasePipeline.get(pipeline_name)(*args, **pipeline_setting)
+    pipeline_setting.update(kwargs)
+    pipeline = BasePipeline.get(pipeline_name)(
+        predictor_kwargs=predictor_kwargs, *args, **pipeline_setting
+    )
     return pipeline

+ 2 - 2
paddlex/inference/pipelines/base.py

@@ -24,9 +24,9 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
 
     __is_base = True
 
-    def __init__(self, **predictor_kwargs) -> None:
+    def __init__(self, predictor_kwargs) -> None:
         super().__init__()
-        self._predictor_kwargs = predictor_kwargs
+        self._predictor_kwargs = {} if predictor_kwargs is None else predictor_kwargs
 
     # alias the __call__() to predict()
     def __call__(self, *args, **kwargs):

+ 2 - 2
paddlex/inference/pipelines/ocr.py

@@ -22,8 +22,8 @@ class OCRPipeline(BasePipeline):
 
     entities = "OCR"
 
-    def __init__(self, det_model, rec_model, batch_size=1, **predictor_kwargs):
-        super().__init__(**predictor_kwargs)
+    def __init__(self, det_model, rec_model, batch_size=1, predictor_kwargs=None):
+        super().__init__(predictor_kwargs=predictor_kwargs)
         self._build_predictor(det_model, rec_model)
         self.set_predictor(batch_size)
 

+ 2 - 2
paddlex/inference/pipelines/single_model_pipeline.py

@@ -17,8 +17,8 @@ from .base import BasePipeline
 
 class _SingleModelPipeline(BasePipeline):
 
-    def __init__(self, model, batch_size=1, **predictor_kwargs):
-        super().__init__(**predictor_kwargs)
+    def __init__(self, model, batch_size=1, predictor_kwargs=None):
+        super().__init__(predictor_kwargs=predictor_kwargs)
         self._build_predictor(model)
         self.set_predictor(batch_size)
 

+ 11 - 4
paddlex/inference/pipelines/table_recognition/table_recognition.py

@@ -34,14 +34,21 @@ class TableRecPipeline(BasePipeline):
         layout_batch_size=1,
         text_rec_batch_size=1,
         table_batch_size=1,
-        **predictor_kwargs,
+        predictor_kwargs=None,
     ):
-        super().__init__(**predictor_kwargs)
-        self._build_predictor(layout_model, text_det_model, text_rec_model, table_model)
+        super().__init__(predictor_kwargs=predictor_kwargs)
+        self._build_predictor(
+            layout_model, text_det_model, text_rec_model, table_model, predictor_kwargs
+        )
         self.set_predictor(layout_batch_size, text_rec_batch_size, table_batch_size)
 
     def _build_predictor(
-        self, layout_model, text_det_model, text_rec_model, table_model
+        self,
+        layout_model,
+        text_det_model,
+        text_rec_model,
+        table_model,
+        predictor_kwargs,
     ):
         self.layout_predictor = self._create_model(model=layout_model)
         self.ocr_pipeline = OCRPipeline(

+ 3 - 1
paddlex/pipelines/table_recognition.yaml

@@ -10,7 +10,9 @@ Pipeline:
   table_model: SLANet
   text_det_model: PP-OCRv4_mobile_det
   text_rec_model: PP-OCRv4_mobile_rec
-  batch_size: 1
+  layout_batch_size: 1
+  text_rec_batch_size: 1
+  table_batch_size: 1
   device: "gpu:0"
 
 ######################################## Support ########################################