_config.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import warnings
  15. from pathlib import Path
  16. from typing import Any, Dict, Literal, List, Mapping, Optional, Tuple, Type, Union
  17. import ultra_infer as ui
  18. from paddlex.utils import logging
  19. from pydantic import BaseModel, ConfigDict, Field, field_validator
  20. from typing_extensions import Annotated, TypeAlias, TypedDict, assert_never
  21. from paddlex_hpi._model_info import get_model_info
  22. from paddlex_hpi._utils.typing import Backend, DeviceType
  23. class _BackendConfig(BaseModel):
  24. def update_ui_option(self, option: ui.RuntimeOption, model_dir: Path) -> None:
  25. raise NotImplementedError
  26. class PaddleInferConfig(_BackendConfig):
  27. cpu_num_threads: int = 8
  28. enable_mkldnn: bool = True
  29. enable_trt: bool = False
  30. trt_dynamic_shapes: Optional[Dict[str, List[List[int]]]] = None
  31. trt_dynamic_shape_input_data: Optional[Dict[str, List[List[float]]]] = None
  32. trt_precision: Literal["FP32", "FP16"] = "FP32"
  33. enable_log_info: bool = False
  34. def update_ui_option(self, option: ui.RuntimeOption, model_dir: Path) -> None:
  35. option.use_paddle_infer_backend()
  36. option.set_cpu_thread_num(self.cpu_num_threads)
  37. option.paddle_infer_option.enable_mkldnn = self.enable_mkldnn
  38. option.paddle_infer_option.enable_trt = self.enable_trt
  39. if self.enable_trt:
  40. option.trt_option.serialize_file = str(model_dir / "trt_serialized.trt")
  41. option.paddle_infer_option.collect_trt_shape = True
  42. option.paddle_infer_option.collect_trt_shape_by_device = True
  43. if self.trt_dynamic_shapes is not None:
  44. for name, shapes in self.trt_dynamic_shapes.items():
  45. option.trt_option.set_shape(name, *shapes)
  46. if self.trt_dynamic_shape_input_data is not None:
  47. for name, data in self.trt_dynamic_shape_input_data.items():
  48. option.trt_option.set_input_data(name, *data)
  49. if self.trt_precision == "FP16":
  50. option.trt_option.enable_fp16 = True
  51. option.paddle_infer_option.enable_log_info = self.enable_log_info
  52. class OpenVINOConfig(_BackendConfig):
  53. cpu_num_threads: int = 8
  54. def update_ui_option(self, option: ui.RuntimeOption, model_dir: Path) -> None:
  55. option.use_openvino_backend()
  56. option.set_cpu_thread_num(self.cpu_num_threads)
  57. class ONNXRuntimeConfig(_BackendConfig):
  58. cpu_num_threads: int = 8
  59. def update_ui_option(self, option: ui.RuntimeOption, model_dir: Path) -> None:
  60. option.use_ort_backend()
  61. option.set_cpu_thread_num(self.cpu_num_threads)
  62. class TensorRTConfig(_BackendConfig):
  63. precision: Literal["FP32", "FP16"] = "FP32"
  64. dynamic_shapes: Optional[Dict[str, List[List[int]]]] = None
  65. def update_ui_option(self, option: ui.RuntimeOption, model_dir: Path) -> None:
  66. option.use_trt_backend()
  67. option.trt_option.serialize_file = str(model_dir / "trt_serialized.trt")
  68. if self.precision == "FP16":
  69. option.trt_option.enable_fp16 = True
  70. if self.dynamic_shapes is not None:
  71. for name, shapes in self.dynamic_shapes.items():
  72. option.trt_option.set_shape(name, *shapes)
  73. class PaddleTensorRTConfig(_BackendConfig):
  74. dynamic_shapes: Dict[str, List[List[int]]]
  75. dynamic_shape_input_data: Optional[Dict[str, List[List[float]]]] = None
  76. enable_log_info: bool = False
  77. def update_ui_option(self, option: ui.RuntimeOption, model_dir: Path) -> None:
  78. option.use_paddle_infer_backend()
  79. option.paddle_infer_option.enable_trt = True
  80. option.trt_option.serialize_file = str(model_dir / "trt_serialized.trt")
  81. if self.dynamic_shapes is not None:
  82. option.paddle_infer_option.collect_trt_shape = True
  83. # TODO: Support setting collect_trt_shape_by_device
  84. for name, shapes in self.dynamic_shapes.items():
  85. option.trt_option.set_shape(name, *shapes)
  86. if self.dynamic_shape_input_data is not None:
  87. for name, data in self.dynamic_shape_input_data.items():
  88. option.trt_option.set_input_data(name, *data)
  89. option.paddle_infer_option.enable_log_info = self.enable_log_info
  90. # Should we use tagged unions?
  91. BackendConfig: TypeAlias = Union[
  92. PaddleInferConfig,
  93. OpenVINOConfig,
  94. ONNXRuntimeConfig,
  95. TensorRTConfig,
  96. ]
  97. def get_backend_config_type(backend: Backend, /) -> Type[BackendConfig]:
  98. backend_config_type: Type[BackendConfig]
  99. if backend == "paddle_infer":
  100. backend_config_type = PaddleInferConfig
  101. elif backend == "openvino":
  102. backend_config_type = OpenVINOConfig
  103. elif backend == "onnx_runtime":
  104. backend_config_type = ONNXRuntimeConfig
  105. elif backend == "tensorrt":
  106. backend_config_type = TensorRTConfig
  107. else:
  108. assert_never(backend)
  109. return backend_config_type
  110. # Can I create this dynamically and automatically?
  111. class BackendConfigs(TypedDict, total=False):
  112. paddle_infer: PaddleInferConfig
  113. openvino: OpenVINOConfig
  114. onnx_runtime: ONNXRuntimeConfig
  115. tensorrt: TensorRTConfig
  116. paddle_tensorrt: PaddleTensorRTConfig
  117. class HPIConfig(BaseModel):
  118. model_config = ConfigDict(populate_by_name=True)
  119. selected_backends: Optional[Dict[DeviceType, Backend]] = None
  120. # For backward compatilibity
  121. backend_configs: Annotated[
  122. Optional[BackendConfigs], Field(validation_alias="backend_config")
  123. ] = None
  124. def get_backend_and_config(
  125. self, model_name: str, device_type: DeviceType
  126. ) -> Tuple[Backend, BackendConfig]:
  127. # Do we need an extensible selector?
  128. model_info = get_model_info(model_name, device_type)
  129. if model_info:
  130. backend_config_pairs = model_info["backend_config_pairs"]
  131. else:
  132. backend_config_pairs = []
  133. config_dict: Dict[str, Any] = {}
  134. if self.selected_backends and device_type in self.selected_backends:
  135. backend = self.selected_backends[device_type]
  136. for pair in backend_config_pairs:
  137. # Use the first one
  138. if pair[0] == self.selected_backends[device_type]:
  139. config_dict.update(pair[1])
  140. break
  141. else:
  142. if backend_config_pairs:
  143. # Currently we select the first one
  144. backend = backend_config_pairs[0][0]
  145. config_dict.update(backend_config_pairs[0][1])
  146. else:
  147. backend = "paddle_infer"
  148. if self.backend_configs and backend in self.backend_configs:
  149. config_dict.update(
  150. self.backend_configs[backend].model_dump(exclude_unset=True)
  151. )
  152. backend_config_type = get_backend_config_type(backend)
  153. backend_config = backend_config_type.model_validate(config_dict)
  154. return backend, backend_config
  155. # XXX: For backward compatilibity
  156. @field_validator("selected_backends", mode="before")
  157. @classmethod
  158. def _hack_selected_backends(cls, data: Any) -> Any:
  159. if isinstance(data, Mapping):
  160. new_data = dict(data)
  161. for device_type in new_data:
  162. if new_data[device_type] == "paddle_tensorrt":
  163. warnings.warn(
  164. "`paddle_tensorrt` is deprecated. Please use `paddle_infer` instead.",
  165. FutureWarning,
  166. )
  167. new_data[device_type] = "paddle_infer"
  168. return new_data
  169. @field_validator("backend_configs", mode="before")
  170. @classmethod
  171. def _hack_backend_configs(cls, data: Any) -> Any:
  172. if isinstance(data, Mapping):
  173. new_data = dict(data)
  174. if new_data and "paddle_tensorrt" in new_data:
  175. warnings.warn(
  176. "`paddle_tensorrt` is deprecated. Please use `paddle_infer` instead.",
  177. FutureWarning,
  178. )
  179. if "paddle_infer" not in new_data:
  180. new_data["paddle_infer"] = {}
  181. pptrt_cfg = new_data["paddle_tensorrt"]
  182. logging.warning("`paddle_infer.enable_trt` will be set to `True`.")
  183. new_data["paddle_infer"]["enable_trt"] = True
  184. new_data["paddle_infer"]["trt_dynamic_shapes"] = pptrt_cfg[
  185. "dynamic_shapes"
  186. ]
  187. if "dynamic_shape_input_data" in pptrt_cfg:
  188. new_data["paddle_infer"]["trt_dynamic_shape_input_data"] = (
  189. pptrt_cfg["dynamic_shape_input_data"]
  190. )
  191. logging.warning("`paddle_tensorrt.enable_log_info` will be ignored.")
  192. return new_data