model.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from ....utils import logging
  16. from ....utils.errors import raise_unsupported_api_error
  17. from ....utils.misc import abspath
  18. from ...base import BaseModel
  19. from ...base.utils.arg import CLIArgument
  20. from ...base.utils.subprocess import CompletedProcess
  21. class VideoDetModel(BaseModel):
  22. """Video Det Model"""
  23. def train(
  24. self,
  25. batch_size: int = None,
  26. learning_rate: float = None,
  27. epochs_iters: int = None,
  28. ips: str = None,
  29. device: str = "gpu",
  30. resume_path: str = None,
  31. dy2st: bool = False,
  32. amp: str = "OFF",
  33. num_workers: int = None,
  34. use_vdl: bool = True,
  35. save_dir: str = None,
  36. **kwargs,
  37. ) -> CompletedProcess:
  38. """train self
  39. Args:
  40. batch_size (int, optional): the train batch size value. Defaults to None.
  41. learning_rate (float, optional): the train learning rate value. Defaults to None.
  42. epochs_iters (int, optional): the train epochs value. Defaults to None.
  43. ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
  44. device (str, optional): the running device. Defaults to 'gpu'.
  45. resume_path (str, optional): the checkpoint file path to resume training. Train from scratch if it is set
  46. to None. Defaults to None.
  47. dy2st (bool, optional): Enable dynamic to static. Defaults to False.
  48. amp (str, optional): the amp settings. Defaults to 'OFF'.
  49. num_workers (int, optional): the workers number. Defaults to None.
  50. use_vdl (bool, optional): enable VisualDL. Defaults to True.
  51. save_dir (str, optional): the directory path to save train output. Defaults to None.
  52. Returns:
  53. CompletedProcess: the result of training subprocess execution.
  54. """
  55. if resume_path is not None:
  56. resume_path = abspath(resume_path)
  57. with self._create_new_config_file() as config_path:
  58. # Update YAML config file
  59. config = self.config.copy()
  60. config.update_device(device)
  61. config._update_to_static(dy2st)
  62. config._update_use_vdl(use_vdl)
  63. if batch_size is not None:
  64. config.update_batch_size(batch_size)
  65. if learning_rate is not None:
  66. config.update_learning_rate(learning_rate)
  67. if epochs_iters is not None:
  68. config._update_epochs(epochs_iters)
  69. config._update_checkpoints(resume_path)
  70. if save_dir is not None:
  71. save_dir = abspath(save_dir)
  72. else:
  73. # `save_dir` is None
  74. save_dir = abspath(config.get_train_save_dir())
  75. config._update_output_dir(save_dir)
  76. if num_workers is not None:
  77. config.update_num_workers(num_workers)
  78. cli_args = []
  79. do_eval = kwargs.pop("do_eval", True)
  80. profile = kwargs.pop("profile", None)
  81. if profile is not None:
  82. cli_args.append(CLIArgument("--profiler_options", profile))
  83. # Benchmarking mode settings
  84. benchmark = kwargs.pop("benchmark", None)
  85. if benchmark is not None:
  86. envs = benchmark.get("env", None)
  87. seed = benchmark.get("seed", None)
  88. do_eval = benchmark.get("do_eval", False)
  89. num_workers = benchmark.get("num_workers", None)
  90. config.update_log_ranks(device)
  91. config._update_amp(benchmark.get("amp", None))
  92. config.update_dali(benchmark.get("dali", False))
  93. config.update_shuffle(benchmark.get("shuffle", False))
  94. config.update_shared_memory(benchmark.get("shared_memory", True))
  95. config.update_print_mem_info(benchmark.get("print_mem_info", True))
  96. if num_workers is not None:
  97. config.update_num_workers(num_workers)
  98. if seed is not None:
  99. config.update_seed(seed)
  100. if envs is not None:
  101. for env_name, env_value in envs.items():
  102. os.environ[env_name] = str(env_value)
  103. else:
  104. config._update_amp(amp)
  105. # PDX related settings
  106. device_type = device.split(":")[0]
  107. uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
  108. config.update({"Global.uniform_output_enabled": uniform_output_enabled})
  109. config.update({"Global.pdx_model_name": self.name})
  110. config.dump(config_path)
  111. self._assert_empty_kwargs(kwargs)
  112. return self.runner.train(
  113. config_path, cli_args, device, ips, save_dir, do_eval=do_eval
  114. )
  115. def evaluate(
  116. self,
  117. weight_path: str,
  118. batch_size: int = None,
  119. ips: str = None,
  120. device: str = "gpu",
  121. amp: str = "OFF",
  122. num_workers: int = None,
  123. **kwargs,
  124. ) -> CompletedProcess:
  125. """evaluate self using specified weight
  126. Args:
  127. weight_path (str): the path of model weight file to be evaluated.
  128. batch_size (int, optional): the batch size value in evaluating. Defaults to None.
  129. ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
  130. device (str, optional): the running device. Defaults to 'gpu'.
  131. amp (str, optional): the AMP setting. Defaults to 'OFF'.
  132. num_workers (int, optional): the workers number in evaluating. Defaults to None.
  133. Returns:
  134. CompletedProcess: the result of evaluating subprocess execution.
  135. """
  136. with self._create_new_config_file() as config_path:
  137. # Update YAML config file
  138. config = self.config.copy()
  139. config._update_amp(amp)
  140. config.update_device(device)
  141. config.update_pretrained_weights(weight_path)
  142. if batch_size is not None:
  143. config.update_batch_size(batch_size)
  144. if num_workers is not None:
  145. config.update_num_workers(num_workers)
  146. config.dump(config_path)
  147. self._assert_empty_kwargs(kwargs)
  148. cp = self.runner.evaluate(config_path, [], device, ips)
  149. return cp
  150. def predict(
  151. self,
  152. weight_path: str,
  153. input_path: str,
  154. input_list_path: str = None,
  155. device: str = "gpu",
  156. save_dir: str = None,
  157. **kwargs,
  158. ) -> CompletedProcess:
  159. """predict using specified weight
  160. Args:
  161. weight_path (str): the path of model weight file used to predict.
  162. input_path (str): the path of image file to be predicted.
  163. input_list_path (str, optional): the paths of images to be predicted if is not None. Defaults to None.
  164. device (str, optional): the running device. Defaults to 'gpu'.
  165. save_dir (str, optional): the directory path to save predict output. Defaults to None.
  166. Returns:
  167. CompletedProcess: the result of predicting subprocess execution.
  168. """
  169. input_path = abspath(input_path)
  170. if input_list_path:
  171. input_list_path = abspath(input_list_path)
  172. with self._create_new_config_file() as config_path:
  173. # Update YAML config file
  174. config = self.config.copy()
  175. config.update_pretrained_weights(weight_path)
  176. config._update_predict_img(input_path, input_list_path)
  177. config.update_device(device)
  178. config._update_save_predict_result(save_dir)
  179. config.dump(config_path)
  180. self._assert_empty_kwargs(kwargs)
  181. return self.runner.predict(config_path, [], device)
  182. def export(self, weight_path: str, save_dir: str, **kwargs) -> CompletedProcess:
  183. """export the dynamic model to static model
  184. Args:
  185. weight_path (str): the model weight file path that used to export.
  186. save_dir (str): the directory path to save export output.
  187. Returns:
  188. CompletedProcess: the result of exporting subprocess execution.
  189. """
  190. if not weight_path.startswith(("http://", "https://")):
  191. weight_path = abspath(weight_path)
  192. save_dir = abspath(save_dir)
  193. with self._create_new_config_file() as config_path:
  194. # Update YAML config file
  195. config = self.config.copy()
  196. config.update_pretrained_weights(weight_path)
  197. config._update_save_inference_dir(save_dir)
  198. device = kwargs.pop("device", None)
  199. if device:
  200. config.update_device(device)
  201. # PDX related settings
  202. uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
  203. config.update({"Global.uniform_output_enabled": uniform_output_enabled})
  204. config.update({"Global.pdx_model_name": self.name})
  205. config.dump(config_path)
  206. self._assert_empty_kwargs(kwargs)
  207. return self.runner.export(config_path, [], None, save_dir)
  208. def infer(
  209. self,
  210. model_dir: str,
  211. input_path: str,
  212. device: str = "gpu",
  213. save_dir: str = None,
  214. dict_path: str = None,
  215. **kwargs,
  216. ) -> CompletedProcess:
  217. """predict image using infernece model
  218. Args:
  219. model_dir (str): the directory path of inference model files that would use to predict.
  220. input_path (str): the path of image that would be predict.
  221. device (str, optional): the running device. Defaults to 'gpu'.
  222. save_dir (str, optional): the directory path to save output. Defaults to None.
  223. dict_path (str, optional): the label dict file path. Defaults to None.
  224. Returns:
  225. CompletedProcess: the result of inferring subprocess execution.
  226. """
  227. model_dir = abspath(model_dir)
  228. input_path = abspath(input_path)
  229. if save_dir is not None:
  230. logging.warning("`save_dir` will not be used.")
  231. config_path = os.path.join(model_dir, "inference.yml")
  232. config = self.config.copy()
  233. config.load(config_path)
  234. config._update_inference_model_dir(model_dir)
  235. config._update_infer_img(input_path)
  236. config._update_infer_device(device)
  237. if dict_path is not None:
  238. dict_path = abspath(dict_path)
  239. config.update_label_dict_path(dict_path)
  240. if "enable_mkldnn" in kwargs:
  241. config._update_enable_mkldnn(kwargs.pop("enable_mkldnn"))
  242. with self._create_new_config_file() as config_path:
  243. config.dump(config_path)
  244. self._assert_empty_kwargs(kwargs)
  245. return self.runner.infer(config_path, [], device)
  246. def compression(
  247. self,
  248. weight_path: str,
  249. batch_size: int = None,
  250. learning_rate: float = None,
  251. epochs_iters: int = None,
  252. device: str = "gpu",
  253. use_vdl: bool = True,
  254. save_dir: str = None,
  255. **kwargs,
  256. ):
  257. """compression model"""
  258. raise_unsupported_api_error("compression", self.__class__)