Ver Fonte

separate Paddle Predictor to copy2gpu, infer, copy2cpu

gaotingquan há 1 ano atrás
pai
commit
b73c5036e3

+ 31 - 27
docs/module_usage/instructions/benchmark.md

@@ -26,40 +26,44 @@ python main.py \
 在开启 Benchmark 后,将自动打印 benchmark 指标:
 
 ```
-+-------------------+-----------------+------+---------------+
-|       Stage       | Total Time (ms) | Nums | Avg Time (ms) |
-+-------------------+-----------------+------+---------------+
-|      ReadCmp      |   49.95107651   |  10  |   4.99510765  |
-|       Resize      |    8.48054886   |  10  |   0.84805489  |
-|     Normalize     |   23.08964729   |  10  |   2.30896473  |
-|     ToCHWImage    |    0.02717972   |  10  |   0.00271797  |
-| ImageDetPredictor |   75.94108582   |  10  |   7.59410858  |
-|   DetPostProcess  |    0.26535988   |  10  |   0.02653599  |
-+-------------------+-----------------+------+---------------+
++----------------+-----------------+------+---------------+
+|     Stage      | Total Time (ms) | Nums | Avg Time (ms) |
++----------------+-----------------+------+---------------+
+|    ReadCmp     |   185.48870087  |  10  |  18.54887009  |
+|     Resize     |   16.95227623   |  30  |   0.56507587  |
+|   Normalize    |   41.12100601   |  30  |   1.37070020  |
+|   ToCHWImage   |    0.05745888   |  30  |   0.00191530  |
+|    Copy2GPU    |   14.58549500   |  10  |   1.45854950  |
+|     Infer      |   100.14462471  |  10  |  10.01446247  |
+|    Copy2CPU    |    9.54508781   |  10  |   0.95450878  |
+| DetPostProcess |    0.56767464   |  30  |   0.01892249  |
++----------------+-----------------+------+---------------+
 +-------------+-----------------+------+---------------+
 |    Stage    | Total Time (ms) | Nums | Avg Time (ms) |
 +-------------+-----------------+------+---------------+
-|  PreProcess |   81.54845238   |  10  |   8.15484524  |
-|  Inference  |   75.94108582   |  10  |   7.59410858  |
-| PostProcess |    0.26535988   |  10  |   0.02653599  |
-|   End2End   |   161.07797623  |  10  |  16.10779762  |
-|    WarmUp   |  5496.41847610  |  5   | 1099.28369522 |
+|  PreProcess |   243.61944199  |  30  |   8.12064807  |
+|  Inference  |   124.27520752  |  30  |   4.14250692  |
+| PostProcess |    0.56767464   |  30  |   0.01892249  |
+|   End2End   |   379.70948219  |  30  |  12.65698274  |
+|    WarmUp   |  9465.68179131  |  5   | 1893.13635826 |
 +-------------+-----------------+------+---------------+
 ```
 
-在 Benchmark 结果中,会统计该模型全部组件(`Component`)的总耗时(`Total Time`,单位为“毫秒”)、调用次数(`Nums`)、调用平均执行耗时(`Avg Time`,单位为“毫秒”),以及按预热(`WarmUp`)、预处理(`PreProcess`)、模型推理(`Inference`)、后处理(`PostProcess`)和端到端(`End2End`)进行划分的耗时统计,包括每个阶段的总耗时(`Total Time`,单位为“毫秒”)、样本数(`Nums`)和单样本平均执行耗时(`Avg Time`,单位为“毫秒”),同时,保存相关指标会到本地 `./benchmark.csv` 文件中:
+在 Benchmark 结果中,会统计该模型全部组件(`Component`)的总耗时(`Total Time`,单位为“毫秒”)、**调用次数**(`Nums`)、**调用**平均执行耗时(`Avg Time`,单位为“毫秒”),以及按预热(`WarmUp`)、预处理(`PreProcess`)、模型推理(`Inference`)、后处理(`PostProcess`)和端到端(`End2End`)进行划分的耗时统计,包括每个阶段的总耗时(`Total Time`,单位为“毫秒”)、**样本数**(`Nums`)和**单样本**平均执行耗时(`Avg Time`,单位为“毫秒”),同时,保存相关指标会到本地 `./benchmark.csv` 文件中:
 
 ```csv
 Stage,Total Time (ms),Nums,Avg Time (ms)
-ReadCmp,0.04995107650756836,10,0.004995107650756836
-Resize,0.008480548858642578,10,0.0008480548858642578
-Normalize,0.02308964729309082,10,0.002308964729309082
-ToCHWImage,2.7179718017578125e-05,10,2.7179718017578126e-06
-ImageDetPredictor,0.07594108581542969,10,0.007594108581542969
-DetPostProcess,0.00026535987854003906,10,2.6535987854003906e-05
-PreProcess,0.08154845237731934,10,0.008154845237731934
-Inference,0.07594108581542969,10,0.007594108581542969
-PostProcess,0.00026535987854003906,10,2.6535987854003906e-05
-End2End,0.16107797622680664,10,0.016107797622680664
-WarmUp,5.496418476104736,5,1.0992836952209473
+ReadCmp,0.18548870086669922,10,0.018548870086669923
+Resize,0.0169522762298584,30,0.0005650758743286133
+Normalize,0.04112100601196289,30,0.001370700200398763
+ToCHWImage,5.745887756347656e-05,30,1.915295918782552e-06
+Copy2GPU,0.014585494995117188,10,0.0014585494995117188
+Infer,0.10014462471008301,10,0.0100144624710083
+Copy2CPU,0.009545087814331055,10,0.0009545087814331055
+DetPostProcess,0.0005676746368408203,30,1.892248789469401e-05
+PreProcess,0.24361944198608398,30,0.0081206480662028
+Inference,0.12427520751953125,30,0.0041425069173177086
+PostProcess,0.0005676746368408203,30,1.892248789469401e-05
+End2End,0.37970948219299316,30,0.012656982739766438
+WarmUp,9.465681791305542,5,1.8931363582611085
 ```

