deploy.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import numpy as np
  17. import yaml
  18. from paddle.inference import Config
  19. from paddle.inference import create_predictor
  20. from paddle.inference import PrecisionType
  21. from paddlex.cv.transforms import build_transforms
  22. from paddlex.utils import logging
  23. class Predictor(object):
  24. def __init__(self,
  25. model_dir,
  26. use_gpu=True,
  27. gpu_id=0,
  28. cpu_thread_num=1,
  29. use_mkl=True,
  30. mkl_thread_num=4,
  31. use_trt=False,
  32. use_glog=False,
  33. memory_optimize=True,
  34. max_trt_batch_size=1,
  35. trt_precision_mode='float32'):
  36. """ 创建Paddle Predictor
  37. Args:
  38. model_dir: 模型路径(必须是导出的部署或量化模型)
  39. use_gpu: 是否使用gpu,默认True
  40. gpu_id: 使用gpu的id,默认0
  41. cpu_thread_num=1:使用cpu进行预测时的线程数,默认为1
  42. use_mkl: 是否使用mkldnn计算库,CPU情况下使用,默认False
  43. mkl_thread_num: mkldnn计算线程数,默认为4
  44. use_trt: 是否使用TensorRT,默认False
  45. use_glog: 是否启用glog日志, 默认False
  46. memory_optimize: 是否启动内存优化,默认True
  47. max_trt_batch_size: 在使用TensorRT时配置的最大batch size,默认1
  48. trt_precision_mode:在使用TensorRT时采用的精度,默认float32
  49. """
  50. if not osp.isdir(model_dir):
  51. logging.error(
  52. "{} is not a valid model directory.".format(model_dir),
  53. exit=True)
  54. if trt_precision_mode == 'float32':
  55. trt_precision_mode = PrecisionType.Float32
  56. elif trt_precision_mode == 'float16':
  57. trt_precision_mode = PrecisionType.Float16
  58. else:
  59. logging.error(
  60. "TensorRT precision mode {} is invalid. Supported modes are float32 and float16."
  61. .format(trt_precision_mode),
  62. exit=True)
  63. def create_predictor(self,
  64. use_gpu=True,
  65. gpu_id=0,
  66. cpu_thread_num=1,
  67. use_mkl=True,
  68. mkl_thread_num=4,
  69. use_trt=False,
  70. use_glog=False,
  71. memory_optimize=True,
  72. max_trt_batch_size=1,
  73. trt_precision_mode=PrecisionType.Float32):
  74. config = Config(
  75. prog_file=osp.join(self.model_dir, 'model.pdmodel'),
  76. params_file=osp.join(self.model_dir, 'model.pdiparams'))
  77. if use_gpu:
  78. # 设置GPU初始显存(单位M)和Device ID
  79. config.enable_use_gpu(100, gpu_id)
  80. config.switch_ir_optim(True)
  81. if use_trt:
  82. config.enable_tensorrt_engine(
  83. workspace_size=1 << 10,
  84. max_batch_size=max_trt_batch_size,
  85. min_subgraph_size=3,
  86. precision_mode=trt_precision_mode,
  87. use_static=False,
  88. use_calib_mode=False)
  89. else:
  90. config.disable_gpu()
  91. config.set_cpu_math_library_num_threads(cpu_thread_num)
  92. if use_mkl:
  93. try:
  94. # cache 10 different shapes for mkldnn to avoid memory leak
  95. config.set_mkldnn_cache_capacity(10)
  96. config.enable_mkldnn()
  97. config.set_cpu_math_library_num_threads(mkl_thread_num)
  98. except Exception as e:
  99. logging.warning(
  100. "The current environment does not support `mkldnn`, so disable mkldnn."
  101. )
  102. pass
  103. if use_glog:
  104. config.enable_glog_info()
  105. else:
  106. config.disable_glog_info()
  107. if memory_optimize:
  108. config.enable_memory_optim()
  109. config.switch_use_feed_fetch_ops(False)
  110. predictor = create_predictor(config)
  111. return predictor