浏览代码

support to pass model arguments when creating model in pipeline

gaotingquan 10 月之前
父节点
当前提交
4231b42bb6

+ 5 - 3
paddlex/inference/pipelines_new/base.py

@@ -67,12 +67,13 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
         """
         raise NotImplementedError("The method `predict` has not been implemented yet.")
 
-    def create_model(self, config: Dict) -> BasePredictor:
+    def create_model(self, config: Dict, **kwargs) -> BasePredictor:
         """
         Create a model instance based on the given configuration.
 
         Args:
             config (Dict): A dictionary containing configuration settings.
+            **kwargs: The model arguments that needed to be pass.
 
         Returns:
             BasePredictor: An instance of the model.
@@ -82,14 +83,15 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
         if model_dir == None:
             model_dir = config["model_name"]
 
-        from ...model import create_model
+        from .. import create_predictor
 
-        model = create_model(
+        model = create_predictor(
             model=model_dir,
             device=self.device,
             pp_option=self.pp_option,
             use_hpip=self.use_hpip,
             hpi_params=self.hpi_params,
+            **kwargs,
         )
 
         # [TODO] Support initializing with additional parameters

+ 6 - 4
paddlex/inference/pipelines_new/image_classification/pipeline.py

@@ -51,13 +51,15 @@ class ImageClassificationPipeline(BasePipeline):
         )
 
         image_classification_model_config = config["SubModules"]["ImageClassification"]
+        model_kwargs = {}
+        if (topk := image_classification_model_config.get("topk", None)) is not None:
+            model_kwargs = {"topk": topk}
         self.image_classification_model = self.create_model(
-            image_classification_model_config
+            image_classification_model_config, **model_kwargs
         )
-        self.topk = image_classification_model_config["topk"]
 
     def predict(
-        self, input: str | list[str] | np.ndarray | list[np.ndarray], **kwargs
+        self, input: str | list[str] | np.ndarray | list[np.ndarray], topk=None
     ) -> TopkResult:
         """Predicts image classification results for the given input.
 
@@ -68,4 +70,4 @@ class ImageClassificationPipeline(BasePipeline):
         Returns:
             TopkResult: The predicted top k results.
         """
-        yield from self.image_classification_model(input, topk=self.topk)
+        yield from self.image_classification_model(input, topk=topk)