model.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from ...base import BaseModel
  16. from ...base.utils.arg import CLIArgument
  17. from ...base.utils.subprocess import CompletedProcess
  18. from ....utils.misc import abspath
  19. class SegModel(BaseModel):
  20. """ Semantic Segmentation Model """
  21. def train(self,
  22. batch_size: int=None,
  23. learning_rate: float=None,
  24. epochs_iters: int=None,
  25. ips: str=None,
  26. device: str='gpu',
  27. resume_path: str=None,
  28. dy2st: bool=False,
  29. amp: str='OFF',
  30. num_workers: int=None,
  31. use_vdl: bool=True,
  32. save_dir: str=None,
  33. **kwargs) -> CompletedProcess:
  34. """train self
  35. Args:
  36. batch_size (int, optional): the train batch size value. Defaults to None.
  37. learning_rate (float, optional): the train learning rate value. Defaults to None.
  38. epochs_iters (int, optional): the train epochs value. Defaults to None.
  39. ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
  40. device (str, optional): the running device. Defaults to 'gpu'.
  41. resume_path (str, optional): the checkpoint file path to resume training. Train from scratch if it is set
  42. to None. Defaults to None.
  43. dy2st (bool, optional): Enable dynamic to static. Defaults to False.
  44. amp (str, optional): the amp settings. Defaults to 'OFF'.
  45. num_workers (int, optional): the workers number. Defaults to None.
  46. use_vdl (bool, optional): enable VisualDL. Defaults to True.
  47. save_dir (str, optional): the directory path to save train output. Defaults to None.
  48. Returns:
  49. CompletedProcess: the result of training subprocess execution.
  50. """
  51. config = self.config.copy()
  52. cli_args = []
  53. if batch_size is not None:
  54. cli_args.append(CLIArgument('--batch_size', batch_size))
  55. if learning_rate is not None:
  56. cli_args.append(CLIArgument('--learning_rate', learning_rate))
  57. if epochs_iters is not None:
  58. cli_args.append(CLIArgument('--iters', epochs_iters))
  59. # No need to handle `ips`
  60. if device is not None:
  61. device_type, _ = self.runner.parse_device(device)
  62. cli_args.append(CLIArgument('--device', device_type))
  63. # For compatibility
  64. resume_dir = kwargs.pop('resume_dir', None)
  65. if resume_path is None and resume_dir is not None:
  66. resume_path = os.path.join(resume_dir, 'model.pdparams')
  67. if resume_path is not None:
  68. # NOTE: We must use an absolute path here,
  69. # so we can run the scripts either inside or outside the repo dir.
  70. resume_path = abspath(resume_path)
  71. if os.path.basename(resume_path) != 'model.pdparams':
  72. raise ValueError(f"{resume_path} has an incorrect file name.")
  73. if not os.path.exists(resume_path):
  74. raise FileNotFoundError(f"{resume_path} does not exist.")
  75. resume_dir = os.path.dirname(resume_path)
  76. opts_path = os.path.join(resume_dir, 'model.pdopt')
  77. if not os.path.exists(opts_path):
  78. raise FileNotFoundError(f"{opts_path} must exist.")
  79. cli_args.append(CLIArgument('--resume_model', resume_dir))
  80. if dy2st:
  81. config.update_dy2st(dy2st)
  82. if use_vdl:
  83. cli_args.append(CLIArgument('--use_vdl'))
  84. if save_dir is not None:
  85. save_dir = abspath(save_dir)
  86. else:
  87. # `save_dir` is None
  88. save_dir = abspath(os.path.join('output', 'train'))
  89. cli_args.append(CLIArgument('--save_dir', save_dir))
  90. save_interval = kwargs.pop('save_interval', None)
  91. if save_interval is not None:
  92. cli_args.append(CLIArgument('--save_interval', save_interval))
  93. do_eval = kwargs.pop('do_eval', True)
  94. repeats = kwargs.pop('repeats', None)
  95. seed = kwargs.pop('seed', None)
  96. profile = kwargs.pop('profile', None)
  97. if profile is not None:
  98. cli_args.append(CLIArgument('--profiler_options', profile))
  99. log_iters = kwargs.pop('log_iters', None)
  100. if log_iters is not None:
  101. cli_args.append(CLIArgument('--log_iters', log_iters))
  102. # Benchmarking mode settings
  103. benchmark = kwargs.pop('benchmark', None)
  104. if benchmark is not None:
  105. envs = benchmark.get('env', None)
  106. seed = benchmark.get('seed', None)
  107. repeats = benchmark.get('repeats', None)
  108. do_eval = benchmark.get('do_eval', False)
  109. num_workers = benchmark.get('num_workers', None)
  110. config.update_log_ranks(device)
  111. amp = benchmark.get('amp', None)
  112. config.update_print_mem_info(benchmark.get('print_mem_info', True))
  113. if repeats is not None:
  114. assert isinstance(repeats, int), 'repeats must be an integer.'
  115. cli_args.append(CLIArgument('--repeats', repeats))
  116. if num_workers is not None:
  117. assert isinstance(num_workers,
  118. int), 'num_workers must be an integer.'
  119. cli_args.append(CLIArgument('--num_workers', num_workers))
  120. if seed is not None:
  121. assert isinstance(seed, int), 'seed must be an integer.'
  122. cli_args.append(CLIArgument('--seed', seed))
  123. if amp in ['O1', 'O2']:
  124. cli_args.append(CLIArgument('--precision', 'fp16'))
  125. cli_args.append(CLIArgument('--amp_level', amp))
  126. if envs is not None:
  127. for env_name, env_value in envs.items():
  128. os.environ[env_name] = str(env_value)
  129. else:
  130. if amp is not None:
  131. if amp != 'OFF':
  132. cli_args.append(CLIArgument('--precision', 'fp16'))
  133. cli_args.append(CLIArgument('--amp_level', amp))
  134. if num_workers is not None:
  135. cli_args.append(CLIArgument('--num_workers', num_workers))
  136. if repeats is not None:
  137. cli_args.append(CLIArgument('--repeats', repeats))
  138. if seed is not None:
  139. cli_args.append(CLIArgument('--seed', seed))
  140. self._assert_empty_kwargs(kwargs)
  141. with self._create_new_config_file() as config_path:
  142. config.dump(config_path)
  143. return self.runner.train(
  144. config_path, cli_args, device, ips, save_dir, do_eval=do_eval)
  145. def evaluate(self,
  146. weight_path: str,
  147. batch_size: int=None,
  148. ips: str=None,
  149. device: str='gpu',
  150. amp: str='OFF',
  151. num_workers: int=None,
  152. **kwargs) -> CompletedProcess:
  153. """evaluate self using specified weight
  154. Args:
  155. weight_path (str): the path of model weight file to be evaluated.
  156. batch_size (int, optional): the batch size value in evaluating. Defaults to None.
  157. ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
  158. device (str, optional): the running device. Defaults to 'gpu'.
  159. amp (str, optional): the AMP setting. Defaults to 'OFF'.
  160. num_workers (int, optional): the workers number in evaluating. Defaults to None.
  161. Returns:
  162. CompletedProcess: the result of evaluating subprocess execution.
  163. """
  164. config = self.config.copy()
  165. cli_args = []
  166. weight_path = abspath(weight_path)
  167. cli_args.append(CLIArgument('--model_path', weight_path))
  168. if batch_size is not None:
  169. if batch_size != 1:
  170. raise ValueError("Batch size other than 1 is not supported.")
  171. # No need to handle `ips`
  172. if device is not None:
  173. device_type, _ = self.runner.parse_device(device)
  174. cli_args.append(CLIArgument('--device', device_type))
  175. if amp is not None:
  176. if amp != 'OFF':
  177. cli_args.append(CLIArgument('--precision', 'fp16'))
  178. cli_args.append(CLIArgument('--amp_level', amp))
  179. if num_workers is not None:
  180. cli_args.append(CLIArgument('--num_workers', num_workers))
  181. self._assert_empty_kwargs(kwargs)
  182. with self._create_new_config_file() as config_path:
  183. config.dump(config_path)
  184. cp = self.runner.evaluate(config_path, cli_args, device, ips)
  185. return cp
  186. def predict(self,
  187. weight_path: str,
  188. input_path: str,
  189. device: str='gpu',
  190. save_dir: str=None,
  191. **kwargs) -> CompletedProcess:
  192. """predict using specified weight
  193. Args:
  194. weight_path (str): the path of model weight file used to predict.
  195. input_path (str): the path of image file to be predicted.
  196. device (str, optional): the running device. Defaults to 'gpu'.
  197. save_dir (str, optional): the directory path to save predict output. Defaults to None.
  198. Returns:
  199. CompletedProcess: the result of predicting subprocess execution.
  200. """
  201. config = self.config.copy()
  202. cli_args = []
  203. weight_path = abspath(weight_path)
  204. cli_args.append(CLIArgument('--model_path', weight_path))
  205. input_path = abspath(input_path)
  206. cli_args.append(CLIArgument('--image_path', input_path))
  207. if device is not None:
  208. device_type, _ = self.runner.parse_device(device)
  209. cli_args.append(CLIArgument('--device', device_type))
  210. if save_dir is not None:
  211. save_dir = abspath(save_dir)
  212. else:
  213. # `save_dir` is None
  214. save_dir = abspath(os.path.join('output', 'predict'))
  215. cli_args.append(CLIArgument('--save_dir', save_dir))
  216. self._assert_empty_kwargs(kwargs)
  217. with self._create_new_config_file() as config_path:
  218. config.dump(config_path)
  219. return self.runner.predict(config_path, cli_args, device)
  220. def analyse(self,
  221. weight_path,
  222. ips=None,
  223. device='gpu',
  224. save_dir=None,
  225. **kwargs):
  226. """ analyse """
  227. config = self.config.copy()
  228. cli_args = []
  229. weight_path = abspath(weight_path)
  230. cli_args.append(CLIArgument('--model_path', weight_path))
  231. if device is not None:
  232. device_type, _ = self.runner.parse_device(device)
  233. cli_args.append(CLIArgument('--device', device_type))
  234. if save_dir is not None:
  235. save_dir = abspath(save_dir)
  236. else:
  237. # `save_dir` is None
  238. save_dir = abspath(os.path.join('output', 'analysis'))
  239. cli_args.append(CLIArgument('--save_dir', save_dir))
  240. self._assert_empty_kwargs(kwargs)
  241. with self._create_new_config_file() as config_path:
  242. config.dump(config_path)
  243. cp = self.runner.analyse(config_path, cli_args, device, ips)
  244. return cp
  245. def export(self, weight_path: str, save_dir: str,
  246. **kwargs) -> CompletedProcess:
  247. """export the dynamic model to static model
  248. Args:
  249. weight_path (str): the model weight file path that used to export.
  250. save_dir (str): the directory path to save export output.
  251. Returns:
  252. CompletedProcess: the result of exporting subprocess execution.
  253. """
  254. config = self.config.copy()
  255. cli_args = []
  256. weight_path = abspath(weight_path)
  257. cli_args.append(CLIArgument('--model_path', weight_path))
  258. if save_dir is not None:
  259. save_dir = abspath(save_dir)
  260. else:
  261. # `save_dir` is None
  262. save_dir = abspath(os.path.join('output', 'export'))
  263. cli_args.append(CLIArgument('--save_dir', save_dir))
  264. input_shape = kwargs.pop('input_shape', None)
  265. if input_shape is not None:
  266. cli_args.append(CLIArgument('--input_shape', *input_shape))
  267. output_op = kwargs.pop('output_op', None)
  268. if output_op is not None:
  269. assert output_op in ['softmax', 'argmax'
  270. ], "`output_op` must be 'softmax' or 'argmax'."
  271. cli_args.append(CLIArgument('--output_op', output_op))
  272. self._assert_empty_kwargs(kwargs)
  273. with self._create_new_config_file() as config_path:
  274. config.dump(config_path)
  275. return self.runner.export(config_path, cli_args, None)
  276. def infer(self,
  277. model_dir: str,
  278. input_path: str,
  279. device: str='gpu',
  280. save_dir: str=None,
  281. **kwargs) -> CompletedProcess:
  282. """predict image using infernece model
  283. Args:
  284. model_dir (str): the directory path of inference model files that would use to predict.
  285. input_path (str): the path of image that would be predict.
  286. device (str, optional): the running device. Defaults to 'gpu'.
  287. save_dir (str, optional): the directory path to save output. Defaults to None.
  288. Returns:
  289. CompletedProcess: the result of infering subprocess execution.
  290. """
  291. config = self.config.copy()
  292. cli_args = []
  293. model_dir = abspath(model_dir)
  294. input_path = abspath(input_path)
  295. cli_args.append(CLIArgument('--image_path', input_path))
  296. if device is not None:
  297. device_type, _ = self.runner.parse_device(device)
  298. cli_args.append(CLIArgument('--device', device_type))
  299. if save_dir is not None:
  300. save_dir = abspath(save_dir)
  301. else:
  302. # `save_dir` is None
  303. save_dir = abspath(os.path.join('output', 'infer'))
  304. cli_args.append(CLIArgument('--save_dir', save_dir))
  305. self._assert_empty_kwargs(kwargs)
  306. with self._create_new_config_file() as config_path:
  307. config.dump(config_path)
  308. deploy_config_path = os.path.join(model_dir, 'inference.yml')
  309. return self.runner.infer(deploy_config_path, cli_args, device)
  310. def compression(self,
  311. weight_path: str,
  312. batch_size: int=None,
  313. learning_rate: float=None,
  314. epochs_iters: int=None,
  315. device: str='gpu',
  316. use_vdl: bool=True,
  317. save_dir: str=None,
  318. **kwargs) -> CompletedProcess:
  319. """compression model
  320. Args:
  321. weight_path (str): the path to weight file of model.
  322. batch_size (int, optional): the batch size value of compression training. Defaults to None.
  323. learning_rate (float, optional): the learning rate value of compression training. Defaults to None.
  324. epochs_iters (int, optional): the epochs or iters of compression training. Defaults to None.
  325. device (str, optional): the device to run compression training. Defaults to 'gpu'.
  326. use_vdl (bool, optional): whether or not to use VisualDL. Defaults to True.
  327. save_dir (str, optional): the directory to save output. Defaults to None.
  328. Returns:
  329. CompletedProcess: the result of compression subprocess execution.
  330. """
  331. # Update YAML config file
  332. # NOTE: In PaddleSeg, QAT does not use a different config file than regular training
  333. # Reusing `self.config` preserves the config items modified by the user when
  334. # `SegModel` is initialized with a `SegConfig` object.
  335. config = self.config.copy()
  336. train_cli_args = []
  337. export_cli_args = []
  338. weight_path = abspath(weight_path)
  339. train_cli_args.append(CLIArgument('--model_path', weight_path))
  340. if batch_size is not None:
  341. train_cli_args.append(CLIArgument('--batch_size', batch_size))
  342. if learning_rate is not None:
  343. train_cli_args.append(CLIArgument('--learning_rate', learning_rate))
  344. if epochs_iters is not None:
  345. train_cli_args.append(CLIArgument('--iters', epochs_iters))
  346. if device is not None:
  347. device_type, _ = self.runner.parse_device(device)
  348. train_cli_args.append(CLIArgument('--device', device_type))
  349. if use_vdl:
  350. train_cli_args.append(CLIArgument('--use_vdl'))
  351. if save_dir is not None:
  352. save_dir = abspath(save_dir)
  353. else:
  354. # `save_dir` is None
  355. save_dir = abspath(os.path.join('output', 'compress'))
  356. train_cli_args.append(CLIArgument('--save_dir', save_dir))
  357. # The exported model saved in a subdirectory named `export`
  358. export_cli_args.append(
  359. CLIArgument('--save_dir', os.path.join(save_dir, 'export')))
  360. input_shape = kwargs.pop('input_shape', None)
  361. if input_shape is not None:
  362. export_cli_args.append(CLIArgument('--input_shape', *input_shape))
  363. self._assert_empty_kwargs(kwargs)
  364. with self._create_new_config_file() as config_path:
  365. config.dump(config_path)
  366. return self.runner.compression(config_path, train_cli_args,
  367. export_cli_args, device, save_dir)