Prechádzať zdrojové kódy

update Paddle Predictor

gaotingquan 10 mesiacov pred
rodič
commit
f1ba03b2e9

+ 2 - 1
paddlex/configs/modules/text_recognition/PP-OCRv4_mobile_rec.yaml

@@ -36,4 +36,5 @@ Predict:
   model_dir: "output/best_accuracy/inference"
   input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_rec_001.png"
   kernel_option:
-    run_mode: paddle
+    run_mode: trt_fp32
+    enable_new_ir: True

+ 93 - 58
paddlex/inference/models/common/static_infer.py

@@ -14,12 +14,12 @@
 
 from typing import Union, Tuple, List, Dict, Any, Iterator
 import os
-import inspect
-from abc import abstractmethod
+import shutil
+from pathlib import Path
 import lazy_paddle as paddle
 import numpy as np
 
-from ....utils.flags import FLAGS_json_format_model
+from ....utils.flags import DEBUG, FLAGS_json_format_model
 from ....utils import logging
 from ...utils.pp_option import PaddlePredictorOption
 
@@ -47,6 +47,40 @@ def collect_trt_shapes(
         predictor.run()
 
 
+def convert_trt(mode, pp_model_path, trt_dynamic_shapes):
+    from lazy_paddle.tensorrt.export import (
+        Input,
+        TensorRTConfig,
+        convert,
+        PrecisionMode,
+    )
+
+    trt_save_dir = str(Path(pp_model_path) / "trt" / "inference")
+
+    precision_map = {
+        "trt_int8": PrecisionMode.INT8,
+        "trt_fp32": PrecisionMode.FP32,
+        "trt_fp16": PrecisionMode.FP16,
+    }
+    trt_inputs = []
+    for name, candidate_shapes in trt_dynamic_shapes.items():
+        min_shape, opt_shape, max_shape = candidate_shapes
+        trt_input = Input(
+            min_input_shape=min_shape,
+            optim_input_shape=opt_shape,
+            max_input_shape=max_shape,
+        )
+        trt_inputs.append(trt_input)
+
+    # Create TensorRTConfig
+    trt_config = TensorRTConfig(inputs=trt_inputs)
+    trt_config.precision_mode = precision_map[mode]
+    trt_config.save_model_dir = trt_save_dir
+    convert(str(Path(pp_model_path) / "inference"), trt_config)
+    # copy inference.yaml to new model dir
+    shutil.copy(str(Path(pp_model_path) / "inference.yml"), trt_save_dir + ".yml")
+
+
 class Copy2GPU:
 
     def __init__(self, input_handlers):
@@ -144,48 +178,53 @@ class StaticInfer:
                     self.model_dir / f"{self.model_prefix}.pdmodel"
                 ).as_posix()
         params_file = (self.model_dir / f"{self.model_prefix}.pdiparams").as_posix()
-        config = Config(model_file, params_file)
 
-        config.enable_memory_optim()
-        if self.option.device in ("gpu", "dcu"):
+        config = Config(model_file, params_file)
+        if self.option.device == "gpu":
             if self.option.device == "gpu":
                 config.exp_disable_mixed_precision_ops({"feed", "fetch"})
             config.enable_use_gpu(100, self.option.device_id)
-            if self.option.device == "gpu":
-                # NOTE: The pptrt settings are not aligned with those of FD.
-                precision_map = {
-                    "trt_int8": Config.Precision.Int8,
-                    "trt_fp32": Config.Precision.Float32,
-                    "trt_fp16": Config.Precision.Half,
-                }
-                if self.option.run_mode in precision_map.keys():
-                    config.enable_tensorrt_engine(
-                        workspace_size=(1 << 25) * self.option.batch_size,
-                        max_batch_size=self.option.batch_size,
-                        min_subgraph_size=self.option.min_subgraph_size,
-                        precision_mode=precision_map[self.option.run_mode],
-                        use_static=self.option.trt_use_static,
-                        use_calib_mode=self.option.trt_calib_mode,
-                    )
 
