gaotingquan 1 年之前
父节点
当前提交
cafe2082e2

+ 1 - 1
paddlex/inference/components/paddle_predictor/predictor.py

@@ -125,7 +125,7 @@ class BasePaddlePredictor(BaseComponent):
                 max_batch_size=self.option.batch_size,
                 min_subgraph_size=self.option.min_subgraph_size,
                 precision_mode=precision_map[self.option.run_mode],
-                trt_use_static=self.option.trt_use_static,
+                use_static=self.option.trt_use_static,
                 use_calib_mode=self.option.trt_calib_mode,
             )
 

+ 0 - 1
paddlex/inference/models/base/base_predictor.py

@@ -18,7 +18,6 @@ from pathlib import Path
 from abc import abstractmethod
 
 from ...components.base import BaseComponent
-from ...utils.pp_option import PaddlePredictorOption
 from ...utils.process_hook import generatorable_method
 
 

+ 2 - 2
paddlex/inference/models/base/basic_predictor.py

@@ -72,12 +72,12 @@ class BasicPredictor(
     def set_predictor(self, batch_size=None, device=None, pp_option=None):
         if batch_size:
             self.components["ReadCmp"].batch_size = batch_size
+
+            self.pp_option.batch_size = batch_size
         if device and device != self.pp_option.device:
             self.pp_option.device = device
-            self.components["PPEngineCmp"].reset()
         if pp_option and pp_option != self.pp_option:
             self.pp_option = pp_option
-            self.components["PPEngineCmp"].reset()
 
     def _has_setter(self, attr):
         prop = getattr(self.__class__, attr, None)

+ 10 - 1
paddlex/inference/utils/pp_option.py

@@ -34,8 +34,8 @@ class PaddlePredictorOption(object):
         super().__init__()
         self.model_name = model_name
         self._cfg = {}
-        self._init_option(**kwargs)
         self._observers = []
+        self._init_option(**kwargs)
 
     def _init_option(self, **kwargs):
         for k, v in kwargs.items():
@@ -62,6 +62,7 @@ class PaddlePredictorOption(object):
             "trt_use_static": False,
             "delete_pass": [],
             "enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False,
+            "batch_size": 1,  # only for trt
         }
 
     def _update(self, k, v):
@@ -179,6 +180,14 @@ class PaddlePredictorOption(object):
         """set run mode"""
         self._update("enable_new_ir", enable_new_ir)
 
+    @property
+    def batch_size(self):
+        return self._cfg["batch_size"]
+
+    @batch_size.setter
+    def batch_size(self, batch_size):
+        self._update("batch_size", batch_size)
+
     def get_support_run_mode(self):
         """get supported run mode"""
         return self.SUPPORT_RUN_MODE