|
|
@@ -39,6 +39,7 @@ class BasicPredictor(
|
|
|
model_dir: str,
|
|
|
config: Dict[str, Any] = None,
|
|
|
device: str = None,
|
|
|
+ batch_size: int = 1,
|
|
|
pp_option: PaddlePredictorOption = None,
|
|
|
) -> None:
|
|
|
"""Initializes the BasicPredictor.
|
|
|
@@ -47,6 +48,7 @@ class BasicPredictor(
|
|
|
model_dir (str): The directory where the model files are stored.
|
|
|
config (Dict[str, Any], optional): The configuration dictionary. Defaults to None.
|
|
|
device (str, optional): The device to run the inference engine on. Defaults to None.
|
|
|
+ batch_size (int, optional): The batch size to predict. Defaults to 1.
|
|
|
pp_option (PaddlePredictorOption, optional): The inference engine options. Defaults to None.
|
|
|
"""
|
|
|
super().__init__(model_dir=model_dir, config=config)
|
|
|
@@ -63,6 +65,8 @@ class BasicPredictor(
|
|
|
if trt_dynamic_shapes:
|
|
|
pp_option.trt_dynamic_shapes = trt_dynamic_shapes
|
|
|
self.pp_option = pp_option
|
|
|
+ self.pp_option.batch_size = batch_size
|
|
|
+ self.batch_sampler.batch_size = batch_size
|
|
|
|
|
|
logging.debug(f"{self.__class__.__name__}: {self.model_dir}")
|
|
|
self.benchmark = benchmark
|