+ 1 - 1
paddlex/configs/object_detection/PicoDet-S.yaml

@@ -33,7 +33,7 @@ Export:
   weight_path: https://paddledet.bj.bcebos.com/models/picodet_s_320_coco_lcnet.pdparams
 
 Predict:
-  batch_size: 1
+  batch_size: 3
   model_dir: "output/best_model/inference"
   input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_object_detection_002.png"
   kernel_option:

+ 7 - 0
paddlex/inference/components/base.py

@@ -107,6 +107,9 @@ class BaseComponent(ABC):
                         f"The parameter ({param.name}) is needed by {self.__class__.__name__}, but {list(args.keys())} only found!"
                     )
 
+        if self.inputs is None:
+            return [({}, None)]
+
         if self.need_batch_input:
             args = {}
             for input_ in input_list:
@@ -266,6 +269,10 @@ class BaseComponent(ABC):
     def name(self):
         return getattr(self, "NAME", self.__class__.__name__)
 
+    @property
+    def sub_cmps(self):
+        return None
+
     @abstractmethod
     def apply(self, input):
         raise NotImplementedError

+ 69 - 24
paddlex/inference/components/paddle_predictor/predictor.py

@@ -23,6 +23,42 @@ from ...utils.pp_option import PaddlePredictorOption
 from ..base import BaseComponent
 
 
+class Copy2GPU(BaseComponent):
+
+    def __init__(self, input_handlers):
+        super().__init__()
+        self.input_handlers = input_handlers
+
+    def apply(self, x):
+        for idx in range(len(x)):
+            self.input_handlers[idx].reshape(x[idx].shape)
+            self.input_handlers[idx].copy_from_cpu(x[idx])
+
+
+class Copy2CPU(BaseComponent):
+
+    def __init__(self, output_handlers):
+        super().__init__()
+        self.output_handlers = output_handlers
+
+    def apply(self):
+        output = []
+        for out_tensor in self.output_handlers:
+            batch = out_tensor.copy_to_cpu()
+            output.append(batch)
+        return output
+
+
+class Infer(BaseComponent):
+
+    def __init__(self, predictor):
+        super().__init__()
+        self.predictor = predictor
+
+    def apply(self):
+        self.predictor.run()
+
+
 class BasePaddlePredictor(BaseComponent):
     """Predictor based on Paddle Inference"""
 
@@ -56,12 +92,13 @@ class BasePaddlePredictor(BaseComponent):
             self.option = PaddlePredictorOption()
         logging.debug(f"Env: {self.option}")
         (
-            self.predictor,
-            self.inference_config,
-            self.input_names,
-            self.input_handlers,
-            self.output_handlers,
+            predictor,
+            input_handlers,
+            output_handlers,
         ) = self._create()
+        self.copy2gpu = Copy2GPU(input_handlers)
+        self.copy2cpu = Copy2CPU(output_handlers)
+        self.infer = Infer(predictor)
         self.option.changed = False
 
     def _create(self):
@@ -169,43 +206,46 @@ class BasePaddlePredictor(BaseComponent):
         for output_name in output_names:
             output_handler = predictor.get_output_handle(output_name)
             output_handlers.append(output_handler)