-                    if not os.path.exists(self.option.shape_info_filename):
-                        logging.info(
-                            f"Dynamic shape info is collected into: {self.option.shape_info_filename}"
-                        )
-                        collect_trt_shapes(
-                            model_file,
-                            params_file,
-                            self.option.device_id,
-                            self.option.shape_info_filename,
-                            self.option.trt_dynamic_shapes,
-                        )
-                    else:
-                        logging.info(
-                            f"A dynamic shape info file ( {self.option.shape_info_filename} ) already exists. No need to collect again."
-                        )
-                    config.enable_tuned_tensorrt_dynamic_shape(
-                        self.option.shape_info_filename, True
+            if hasattr(config, "enable_new_ir"):
+                config.enable_new_ir(self.option.enable_new_ir)
+            if hasattr(config, "enable_new_executor"):
+                config.enable_new_executor()
+            config.set_optimization_level(3)
+
+            # NOTE: The pptrt settings are not aligned with those of FD.
+            precision_map = {
+                "trt_int8": Config.Precision.Int8,
+                "trt_fp32": Config.Precision.Float32,
+                "trt_fp16": Config.Precision.Half,
+            }
+            if self.option.run_mode in precision_map.keys():
+                config.enable_tensorrt_engine(
+                    workspace_size=(1 << 25) * self.option.batch_size,
+                    max_batch_size=self.option.batch_size,
+                    min_subgraph_size=self.option.min_subgraph_size,
+                    precision_mode=precision_map[self.option.run_mode],
+                    use_static=self.option.trt_use_static,
+                    use_calib_mode=self.option.trt_calib_mode,
+                )
+
+                if not os.path.exists(self.option.shape_info_filename):
+                    logging.info(
+                        f"Dynamic shape info is collected into: {self.option.shape_info_filename}"
+                    )
+                    collect_trt_shapes(
+                        model_file,
+                        params_file,
+                        self.option.device_id,
+                        self.option.shape_info_filename,
+                        self.option.trt_dynamic_shapes,
+                    )
+                else:
+                    logging.info(
+                        f"A dynamic shape info file ( {self.option.shape_info_filename} ) already exists. No need to collect again."
                     )
+                config.enable_tuned_tensorrt_dynamic_shape(
+                    self.option.shape_info_filename, True
+                )
 
         elif self.option.device == "npu":
             config.enable_custom_device("npu")
@@ -193,6 +232,11 @@ class StaticInfer:
             pass
         elif self.option.device == "mlu":
             config.enable_custom_device("mlu")
+        elif self.option.device == "dcu":
+            if paddle.is_compiled_with_rocm():
+                # Delete unsupported passes in dcu
+                config.delete_pass("conv2d_add_act_fuse_pass")
+                config.delete_pass("conv2d_add_fuse_pass")
         else:
             assert self.option.device == "cpu"
             config.disable_gpu()
@@ -209,30 +253,21 @@ class StaticInfer:
             else:
                 if hasattr(config, "disable_mkldnn"):
                     config.disable_mkldnn()
+            config.set_cpu_math_library_num_threads(self.option.cpu_threads)
 
-        # Disable paddle inference logging
-        config.disable_glog_info()
-
-        config.set_cpu_math_library_num_threads(self.option.cpu_threads)
-
-        if self.option.device in ("cpu", "gpu"):
-            if not (
-                self.option.device == "gpu" and self.option.run_mode.startswith("trt")
-            ):
-                if hasattr(config, "enable_new_ir"):
-                    config.enable_new_ir(self.option.enable_new_ir)
-                if hasattr(config, "enable_new_executor"):
-                    config.enable_new_executor()
-                config.set_optimization_level(3)
+            if hasattr(config, "enable_new_ir"):
+                config.enable_new_ir(self.option.enable_new_ir)
+            if hasattr(config, "enable_new_executor"):
+                config.enable_new_executor()
+            config.set_optimization_level(3)
 
+        config.enable_memory_optim()
         for del_p in self.option.delete_pass:
             config.delete_pass(del_p)
 
-        if self.option.device in ("gpu", "dcu"):
-            if paddle.is_compiled_with_rocm():
-                # Delete unsupported passes in dcu
-                config.delete_pass("conv2d_add_act_fuse_pass")
-                config.delete_pass("conv2d_add_fuse_pass")
+        # Disable paddle inference logging
+        if not DEBUG:
+            config.disable_glog_info()
 
         predictor = create_predictor(config)