model.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from ...base import BaseModel
  16. from ...base.utils.arg import CLIArgument
  17. from ...base.utils.subprocess import CompletedProcess
  18. from ....utils.misc import abspath
  19. class SegModel(BaseModel):
  20. """ Semantic Segmentation Model """
  21. def train(self,
  22. batch_size: int=None,
  23. learning_rate: float=None,
  24. epochs_iters: int=None,
  25. ips: str=None,
  26. device: str='gpu',
  27. resume_path: str=None,
  28. dy2st: bool=False,
  29. amp: str='OFF',
  30. num_workers: int=None,
  31. use_vdl: bool=True,
  32. save_dir: str=None,
  33. **kwargs) -> CompletedProcess:
  34. """train self
  35. Args:
  36. batch_size (int, optional): the train batch size value. Defaults to None.
  37. learning_rate (float, optional): the train learning rate value. Defaults to None.
  38. epochs_iters (int, optional): the train epochs value. Defaults to None.
  39. ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
  40. device (str, optional): the running device. Defaults to 'gpu'.
  41. resume_path (str, optional): the checkpoint file path to resume training. Train from scratch if it is set
  42. to None. Defaults to None.
  43. dy2st (bool, optional): Enable dynamic to static. Defaults to False.
  44. amp (str, optional): the amp settings. Defaults to 'OFF'.
  45. num_workers (int, optional): the workers number. Defaults to None.
  46. use_vdl (bool, optional): enable VisualDL. Defaults to True.
  47. save_dir (str, optional): the directory path to save train output. Defaults to None.
  48. Returns:
  49. CompletedProcess: the result of training subprocess execution.
  50. """
  51. config = self.config.copy()
  52. cli_args = []
  53. if batch_size is not None:
  54. cli_args.append(CLIArgument('--batch_size', batch_size))
  55. if learning_rate is not None:
  56. cli_args.append(CLIArgument('--learning_rate', learning_rate))
  57. if epochs_iters is not None:
  58. cli_args.append(CLIArgument('--iters', epochs_iters))
  59. # No need to handle `ips`
  60. if device is not None:
  61. device_type, _ = self.runner.parse_device(device)
  62. cli_args.append(CLIArgument('--device', device_type))
  63. # For compatibility
  64. resume_dir = kwargs.pop('resume_dir', None)
  65. if resume_path is None and resume_dir is not None:
  66. resume_path = os.path.join(resume_dir, 'model.pdparams')
  67. if resume_path is not None:
  68. # NOTE: We must use an absolute path here,
  69. # so we can run the scripts either inside or outside the repo dir.
  70. resume_path = abspath(resume_path)
  71. if os.path.basename(resume_path) != 'model.pdparams':
  72. raise ValueError(f"{resume_path} has an incorrect file name.")
  73. if not os.path.exists(resume_path):
  74. raise FileNotFoundError(f"{resume_path} does not exist.")
  75. resume_dir = os.path.dirname(resume_path)
  76. opts_path = os.path.join(resume_dir, 'model.pdopt')
  77. if not os.path.exists(opts_path):
  78. raise FileNotFoundError(f"{opts_path} must exist.")
  79. cli_args.append(CLIArgument('--resume_model', resume_dir))
  80. if dy2st:
  81. config.update_dy2st(dy2st)
  82. if use_vdl:
  83. cli_args.append(CLIArgument('--use_vdl'))
  84. if save_dir is not None:
  85. save_dir = abspath(save_dir)
  86. else:
  87. # `save_dir` is None
  88. save_dir = abspath(os.path.join('output', 'train'))
  89. cli_args.append(CLIArgument('--save_dir', save_dir))
  90. save_interval = kwargs.pop('save_interval', None)
  91. if save_interval is not None:
  92. cli_args.append(CLIArgument('--save_interval', save_interval))
  93. do_eval = kwargs.pop('do_eval', True)
  94. repeats = kwargs.pop('repeats', None)
  95. seed = kwargs.pop('seed', None)
  96. profile = kwargs.pop('profile', None)
  97. if profile is not None:
  98. cli_args.append(CLIArgument('--profiler_options', profile))
  99. log_iters = kwargs.pop('log_iters', None)
  100. if log_iters is not None:
  101. cli_args.append(CLIArgument('--log_iters', log_iters))
  102. # Benchmarking mode settings
  103. benchmark = kwargs.pop('benchmark', None)
  104. if benchmark is not None:
  105. envs = benchmark.get('env', None)
  106. seed = benchmark.get('seed', None)
  107. repeats = benchmark.get('repeats', None)
  108. do_eval = benchmark.get('do_eval', False)
  109. num_workers = benchmark.get('num_workers', None)
  110. config.update_log_ranks(device)
  111. amp = benchmark.get('amp', None)
  112. config.update_print_mem_info(benchmark.get('print_mem_info', True))
  113. config.update_shuffle(benchmark.get('shuffle', False))
  114. if repeats is not None:
  115. assert isinstance(repeats, int), 'repeats must be an integer.'
  116. cli_args.append(CLIArgument('--repeats', repeats))
  117. if num_workers is not None:
  118. assert isinstance(num_workers,
  119. int), 'num_workers must be an integer.'
  120. cli_args.append(CLIArgument('--num_workers', num_workers))
  121. if seed is not None:
  122. assert isinstance(seed, int), 'seed must be an integer.'
  123. cli_args.append(CLIArgument('--seed', seed))
  124. if amp in ['O1', 'O2']:
  125. cli_args.append(CLIArgument('--precision', 'fp16'))
  126. cli_args.append(CLIArgument('--amp_level', amp))
  127. if envs is not None:
  128. for env_name, env_value in envs.items():
  129. os.environ[env_name] = str(env_value)
  130. else:
  131. if amp is not None:
  132. if amp != 'OFF':
  133. cli_args.append(CLIArgument('--precision', 'fp16'))
  134. cli_args.append(CLIArgument('--amp_level', amp))
  135. if num_workers is not None:
  136. cli_args.append(CLIArgument('--num_workers', num_workers))
  137. if repeats is not None:
  138. cli_args.append(CLIArgument('--repeats', repeats))
  139. if seed is not None:
  140. cli_args.append(CLIArgument('--seed', seed))
  141. self._assert_empty_kwargs(kwargs)
  142. with self._create_new_config_file() as config_path:
  143. config.dump(config_path)
  144. return self.runner.train(
  145. config_path, cli_args, device, ips, save_dir, do_eval=do_eval)
  146. def evaluate(self,
  147. weight_path: str,
  148. batch_size: int=None,
  149. ips: str=None,
  150. device: str='gpu',
  151. amp: str='OFF',
  152. num_workers: int=None,
  153. **kwargs) -> CompletedProcess:
  154. """evaluate self using specified weight
  155. Args:
  156. weight_path (str): the path of model weight file to be evaluated.
  157. batch_size (int, optional): the batch size value in evaluating. Defaults to None.
  158. ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
  159. device (str, optional): the running device. Defaults to 'gpu'.
  160. amp (str, optional): the AMP setting. Defaults to 'OFF'.
  161. num_workers (int, optional): the workers number in evaluating. Defaults to None.
  162. Returns:
  163. CompletedProcess: the result of evaluating subprocess execution.
  164. """
  165. config = self.config.copy()
  166. cli_args = []
  167. weight_path = abspath(weight_path)
  168. cli_args.append(CLIArgument('--model_path', weight_path))
  169. if batch_size is not None:
  170. if batch_size != 1:
  171. raise ValueError("Batch size other than 1 is not supported.")
  172. # No need to handle `ips`
  173. if device is not None:
  174. device_type, _ = self.runner.parse_device(device)
  175. cli_args.append(CLIArgument('--device', device_type))
  176. if amp is not None:
  177. if amp != 'OFF':
  178. cli_args.append(CLIArgument('--precision', 'fp16'))
  179. cli_args.append(CLIArgument('--amp_level', amp))
  180. if num_workers is not None:
  181. cli_args.append(CLIArgument('--num_workers', num_workers))
  182. self._assert_empty_kwargs(kwargs)
  183. with self._create_new_config_file() as config_path:
  184. config.dump(config_path)
  185. cp = self.runner.evaluate(config_path, cli_args, device, ips)
  186. return cp
  187. def predict(self,
  188. weight_path: str,
  189. input_path: str,
  190. device: str='gpu',
  191. save_dir: str=None,
  192. **kwargs) -> CompletedProcess:
  193. """predict using specified weight
  194. Args:
  195. weight_path (str): the path of model weight file used to predict.
  196. input_path (str): the path of image file to be predicted.
  197. device (str, optional): the running device. Defaults to 'gpu'.
  198. save_dir (str, optional): the directory path to save predict output. Defaults to None.
  199. Returns:
  200. CompletedProcess: the result of predicting subprocess execution.
  201. """
  202. config = self.config.copy()
  203. cli_args = []
  204. weight_path = abspath(weight_path)
  205. cli_args.append(CLIArgument('--model_path', weight_path))
  206. input_path = abspath(input_path)
  207. cli_args.append(CLIArgument('--image_path', input_path))
  208. if device is not None:
  209. device_type, _ = self.runner.parse_device(device)
  210. cli_args.append(CLIArgument('--device', device_type))
  211. if save_dir is not None:
  212. save_dir = abspath(save_dir)
  213. else:
  214. # `save_dir` is None
  215. save_dir = abspath(os.path.join('output', 'predict'))
  216. cli_args.append(CLIArgument('--save_dir', save_dir))
  217. self._assert_empty_kwargs(kwargs)
  218. with self._create_new_config_file() as config_path:
  219. config.dump(config_path)
  220. return self.runner.predict(config_path, cli_args, device)
  221. def analyse(self,
  222. weight_path,
  223. ips=None,
  224. device='gpu',
  225. save_dir=None,
  226. **kwargs):
  227. """ analyse """
  228. config = self.config.copy()
  229. cli_args = []
  230. weight_path = abspath(weight_path)
  231. cli_args.append(CLIArgument('--model_path', weight_path))
  232. if device is not None:
  233. device_type, _ = self.runner.parse_device(device)
  234. cli_args.append(CLIArgument('--device', device_type))
  235. if save_dir is not None:
  236. save_dir = abspath(save_dir)
  237. else:
  238. # `save_dir` is None
  239. save_dir = abspath(os.path.join('output', 'analysis'))
  240. cli_args.append(CLIArgument('--save_dir', save_dir))
  241. self._assert_empty_kwargs(kwargs)
  242. with self._create_new_config_file() as config_path:
  243. config.dump(config_path)
  244. cp = self.runner.analyse(config_path, cli_args, device, ips)
  245. return cp
  246. def export(self, weight_path: str, save_dir: str,
  247. **kwargs) -> CompletedProcess:
  248. """export the dynamic model to static model
  249. Args:
  250. weight_path (str): the model weight file path that used to export.
  251. save_dir (str): the directory path to save export output.
  252. Returns:
  253. CompletedProcess: the result of exporting subprocess execution.
  254. """
  255. config = self.config.copy()
  256. cli_args = []
  257. weight_path = abspath(weight_path)
  258. cli_args.append(CLIArgument('--model_path', weight_path))
  259. if save_dir is not None:
  260. save_dir = abspath(save_dir)
  261. else:
  262. # `save_dir` is None
  263. save_dir = abspath(os.path.join('output', 'export'))
  264. cli_args.append(CLIArgument('--save_dir', save_dir))
  265. input_shape = kwargs.pop('input_shape', None)
  266. if input_shape is not None:
  267. cli_args.append(CLIArgument('--input_shape', *input_shape))
  268. output_op = kwargs.pop('output_op', None)
  269. if output_op is not None:
  270. assert output_op in ['softmax', 'argmax'
  271. ], "`output_op` must be 'softmax' or 'argmax'."
  272. cli_args.append(CLIArgument('--output_op', output_op))
  273. self._assert_empty_kwargs(kwargs)
  274. with self._create_new_config_file() as config_path:
  275. config.dump(config_path)
  276. return self.runner.export(config_path, cli_args, None)
  277. def infer(self,
  278. model_dir: str,
  279. input_path: str,
  280. device: str='gpu',
  281. save_dir: str=None,
  282. **kwargs) -> CompletedProcess:
  283. """predict image using infernece model
  284. Args:
  285. model_dir (str): the directory path of inference model files that would use to predict.
  286. input_path (str): the path of image that would be predict.
  287. device (str, optional): the running device. Defaults to 'gpu'.
  288. save_dir (str, optional): the directory path to save output. Defaults to None.
  289. Returns:
  290. CompletedProcess: the result of infering subprocess execution.
  291. """
  292. config = self.config.copy()
  293. cli_args = []
  294. model_dir = abspath(model_dir)
  295. input_path = abspath(input_path)
  296. cli_args.append(CLIArgument('--image_path', input_path))
  297. if device is not None:
  298. device_type, _ = self.runner.parse_device(device)
  299. cli_args.append(CLIArgument('--device', device_type))
  300. if save_dir is not None:
  301. save_dir = abspath(save_dir)
  302. else:
  303. # `save_dir` is None
  304. save_dir = abspath(os.path.join('output', 'infer'))
  305. cli_args.append(CLIArgument('--save_dir', save_dir))
  306. self._assert_empty_kwargs(kwargs)
  307. with self._create_new_config_file() as config_path:
  308. config.dump(config_path)
  309. deploy_config_path = os.path.join(model_dir, 'inference.yml')
  310. return self.runner.infer(deploy_config_path, cli_args, device)
  311. def compression(self,
  312. weight_path: str,
  313. batch_size: int=None,
  314. learning_rate: float=None,
  315. epochs_iters: int=None,
  316. device: str='gpu',
  317. use_vdl: bool=True,
  318. save_dir: str=None,
  319. **kwargs) -> CompletedProcess:
  320. """compression model
  321. Args:
  322. weight_path (str): the path to weight file of model.
  323. batch_size (int, optional): the batch size value of compression training. Defaults to None.
  324. learning_rate (float, optional): the learning rate value of compression training. Defaults to None.
  325. epochs_iters (int, optional): the epochs or iters of compression training. Defaults to None.
  326. device (str, optional): the device to run compression training. Defaults to 'gpu'.
  327. use_vdl (bool, optional): whether or not to use VisualDL. Defaults to True.
  328. save_dir (str, optional): the directory to save output. Defaults to None.
  329. Returns:
  330. CompletedProcess: the result of compression subprocess execution.
  331. """
  332. # Update YAML config file
  333. # NOTE: In PaddleSeg, QAT does not use a different config file than regular training
  334. # Reusing `self.config` preserves the config items modified by the user when
  335. # `SegModel` is initialized with a `SegConfig` object.
  336. config = self.config.copy()
  337. train_cli_args = []
  338. export_cli_args = []
  339. weight_path = abspath(weight_path)
  340. train_cli_args.append(CLIArgument('--model_path', weight_path))
  341. if batch_size is not None:
  342. train_cli_args.append(CLIArgument('--batch_size', batch_size))
  343. if learning_rate is not None:
  344. train_cli_args.append(CLIArgument('--learning_rate', learning_rate))
  345. if epochs_iters is not None:
  346. train_cli_args.append(CLIArgument('--iters', epochs_iters))
  347. if device is not None:
  348. device_type, _ = self.runner.parse_device(device)
  349. train_cli_args.append(CLIArgument('--device', device_type))
  350. if use_vdl:
  351. train_cli_args.append(CLIArgument('--use_vdl'))
  352. if save_dir is not None:
  353. save_dir = abspath(save_dir)
  354. else:
  355. # `save_dir` is None
  356. save_dir = abspath(os.path.join('output', 'compress'))
  357. train_cli_args.append(CLIArgument('--save_dir', save_dir))
  358. # The exported model saved in a subdirectory named `export`
  359. export_cli_args.append(
  360. CLIArgument('--save_dir', os.path.join(save_dir, 'export')))
  361. input_shape = kwargs.pop('input_shape', None)
  362. if input_shape is not None:
  363. export_cli_args.append(CLIArgument('--input_shape', *input_shape))
  364. self._assert_empty_kwargs(kwargs)
  365. with self._create_new_config_file() as config_path:
  366. config.dump(config_path)
  367. return self.runner.compression(config_path, train_cli_args,
  368. export_cli_args, device, save_dir)