model.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from ...base import BaseModel
  16. from ...base.utils.arg import CLIArgument
  17. from ...base.utils.subprocess import CompletedProcess
  18. from ....utils.misc import abspath
  19. class SegModel(BaseModel):
  20. """ Semantic Segmentation Model """
  21. def train(self,
  22. batch_size: int=None,
  23. learning_rate: float=None,
  24. epochs_iters: int=None,
  25. ips: str=None,
  26. device: str='gpu',
  27. resume_path: str=None,
  28. dy2st: bool=False,
  29. amp: str='OFF',
  30. num_workers: int=None,
  31. use_vdl: bool=True,
  32. save_dir: str=None,
  33. **kwargs) -> CompletedProcess:
  34. """train self
  35. Args:
  36. batch_size (int, optional): the train batch size value. Defaults to None.
  37. learning_rate (float, optional): the train learning rate value. Defaults to None.
  38. epochs_iters (int, optional): the train epochs value. Defaults to None.
  39. ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
  40. device (str, optional): the running device. Defaults to 'gpu'.
  41. resume_path (str, optional): the checkpoint file path to resume training. Train from scratch if it is set
  42. to None. Defaults to None.
  43. dy2st (bool, optional): Enable dynamic to static. Defaults to False.
  44. amp (str, optional): the amp settings. Defaults to 'OFF'.
  45. num_workers (int, optional): the workers number. Defaults to None.
  46. use_vdl (bool, optional): enable VisualDL. Defaults to True.
  47. save_dir (str, optional): the directory path to save train output. Defaults to None.
  48. Returns:
  49. CompletedProcess: the result of training subprocess execution.
  50. """
  51. config = self.config.copy()
  52. cli_args = []
  53. if batch_size is not None:
  54. cli_args.append(CLIArgument('--batch_size', batch_size))
  55. if learning_rate is not None:
  56. cli_args.append(CLIArgument('--learning_rate', learning_rate))
  57. if epochs_iters is not None:
  58. cli_args.append(CLIArgument('--iters', epochs_iters))
  59. # No need to handle `ips`
  60. if device is not None:
  61. device_type, _ = self.runner.parse_device(device)
  62. cli_args.append(CLIArgument('--device', device_type))
  63. # For compatibility
  64. resume_dir = kwargs.pop('resume_dir', None)
  65. if resume_path is None and resume_dir is not None:
  66. resume_path = os.path.join(resume_dir, 'model.pdparams')
  67. if resume_path is not None:
  68. # NOTE: We must use an absolute path here,
  69. # so we can run the scripts either inside or outside the repo dir.
  70. resume_path = abspath(resume_path)
  71. if os.path.basename(resume_path) != 'model.pdparams':
  72. raise ValueError(f"{resume_path} has an incorrect file name.")
  73. if not os.path.exists(resume_path):
  74. raise FileNotFoundError(f"{resume_path} does not exist.")
  75. resume_dir = os.path.dirname(resume_path)
  76. opts_path = os.path.join(resume_dir, 'model.pdopt')
  77. if not os.path.exists(opts_path):
  78. raise FileNotFoundError(f"{opts_path} must exist.")
  79. cli_args.append(CLIArgument('--resume_model', resume_dir))
  80. if dy2st:
  81. config.update_dy2st(dy2st)
  82. if amp is not None:
  83. if amp != 'OFF':
  84. cli_args.append(CLIArgument('--precision', 'fp16'))
  85. cli_args.append(CLIArgument('--amp_level', amp))
  86. if num_workers is not None:
  87. cli_args.append(CLIArgument('--num_workers', num_workers))
  88. if use_vdl:
  89. cli_args.append(CLIArgument('--use_vdl'))
  90. if save_dir is not None:
  91. save_dir = abspath(save_dir)
  92. else:
  93. # `save_dir` is None
  94. save_dir = abspath(os.path.join('output', 'train'))
  95. cli_args.append(CLIArgument('--save_dir', save_dir))
  96. save_interval = kwargs.pop('save_interval', None)
  97. if save_interval is not None:
  98. cli_args.append(CLIArgument('--save_interval', save_interval))
  99. do_eval = kwargs.pop('do_eval', True)
  100. profile = kwargs.pop('profile', None)
  101. if profile is not None:
  102. cli_args.append(CLIArgument('--profiler_options', profile))
  103. log_iters = kwargs.pop('log_iters', None)
  104. if log_iters is not None:
  105. cli_args.append(CLIArgument('--log_iters', log_iters))
  106. repeats = kwargs.pop('repeats', None)
  107. if repeats is not None:
  108. cli_args.append(CLIArgument('--repeats', repeats))
  109. seed = kwargs.pop('seed', None)
  110. if seed is not None:
  111. cli_args.append(CLIArgument('--seed', seed))
  112. self._assert_empty_kwargs(kwargs)
  113. with self._create_new_config_file() as config_path:
  114. config.dump(config_path)
  115. return self.runner.train(
  116. config_path, cli_args, device, ips, save_dir, do_eval=do_eval)
  117. def evaluate(self,
  118. weight_path: str,
  119. batch_size: int=None,
  120. ips: str=None,
  121. device: str='gpu',
  122. amp: str='OFF',
  123. num_workers: int=None,
  124. **kwargs) -> CompletedProcess:
  125. """evaluate self using specified weight
  126. Args:
  127. weight_path (str): the path of model weight file to be evaluated.
  128. batch_size (int, optional): the batch size value in evaluating. Defaults to None.
  129. ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
  130. device (str, optional): the running device. Defaults to 'gpu'.
  131. amp (str, optional): the AMP setting. Defaults to 'OFF'.
  132. num_workers (int, optional): the workers number in evaluating. Defaults to None.
  133. Returns:
  134. CompletedProcess: the result of evaluating subprocess execution.
  135. """
  136. config = self.config.copy()
  137. cli_args = []
  138. weight_path = abspath(weight_path)
  139. cli_args.append(CLIArgument('--model_path', weight_path))
  140. if batch_size is not None:
  141. if batch_size != 1:
  142. raise ValueError("Batch size other than 1 is not supported.")
  143. # No need to handle `ips`
  144. if device is not None:
  145. device_type, _ = self.runner.parse_device(device)
  146. cli_args.append(CLIArgument('--device', device_type))
  147. if amp is not None:
  148. if amp != 'OFF':
  149. cli_args.append(CLIArgument('--precision', 'fp16'))
  150. cli_args.append(CLIArgument('--amp_level', amp))
  151. if num_workers is not None:
  152. cli_args.append(CLIArgument('--num_workers', num_workers))
  153. self._assert_empty_kwargs(kwargs)
  154. with self._create_new_config_file() as config_path:
  155. config.dump(config_path)
  156. cp = self.runner.evaluate(config_path, cli_args, device, ips)
  157. return cp
  158. def predict(self,
  159. weight_path: str,
  160. input_path: str,
  161. device: str='gpu',
  162. save_dir: str=None,
  163. **kwargs) -> CompletedProcess:
  164. """predict using specified weight
  165. Args:
  166. weight_path (str): the path of model weight file used to predict.
  167. input_path (str): the path of image file to be predicted.
  168. device (str, optional): the running device. Defaults to 'gpu'.
  169. save_dir (str, optional): the directory path to save predict output. Defaults to None.
  170. Returns:
  171. CompletedProcess: the result of predicting subprocess execution.
  172. """
  173. config = self.config.copy()
  174. cli_args = []
  175. weight_path = abspath(weight_path)
  176. cli_args.append(CLIArgument('--model_path', weight_path))
  177. input_path = abspath(input_path)
  178. cli_args.append(CLIArgument('--image_path', input_path))
  179. if device is not None:
  180. device_type, _ = self.runner.parse_device(device)
  181. cli_args.append(CLIArgument('--device', device_type))
  182. if save_dir is not None:
  183. save_dir = abspath(save_dir)
  184. else:
  185. # `save_dir` is None
  186. save_dir = abspath(os.path.join('output', 'predict'))
  187. cli_args.append(CLIArgument('--save_dir', save_dir))
  188. self._assert_empty_kwargs(kwargs)
  189. with self._create_new_config_file() as config_path:
  190. config.dump(config_path)
  191. return self.runner.predict(config_path, cli_args, device)
  192. def analyse(self,
  193. weight_path,
  194. ips=None,
  195. device='gpu',
  196. save_dir=None,
  197. **kwargs):
  198. """ analyse """
  199. config = self.config.copy()
  200. cli_args = []
  201. weight_path = abspath(weight_path)
  202. cli_args.append(CLIArgument('--model_path', weight_path))
  203. if device is not None:
  204. device_type, _ = self.runner.parse_device(device)
  205. cli_args.append(CLIArgument('--device', device_type))
  206. if save_dir is not None:
  207. save_dir = abspath(save_dir)
  208. else:
  209. # `save_dir` is None
  210. save_dir = abspath(os.path.join('output', 'analysis'))
  211. cli_args.append(CLIArgument('--save_dir', save_dir))
  212. self._assert_empty_kwargs(kwargs)
  213. with self._create_new_config_file() as config_path:
  214. config.dump(config_path)
  215. cp = self.runner.analyse(config_path, cli_args, device, ips)
  216. return cp
  217. def export(self, weight_path: str, save_dir: str,
  218. **kwargs) -> CompletedProcess:
  219. """export the dynamic model to static model
  220. Args:
  221. weight_path (str): the model weight file path that used to export.
  222. save_dir (str): the directory path to save export output.
  223. Returns:
  224. CompletedProcess: the result of exporting subprocess execution.
  225. """
  226. config = self.config.copy()
  227. cli_args = []
  228. weight_path = abspath(weight_path)
  229. cli_args.append(CLIArgument('--model_path', weight_path))
  230. if save_dir is not None:
  231. save_dir = abspath(save_dir)
  232. else:
  233. # `save_dir` is None
  234. save_dir = abspath(os.path.join('output', 'export'))
  235. cli_args.append(CLIArgument('--save_dir', save_dir))
  236. input_shape = kwargs.pop('input_shape', None)
  237. if input_shape is not None:
  238. cli_args.append(CLIArgument('--input_shape', *input_shape))
  239. output_op = kwargs.pop('output_op', None)
  240. if output_op is not None:
  241. assert output_op in ['softmax', 'argmax'
  242. ], "`output_op` must be 'softmax' or 'argmax'."
  243. cli_args.append(CLIArgument('--output_op', output_op))
  244. self._assert_empty_kwargs(kwargs)
  245. with self._create_new_config_file() as config_path:
  246. config.dump(config_path)
  247. return self.runner.export(config_path, cli_args, None)
  248. def infer(self,
  249. model_dir: str,
  250. input_path: str,
  251. device: str='gpu',
  252. save_dir: str=None,
  253. **kwargs) -> CompletedProcess:
  254. """predict image using infernece model
  255. Args:
  256. model_dir (str): the directory path of inference model files that would use to predict.
  257. input_path (str): the path of image that would be predict.
  258. device (str, optional): the running device. Defaults to 'gpu'.
  259. save_dir (str, optional): the directory path to save output. Defaults to None.
  260. Returns:
  261. CompletedProcess: the result of infering subprocess execution.
  262. """
  263. config = self.config.copy()
  264. cli_args = []
  265. model_dir = abspath(model_dir)
  266. input_path = abspath(input_path)
  267. cli_args.append(CLIArgument('--image_path', input_path))
  268. if device is not None:
  269. device_type, _ = self.runner.parse_device(device)
  270. cli_args.append(CLIArgument('--device', device_type))
  271. if save_dir is not None:
  272. save_dir = abspath(save_dir)
  273. else:
  274. # `save_dir` is None
  275. save_dir = abspath(os.path.join('output', 'infer'))
  276. cli_args.append(CLIArgument('--save_dir', save_dir))
  277. self._assert_empty_kwargs(kwargs)
  278. with self._create_new_config_file() as config_path:
  279. config.dump(config_path)
  280. deploy_config_path = os.path.join(model_dir, 'inference.yml')
  281. return self.runner.infer(deploy_config_path, cli_args, device)
  282. def compression(self,
  283. weight_path: str,
  284. batch_size: int=None,
  285. learning_rate: float=None,
  286. epochs_iters: int=None,
  287. device: str='gpu',
  288. use_vdl: bool=True,
  289. save_dir: str=None,
  290. **kwargs) -> CompletedProcess:
  291. """compression model
  292. Args:
  293. weight_path (str): the path to weight file of model.
  294. batch_size (int, optional): the batch size value of compression training. Defaults to None.
  295. learning_rate (float, optional): the learning rate value of compression training. Defaults to None.
  296. epochs_iters (int, optional): the epochs or iters of compression training. Defaults to None.
  297. device (str, optional): the device to run compression training. Defaults to 'gpu'.
  298. use_vdl (bool, optional): whether or not to use VisualDL. Defaults to True.
  299. save_dir (str, optional): the directory to save output. Defaults to None.
  300. Returns:
  301. CompletedProcess: the result of compression subprocess execution.
  302. """
  303. # Update YAML config file
  304. # NOTE: In PaddleSeg, QAT does not use a different config file than regular training
  305. # Reusing `self.config` preserves the config items modified by the user when
  306. # `SegModel` is initialized with a `SegConfig` object.
  307. config = self.config.copy()
  308. train_cli_args = []
  309. export_cli_args = []
  310. weight_path = abspath(weight_path)
  311. train_cli_args.append(CLIArgument('--model_path', weight_path))
  312. if batch_size is not None:
  313. train_cli_args.append(CLIArgument('--batch_size', batch_size))
  314. if learning_rate is not None:
  315. train_cli_args.append(CLIArgument('--learning_rate', learning_rate))
  316. if epochs_iters is not None:
  317. train_cli_args.append(CLIArgument('--iters', epochs_iters))
  318. if device is not None:
  319. device_type, _ = self.runner.parse_device(device)
  320. train_cli_args.append(CLIArgument('--device', device_type))
  321. if use_vdl:
  322. train_cli_args.append(CLIArgument('--use_vdl'))
  323. if save_dir is not None:
  324. save_dir = abspath(save_dir)
  325. else:
  326. # `save_dir` is None
  327. save_dir = abspath(os.path.join('output', 'compress'))
  328. train_cli_args.append(CLIArgument('--save_dir', save_dir))
  329. # The exported model saved in a subdirectory named `export`
  330. export_cli_args.append(
  331. CLIArgument('--save_dir', os.path.join(save_dir, 'export')))
  332. input_shape = kwargs.pop('input_shape', None)
  333. if input_shape is not None:
  334. export_cli_args.append(CLIArgument('--input_shape', *input_shape))
  335. self._assert_empty_kwargs(kwargs)
  336. with self._create_new_config_file() as config_path:
  337. config.dump(config_path)
  338. return self.runner.compression(config_path, train_cli_args,
  339. export_cli_args, device, save_dir)