model.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from ....utils import logging
  16. from ....utils.misc import abspath
  17. from ...base import BaseModel
  18. from ...base.utils.arg import CLIArgument
  19. from ...base.utils.subprocess import CompletedProcess
  20. class ClsModel(BaseModel):
  21. """Image Classification Model"""
  22. def train(
  23. self,
  24. batch_size: int = None,
  25. learning_rate: float = None,
  26. epochs_iters: int = None,
  27. ips: str = None,
  28. device: str = "gpu",
  29. resume_path: str = None,
  30. dy2st: bool = False,
  31. amp: str = "OFF",
  32. num_workers: int = None,
  33. use_vdl: bool = True,
  34. save_dir: str = None,
  35. **kwargs,
  36. ) -> CompletedProcess:
  37. """train self
  38. Args:
  39. batch_size (int, optional): the train batch size value. Defaults to None.
  40. learning_rate (float, optional): the train learning rate value. Defaults to None.
  41. epochs_iters (int, optional): the train epochs value. Defaults to None.
  42. ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
  43. device (str, optional): the running device. Defaults to 'gpu'.
  44. resume_path (str, optional): the checkpoint file path to resume training. Train from scratch if it is set
  45. to None. Defaults to None.
  46. dy2st (bool, optional): Enable dynamic to static. Defaults to False.
  47. amp (str, optional): the amp settings. Defaults to 'OFF'.
  48. num_workers (int, optional): the workers number. Defaults to None.
  49. use_vdl (bool, optional): enable VisualDL. Defaults to True.
  50. save_dir (str, optional): the directory path to save train output. Defaults to None.
  51. Returns:
  52. CompletedProcess: the result of training subprocess execution.
  53. """
  54. if resume_path is not None:
  55. resume_path = abspath(resume_path)
  56. with self._create_new_config_file() as config_path:
  57. # Update YAML config file
  58. config = self.config.copy()
  59. config.update_device(device)
  60. config._update_to_static(dy2st)
  61. config._update_use_vdl(use_vdl)
  62. if batch_size is not None:
  63. config.update_batch_size(batch_size)
  64. if learning_rate is not None:
  65. config.update_learning_rate(learning_rate)
  66. if epochs_iters is not None:
  67. config._update_epochs(epochs_iters)
  68. config._update_checkpoints(resume_path)
  69. if save_dir is not None:
  70. save_dir = abspath(save_dir)
  71. else:
  72. # `save_dir` is None
  73. save_dir = abspath(config.get_train_save_dir())
  74. config._update_output_dir(save_dir)
  75. if num_workers is not None:
  76. config.update_num_workers(num_workers)
  77. cli_args = []
  78. do_eval = kwargs.pop("do_eval", True)
  79. profile = kwargs.pop("profile", None)
  80. if profile is not None:
  81. cli_args.append(CLIArgument("--profiler_options", profile))
  82. # Benchmarking mode settings
  83. benchmark = kwargs.pop("benchmark", None)
  84. if benchmark is not None:
  85. envs = benchmark.get("env", None)
  86. seed = benchmark.get("seed", None)
  87. do_eval = benchmark.get("do_eval", False)
  88. num_workers = benchmark.get("num_workers", None)
  89. config.update_log_ranks(device)
  90. config._update_amp(benchmark.get("amp", None))
  91. config.update_dali(benchmark.get("dali", False))
  92. config.update_shuffle(benchmark.get("shuffle", False))
  93. config.update_shared_memory(benchmark.get("shared_memory", True))
  94. config.update_print_mem_info(benchmark.get("print_mem_info", True))
  95. if num_workers is not None:
  96. config.update_num_workers(num_workers)
  97. if seed is not None:
  98. config.update_seed(seed)
  99. if envs is not None:
  100. for env_name, env_value in envs.items():
  101. os.environ[env_name] = str(env_value)
  102. else:
  103. config._update_amp(amp)
  104. # PDX related settings
  105. device_type = device.split(":")[0]
  106. uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
  107. export_with_pir = kwargs.pop("export_with_pir", False)
  108. config.update([f"Global.uniform_output_enabled={uniform_output_enabled}"])
  109. config.update([f"Global.pdx_model_name={self.name}"])
  110. if export_with_pir:
  111. config.update([f"Global.export_with_pir={export_with_pir}"])
  112. config.dump(config_path)
  113. self._assert_empty_kwargs(kwargs)
  114. return self.runner.train(
  115. config_path, cli_args, device, ips, save_dir, do_eval=do_eval
  116. )
  117. def evaluate(
  118. self,
  119. weight_path: str,
  120. batch_size: int = None,
  121. ips: str = None,
  122. device: str = "gpu",
  123. amp: str = "OFF",
  124. num_workers: int = None,
  125. **kwargs,
  126. ) -> CompletedProcess:
  127. """evaluate self using specified weight
  128. Args:
  129. weight_path (str): the path of model weight file to be evaluated.
  130. batch_size (int, optional): the batch size value in evaluating. Defaults to None.
  131. ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
  132. device (str, optional): the running device. Defaults to 'gpu'.
  133. amp (str, optional): the AMP setting. Defaults to 'OFF'.
  134. num_workers (int, optional): the workers number in evaluating. Defaults to None.
  135. Returns:
  136. CompletedProcess: the result of evaluating subprocess execution.
  137. """
  138. weight_path = abspath(weight_path)
  139. with self._create_new_config_file() as config_path:
  140. # Update YAML config file
  141. config = self.config.copy()
  142. config._update_amp(amp)
  143. config.update_device(device)
  144. config.update_pretrained_weights(weight_path)
  145. if batch_size is not None:
  146. config.update_batch_size(batch_size)
  147. if num_workers is not None:
  148. config.update_num_workers(num_workers)
  149. config.dump(config_path)
  150. self._assert_empty_kwargs(kwargs)
  151. cp = self.runner.evaluate(config_path, [], device, ips)
  152. return cp
  153. def predict(
  154. self,
  155. weight_path: str,
  156. input_path: str,
  157. input_list_path: str = None,
  158. device: str = "gpu",
  159. save_dir: str = None,
  160. **kwargs,
  161. ) -> CompletedProcess:
  162. """predict using specified weight
  163. Args:
  164. weight_path (str): the path of model weight file used to predict.
  165. input_path (str): the path of image file to be predicted.
  166. input_list_path (str, optional): the paths of images to be predicted if is not None. Defaults to None.
  167. device (str, optional): the running device. Defaults to 'gpu'.
  168. save_dir (str, optional): the directory path to save predict output. Defaults to None.
  169. Returns:
  170. CompletedProcess: the result of predicting subprocess execution.
  171. """
  172. weight_path = abspath(weight_path)
  173. input_path = abspath(input_path)
  174. if input_list_path:
  175. input_list_path = abspath(input_list_path)
  176. with self._create_new_config_file() as config_path:
  177. # Update YAML config file
  178. config = self.config.copy()
  179. config.update_pretrained_weights(weight_path)
  180. config._update_predict_img(input_path, input_list_path)
  181. config.update_device(device)
  182. config._update_save_predict_result(save_dir)
  183. config.dump(config_path)
  184. self._assert_empty_kwargs(kwargs)
  185. return self.runner.predict(config_path, [], device)
  186. def export(self, weight_path: str, save_dir: str, **kwargs) -> CompletedProcess:
  187. """export the dynamic model to static model
  188. Args:
  189. weight_path (str): the model weight file path that used to export.
  190. save_dir (str): the directory path to save export output.
  191. Returns:
  192. CompletedProcess: the result of exporting subprocess execution.
  193. """
  194. if not weight_path.startswith("http"):
  195. weight_path = abspath(weight_path)
  196. save_dir = abspath(save_dir)
  197. with self._create_new_config_file() as config_path:
  198. # Update YAML config file
  199. config = self.config.copy()
  200. config.update_pretrained_weights(weight_path)
  201. config._update_save_inference_dir(save_dir)
  202. device = kwargs.pop("device", None)
  203. if device:
  204. config.update_device(device)
  205. # PDX related settings
  206. uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
  207. export_with_pir = kwargs.pop("export_with_pir", False)
  208. config.update([f"Global.uniform_output_enabled={uniform_output_enabled}"])
  209. config.update([f"Global.pdx_model_name={self.name}"])
  210. if export_with_pir:
  211. config.update([f"Global.export_with_pir={export_with_pir}"])
  212. config.dump(config_path)
  213. self._assert_empty_kwargs(kwargs)
  214. return self.runner.export(config_path, [], None, save_dir)
  215. def infer(
  216. self,
  217. model_dir: str,
  218. input_path: str,
  219. device: str = "gpu",
  220. save_dir: str = None,
  221. dict_path: str = None,
  222. **kwargs,
  223. ) -> CompletedProcess:
  224. """predict image using infernece model
  225. Args:
  226. model_dir (str): the directory path of inference model files that would use to predict.
  227. input_path (str): the path of image that would be predict.
  228. device (str, optional): the running device. Defaults to 'gpu'.
  229. save_dir (str, optional): the directory path to save output. Defaults to None.
  230. dict_path (str, optional): the label dict file path. Defaults to None.
  231. Returns:
  232. CompletedProcess: the result of inferring subprocess execution.
  233. """
  234. model_dir = abspath(model_dir)
  235. input_path = abspath(input_path)
  236. if save_dir is not None:
  237. logging.warning("`save_dir` will not be used.")
  238. config_path = os.path.join(model_dir, "inference.yml")
  239. config = self.config.copy()
  240. config.load(config_path)
  241. config._update_inference_model_dir(model_dir)
  242. config._update_infer_img(input_path)
  243. config._update_infer_device(device)
  244. if dict_path is not None:
  245. dict_path = abspath(dict_path)
  246. config.update_label_dict_path(dict_path)
  247. if "enable_mkldnn" in kwargs:
  248. config._update_enable_mkldnn(kwargs.pop("enable_mkldnn"))
  249. with self._create_new_config_file() as config_path:
  250. config.dump(config_path)
  251. self._assert_empty_kwargs(kwargs)
  252. return self.runner.infer(config_path, [], device)
  253. def compression(
  254. self,
  255. weight_path: str,
  256. batch_size: int = None,
  257. learning_rate: float = None,
  258. epochs_iters: int = None,
  259. device: str = "gpu",
  260. use_vdl: bool = True,
  261. save_dir: str = None,
  262. **kwargs,
  263. ) -> CompletedProcess:
  264. """compression model
  265. Args:
  266. weight_path (str): the path to weight file of model.
  267. batch_size (int, optional): the batch size value of compression training. Defaults to None.
  268. learning_rate (float, optional): the learning rate value of compression training. Defaults to None.
  269. epochs_iters (int, optional): the epochs or iters of compression training. Defaults to None.
  270. device (str, optional): the device to run compression training. Defaults to 'gpu'.
  271. use_vdl (bool, optional): whether or not to use VisualDL. Defaults to True.
  272. save_dir (str, optional): the directory to save output. Defaults to None.
  273. Returns:
  274. CompletedProcess: the result of compression subprocess execution.
  275. """
  276. weight_path = abspath(weight_path)
  277. with self._create_new_config_file() as config_path:
  278. # Update YAML config file
  279. config = self.config.copy()
  280. config._update_amp(None)
  281. config.update_device(device)
  282. config._update_use_vdl(use_vdl)
  283. config._update_slim_config(self.model_info["auto_compression_config_path"])
  284. config.update_pretrained_weights(weight_path)
  285. if batch_size is not None:
  286. config.update_batch_size(batch_size)
  287. if learning_rate is not None:
  288. config.update_learning_rate(learning_rate)
  289. if epochs_iters is not None:
  290. config._update_epochs(epochs_iters)
  291. if save_dir is not None:
  292. save_dir = abspath(save_dir)
  293. else:
  294. # `save_dir` is None
  295. save_dir = abspath(config.get_train_save_dir())
  296. config._update_output_dir(save_dir)
  297. config.dump(config_path)
  298. export_cli_args = []
  299. export_cli_args.append(
  300. CLIArgument(
  301. "-o",
  302. f"Global.save_inference_dir={os.path.join(save_dir, 'export')}",
  303. )
  304. )
  305. self._assert_empty_kwargs(kwargs)
  306. return self.runner.compression(
  307. config_path, [], export_cli_args, device, save_dir
  308. )