-        return predictor, config, input_names, input_handlers, output_handlers
-
-    def get_input_names(self):
-        """get input names"""
-        return self.input_names
+        return predictor, input_handlers, output_handlers
 
     def apply(self, **kwargs):
         if self.option.changed:
             self._reset()
-        x = self.to_batch(**kwargs)
-        for idx in range(len(x)):
-            self.input_handlers[idx].reshape(x[idx].shape)
-            self.input_handlers[idx].copy_from_cpu(x[idx])
-
-        self.predictor.run()
-        output = []
-        for out_tensor in self.output_handlers:
-            batch = out_tensor.copy_to_cpu()
-            output.append(batch)
-        return self.format_output(output)
+        batches = self.to_batch(**kwargs)
+        self.copy2gpu.apply(batches)
+        self.infer.apply()
+        pred = self.copy2cpu.apply()
+        return self.format_output(pred)
 
-    def format_output(self, pred):
-        return [{"pred": res} for res in zip(*pred)]
+    @property
+    def sub_cmps(self):
+        return {
+            "Copy2GPU": self.copy2gpu,
+            "Infer": self.infer,
+            "Copy2CPU": self.copy2cpu,
+        }
 
     @abstractmethod
     def to_batch(self):
         raise NotImplementedError
 
+    @abstractmethod
+    def format_output(self, pred):
+        return [{"pred": res} for res in zip(*pred)]
 
-class ImagePredictor(BasePaddlePredictor):
 
+class ImagePredictor(BasePaddlePredictor):
     INPUT_KEYS = "img"
+    OUTPUT_KEYS = "pred"
     DEAULT_INPUTS = {"img": "img"}
+    DEAULT_OUTPUTS = {"pred": "pred"}
 
     def to_batch(self, img):
         return [np.stack(img, axis=0).astype(dtype=np.float32, copy=False)]
 
+    def format_output(self, pred):
+        return [{"pred": res} for res in zip(*pred)]
+
 
 class ImageDetPredictor(BasePaddlePredictor):
 
@@ -276,9 +316,14 @@ class ImageDetPredictor(BasePaddlePredictor):
 class TSPPPredictor(BasePaddlePredictor):
 
     INPUT_KEYS = "ts"
+    OUTPUT_KEYS = "pred"
     DEAULT_INPUTS = {"ts": "ts"}
+    DEAULT_OUTPUTS = {"pred": "pred"}
 
     def to_batch(self, ts):
         n = len(ts[0])
         x = [np.stack([lst[i] for lst in ts], axis=0) for i in range(n)]
         return x
+
+    def format_output(self, pred):
+        return [{"pred": res} for res in zip(*pred)]

+ 23 - 10
paddlex/inference/utils/benchmark.py

@@ -42,11 +42,18 @@ class Benchmark:
         self._reset()
 
     def _reset(self):
-        for name in self._components:
-            cmp = self._components[name]
+        for name, cmp in self.iterate_cmp(self._components):
             cmp.timer.reset()
         self._e2e_tic = time.time()
 
+    def iterate_cmp(self, cmps):
+        if cmps is None:
+            return
+        for name, cmp in cmps.items():
+            if cmp.sub_cmps is not None:
+                yield from self.iterate_cmp(cmp.sub_cmps)
+            yield name, cmp
+
     def gather(self, e2e_num):
         # lazy import for avoiding circular import
         from ..components.paddle_predictor import BasePaddlePredictor
@@ -54,17 +61,23 @@ class Benchmark:
         detail = []
         summary = {"preprocess": 0, "inference": 0, "postprocess": 0}
         op_tag = "preprocess"
-        for name in self._components:
-            cmp = self._components[name]
-            times = cmp.timer.logs
-            counts = len(times)
-            avg = np.mean(times)
-            total = np.sum(times)
-            detail.append((name, total, counts, avg))
+        for name, cmp in self._components.items():
             if isinstance(cmp, BasePaddlePredictor):
-                summary["inference"] += total
+                # TODO(gaotingquan): show by hierarchy. Now dont show xxxPredictor benchmark info to ensure mutual exclusivity between components.
+                for name, sub_cmp in cmp.sub_cmps.items():
+                    times = sub_cmp.timer.logs
+                    counts = len(times)
+                    avg = np.mean(times)
+                    total = np.sum(times)
+                    detail.append((name, total, counts, avg))
+                    summary["inference"] += total
                 op_tag = "postprocess"
             else:
+                times = cmp.timer.logs
+                counts = len(times)
+                avg = np.mean(times)
+                total = np.sum(times)
+                detail.append((name, total, counts, avg))
                 summary[op_tag] += total
 
         summary = [