瀏覽代碼

disable trt acceleration for rcnn and segmentation models

will-jl944 4 年之前
父節點
當前提交
34b2ca4873
共有 1 個文件被更改,包括 16 次插入7 次删除
  1. 16 7
      paddlex/deploy.py

+ 16 - 7
paddlex/deploy.py

@@ -94,13 +94,22 @@ class Predictor(object):
             config.enable_use_gpu(100, gpu_id)
             config.switch_ir_optim(True)
             if use_trt:
-                config.enable_tensorrt_engine(
-                    workspace_size=1 << 10,
-                    max_batch_size=max_trt_batch_size,
-                    min_subgraph_size=3,
-                    precision_mode=trt_precision_mode,
-                    use_static=False,
-                    use_calib_mode=False)
+                if self._model.model_type == 'segmenter':
+                    logging.warning(
+                        "Semantic segmentation models do not support TensorRT acceleration, "
+                        "TensorRT is forcibly disabled.")
+                elif 'RCNN' in self._model.__class__.__name__:
+                    logging.warning(
+                        "RCNN models do not support TensorRT acceleration, "
+                        "TensorRT is forcibly disabled.")
+                else:
+                    config.enable_tensorrt_engine(
+                        workspace_size=1 << 10,
+                        max_batch_size=max_trt_batch_size,
+                        min_subgraph_size=3,
+                        precision_mode=trt_precision_mode,
+                        use_static=False,
+                        use_calib_mode=False)
         else:
             config.disable_gpu()
             config.set_cpu_math_library_num_threads(cpu_thread_num)