Ver código fonte

support to pass batch_size when create model (#2901)

Tingquan Gao 10 meses atrás
pai
commit
2b72192e97

+ 4 - 0
paddlex/inference/models_new/base/predictor/basic_predictor.py

@@ -39,6 +39,7 @@ class BasicPredictor(
         model_dir: str,
         config: Dict[str, Any] = None,
         device: str = None,
+        batch_size: int = 1,
         pp_option: PaddlePredictorOption = None,
     ) -> None:
         """Initializes the BasicPredictor.
@@ -47,6 +48,7 @@ class BasicPredictor(
             model_dir (str): The directory where the model files are stored.
             config (Dict[str, Any], optional): The configuration dictionary. Defaults to None.
             device (str, optional): The device to run the inference engine on. Defaults to None.
+            batch_size (int, optional): The batch size to predict. Defaults to 1.
             pp_option (PaddlePredictorOption, optional): The inference engine options. Defaults to None.
         """
         super().__init__(model_dir=model_dir, config=config)
@@ -63,6 +65,8 @@ class BasicPredictor(
         if trt_dynamic_shapes:
             pp_option.trt_dynamic_shapes = trt_dynamic_shapes
         self.pp_option = pp_option
+        self.pp_option.batch_size = batch_size
+        self.batch_sampler.batch_size = batch_size
 
         logging.debug(f"{self.__class__.__name__}: {self.model_dir}")
         self.benchmark = benchmark