model.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from ...base import BaseModel
  16. from ...base.utils.arg import CLIArgument
  17. from ...base.utils.subprocess import CompletedProcess
  18. from ....utils.misc import abspath
  19. from ....utils.download import download
  20. from ....utils.cache import DEFAULT_CACHE_DIR
  21. class SegModel(BaseModel):
  22. """Semantic Segmentation Model"""
  23. def train(
  24. self,
  25. batch_size: int = None,
  26. learning_rate: float = None,
  27. epochs_iters: int = None,
  28. ips: str = None,
  29. device: str = "gpu",
  30. resume_path: str = None,
  31. dy2st: bool = False,
  32. amp: str = "OFF",
  33. num_workers: int = None,
  34. use_vdl: bool = True,
  35. save_dir: str = None,
  36. **kwargs,
  37. ) -> CompletedProcess:
  38. """train self
  39. Args:
  40. batch_size (int, optional): the train batch size value. Defaults to None.
  41. learning_rate (float, optional): the train learning rate value. Defaults to None.
  42. epochs_iters (int, optional): the train epochs value. Defaults to None.
  43. ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
  44. device (str, optional): the running device. Defaults to 'gpu'.
  45. resume_path (str, optional): the checkpoint file path to resume training. Train from scratch if it is set
  46. to None. Defaults to None.
  47. dy2st (bool, optional): Enable dynamic to static. Defaults to False.
  48. amp (str, optional): the amp settings. Defaults to 'OFF'.
  49. num_workers (int, optional): the workers number. Defaults to None.
  50. use_vdl (bool, optional): enable VisualDL. Defaults to True.
  51. save_dir (str, optional): the directory path to save train output. Defaults to None.
  52. Returns:
  53. CompletedProcess: the result of training subprocess execution.
  54. """
  55. config = self.config.copy()
  56. cli_args = []
  57. if batch_size is not None:
  58. cli_args.append(CLIArgument("--batch_size", batch_size))
  59. if learning_rate is not None:
  60. cli_args.append(CLIArgument("--learning_rate", learning_rate))
  61. if epochs_iters is not None:
  62. cli_args.append(CLIArgument("--iters", epochs_iters))
  63. # No need to handle `ips`
  64. if device is not None:
  65. device_type, _ = self.runner.parse_device(device)
  66. cli_args.append(CLIArgument("--device", device_type))
  67. # For compatibility
  68. resume_dir = kwargs.pop("resume_dir", None)
  69. if resume_path is None and resume_dir is not None:
  70. resume_path = os.path.join(resume_dir, "model.pdparams")
  71. if resume_path is not None:
  72. # NOTE: We must use an absolute path here,
  73. # so we can run the scripts either inside or outside the repo dir.
  74. resume_path = abspath(resume_path)
  75. if os.path.basename(resume_path) != "model.pdparams":
  76. raise ValueError(f"{resume_path} has an incorrect file name.")
  77. if not os.path.exists(resume_path):
  78. raise FileNotFoundError(f"{resume_path} does not exist.")
  79. resume_dir = os.path.dirname(resume_path)
  80. opts_path = os.path.join(resume_dir, "model.pdopt")
  81. if not os.path.exists(opts_path):
  82. raise FileNotFoundError(f"{opts_path} must exist.")
  83. cli_args.append(CLIArgument("--resume_model", resume_dir))
  84. if dy2st:
  85. config.update_dy2st(dy2st)
  86. if use_vdl:
  87. cli_args.append(CLIArgument("--use_vdl"))
  88. if save_dir is not None:
  89. save_dir = abspath(save_dir)
  90. else:
  91. # `save_dir` is None
  92. save_dir = abspath(os.path.join("output", "train"))
  93. cli_args.append(CLIArgument("--save_dir", save_dir))
  94. save_interval = kwargs.pop("save_interval", None)
  95. if save_interval is not None:
  96. cli_args.append(CLIArgument("--save_interval", save_interval))
  97. do_eval = kwargs.pop("do_eval", True)
  98. repeats = kwargs.pop("repeats", None)
  99. seed = kwargs.pop("seed", None)
  100. profile = kwargs.pop("profile", None)
  101. if profile is not None:
  102. cli_args.append(CLIArgument("--profiler_options", profile))
  103. log_iters = kwargs.pop("log_iters", None)
  104. if log_iters is not None:
  105. cli_args.append(CLIArgument("--log_iters", log_iters))
  106. # Benchmarking mode settings
  107. benchmark = kwargs.pop("benchmark", None)
  108. if benchmark is not None:
  109. envs = benchmark.get("env", None)
  110. seed = benchmark.get("seed", None)
  111. repeats = benchmark.get("repeats", None)
  112. do_eval = benchmark.get("do_eval", False)
  113. num_workers = benchmark.get("num_workers", None)
  114. config.update_log_ranks(device)
  115. amp = benchmark.get("amp", None)
  116. config.update_print_mem_info(benchmark.get("print_mem_info", True))
  117. config.update_shuffle(benchmark.get("shuffle", False))
  118. if repeats is not None:
  119. assert isinstance(repeats, int), "repeats must be an integer."
  120. cli_args.append(CLIArgument("--repeats", repeats))
  121. if num_workers is not None:
  122. assert isinstance(num_workers, int), "num_workers must be an integer."
  123. cli_args.append(CLIArgument("--num_workers", num_workers))
  124. if seed is not None:
  125. assert isinstance(seed, int), "seed must be an integer."
  126. cli_args.append(CLIArgument("--seed", seed))
  127. if amp in ["O1", "O2"]:
  128. cli_args.append(CLIArgument("--precision", "fp16"))
  129. cli_args.append(CLIArgument("--amp_level", amp))
  130. if envs is not None:
  131. for env_name, env_value in envs.items():
  132. os.environ[env_name] = str(env_value)
  133. else:
  134. if amp is not None:
  135. if amp != "OFF":
  136. cli_args.append(CLIArgument("--precision", "fp16"))
  137. cli_args.append(CLIArgument("--amp_level", amp))
  138. if num_workers is not None:
  139. cli_args.append(CLIArgument("--num_workers", num_workers))
  140. if repeats is not None:
  141. cli_args.append(CLIArgument("--repeats", repeats))
  142. if seed is not None:
  143. cli_args.append(CLIArgument("--seed", seed))
  144. # PDX related settings
  145. config.set_val("pdx_model_name", self.name)
  146. hpi_config_path = self.model_info.get("hpi_config_path", None)
  147. if hpi_config_path:
  148. hpi_config_path = hpi_config_path.as_posix()
  149. config.set_val("hpi_config_path", hpi_config_path)
  150. self._assert_empty_kwargs(kwargs)
  151. with self._create_new_config_file() as config_path:
  152. config.dump(config_path)
  153. return self.runner.train(
  154. config_path, cli_args, device, ips, save_dir, do_eval=do_eval
  155. )
  156. def evaluate(
  157. self,
  158. weight_path: str,
  159. batch_size: int = None,
  160. ips: str = None,
  161. device: str = "gpu",
  162. amp: str = "OFF",
  163. num_workers: int = None,
  164. **kwargs,
  165. ) -> CompletedProcess:
  166. """evaluate self using specified weight
  167. Args:
  168. weight_path (str): the path of model weight file to be evaluated.
  169. batch_size (int, optional): the batch size value in evaluating. Defaults to None.
  170. ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
  171. device (str, optional): the running device. Defaults to 'gpu'.
  172. amp (str, optional): the AMP setting. Defaults to 'OFF'.
  173. num_workers (int, optional): the workers number in evaluating. Defaults to None.
  174. Returns:
  175. CompletedProcess: the result of evaluating subprocess execution.
  176. """
  177. config = self.config.copy()
  178. cli_args = []
  179. weight_path = abspath(weight_path)
  180. cli_args.append(CLIArgument("--model_path", weight_path))
  181. if batch_size is not None:
  182. if batch_size != 1:
  183. raise ValueError("Batch size other than 1 is not supported.")
  184. # No need to handle `ips`
  185. if device is not None:
  186. device_type, _ = self.runner.parse_device(device)
  187. cli_args.append(CLIArgument("--device", device_type))
  188. if amp is not None:
  189. if amp != "OFF":
  190. cli_args.append(CLIArgument("--precision", "fp16"))
  191. cli_args.append(CLIArgument("--amp_level", amp))
  192. if num_workers is not None:
  193. cli_args.append(CLIArgument("--num_workers", num_workers))
  194. self._assert_empty_kwargs(kwargs)
  195. with self._create_new_config_file() as config_path:
  196. config.dump(config_path)
  197. cp = self.runner.evaluate(config_path, cli_args, device, ips)
  198. return cp
  199. def predict(
  200. self,
  201. weight_path: str,
  202. input_path: str,
  203. device: str = "gpu",
  204. save_dir: str = None,
  205. **kwargs,
  206. ) -> CompletedProcess:
  207. """predict using specified weight
  208. Args:
  209. weight_path (str): the path of model weight file used to predict.
  210. input_path (str): the path of image file to be predicted.
  211. device (str, optional): the running device. Defaults to 'gpu'.
  212. save_dir (str, optional): the directory path to save predict output. Defaults to None.
  213. Returns:
  214. CompletedProcess: the result of predicting subprocess execution.
  215. """
  216. config = self.config.copy()
  217. cli_args = []
  218. weight_path = abspath(weight_path)
  219. cli_args.append(CLIArgument("--model_path", weight_path))
  220. input_path = abspath(input_path)
  221. cli_args.append(CLIArgument("--image_path", input_path))
  222. if device is not None:
  223. device_type, _ = self.runner.parse_device(device)
  224. cli_args.append(CLIArgument("--device", device_type))
  225. if save_dir is not None:
  226. save_dir = abspath(save_dir)
  227. else:
  228. # `save_dir` is None
  229. save_dir = abspath(os.path.join("output", "predict"))
  230. cli_args.append(CLIArgument("--save_dir", save_dir))
  231. self._assert_empty_kwargs(kwargs)
  232. with self._create_new_config_file() as config_path:
  233. config.dump(config_path)
  234. return self.runner.predict(config_path, cli_args, device)
  235. def analyse(self, weight_path, ips=None, device="gpu", save_dir=None, **kwargs):
  236. """analyse"""
  237. config = self.config.copy()
  238. cli_args = []
  239. weight_path = abspath(weight_path)
  240. cli_args.append(CLIArgument("--model_path", weight_path))
  241. if device is not None:
  242. device_type, _ = self.runner.parse_device(device)
  243. cli_args.append(CLIArgument("--device", device_type))
  244. if save_dir is not None:
  245. save_dir = abspath(save_dir)
  246. else:
  247. # `save_dir` is None
  248. save_dir = abspath(os.path.join("output", "analysis"))
  249. cli_args.append(CLIArgument("--save_dir", save_dir))
  250. self._assert_empty_kwargs(kwargs)
  251. with self._create_new_config_file() as config_path:
  252. config.dump(config_path)
  253. cp = self.runner.analyse(config_path, cli_args, device, ips)
  254. return cp
  255. def export(self, weight_path: str, save_dir: str, **kwargs) -> CompletedProcess:
  256. """export the dynamic model to static model
  257. Args:
  258. weight_path (str): the model weight file path that used to export.
  259. save_dir (str): the directory path to save export output.
  260. Returns:
  261. CompletedProcess: the result of exporting subprocess execution.
  262. """
  263. config = self.config.copy()
  264. cli_args = []
  265. if not weight_path.startswith("http"):
  266. weight_path = abspath(weight_path)
  267. else:
  268. filename = os.path.basename(weight_path)
  269. save_path = os.path.join(DEFAULT_CACHE_DIR, filename)
  270. download(weight_path, save_path, print_progress=True, overwrite=True)
  271. weight_path = save_path
  272. cli_args.append(CLIArgument("--model_path", weight_path))
  273. if save_dir is not None:
  274. save_dir = abspath(save_dir)
  275. else:
  276. # `save_dir` is None
  277. save_dir = abspath(os.path.join("output", "export"))
  278. cli_args.append(CLIArgument("--save_dir", save_dir))
  279. input_shape = kwargs.pop("input_shape", None)
  280. if input_shape is not None:
  281. cli_args.append(CLIArgument("--input_shape", *input_shape))
  282. output_op = kwargs.pop("output_op", None)
  283. if output_op is not None:
  284. assert output_op in [
  285. "softmax",
  286. "argmax",
  287. ], "`output_op` must be 'softmax' or 'argmax'."
  288. cli_args.append(CLIArgument("--output_op", output_op))
  289. # PDX related settings
  290. config.set_val("pdx_model_name", self.name)
  291. hpi_config_path = self.model_info.get("hpi_config_path", None)
  292. if hpi_config_path:
  293. hpi_config_path = hpi_config_path.as_posix()
  294. config.set_val("hpi_config_path", hpi_config_path)
  295. self._assert_empty_kwargs(kwargs)
  296. with self._create_new_config_file() as config_path:
  297. config.dump(config_path)
  298. return self.runner.export(config_path, cli_args, None)
  299. def infer(
  300. self,
  301. model_dir: str,
  302. input_path: str,
  303. device: str = "gpu",
  304. save_dir: str = None,
  305. **kwargs,
  306. ) -> CompletedProcess:
  307. """predict image using infernece model
  308. Args:
  309. model_dir (str): the directory path of inference model files that would use to predict.
  310. input_path (str): the path of image that would be predict.
  311. device (str, optional): the running device. Defaults to 'gpu'.
  312. save_dir (str, optional): the directory path to save output. Defaults to None.
  313. Returns:
  314. CompletedProcess: the result of infering subprocess execution.
  315. """
  316. config = self.config.copy()
  317. cli_args = []
  318. model_dir = abspath(model_dir)
  319. input_path = abspath(input_path)
  320. cli_args.append(CLIArgument("--image_path", input_path))
  321. if device is not None:
  322. device_type, _ = self.runner.parse_device(device)
  323. cli_args.append(CLIArgument("--device", device_type))
  324. if save_dir is not None:
  325. save_dir = abspath(save_dir)
  326. else:
  327. # `save_dir` is None
  328. save_dir = abspath(os.path.join("output", "infer"))
  329. cli_args.append(CLIArgument("--save_dir", save_dir))
  330. self._assert_empty_kwargs(kwargs)
  331. with self._create_new_config_file() as config_path:
  332. config.dump(config_path)
  333. deploy_config_path = os.path.join(model_dir, "inference.yml")
  334. return self.runner.infer(deploy_config_path, cli_args, device)
  335. def compression(
  336. self,
  337. weight_path: str,
  338. batch_size: int = None,
  339. learning_rate: float = None,
  340. epochs_iters: int = None,
  341. device: str = "gpu",
  342. use_vdl: bool = True,
  343. save_dir: str = None,
  344. **kwargs,
  345. ) -> CompletedProcess:
  346. """compression model
  347. Args:
  348. weight_path (str): the path to weight file of model.
  349. batch_size (int, optional): the batch size value of compression training. Defaults to None.
  350. learning_rate (float, optional): the learning rate value of compression training. Defaults to None.
  351. epochs_iters (int, optional): the epochs or iters of compression training. Defaults to None.
  352. device (str, optional): the device to run compression training. Defaults to 'gpu'.
  353. use_vdl (bool, optional): whether or not to use VisualDL. Defaults to True.
  354. save_dir (str, optional): the directory to save output. Defaults to None.
  355. Returns:
  356. CompletedProcess: the result of compression subprocess execution.
  357. """
  358. # Update YAML config file
  359. # NOTE: In PaddleSeg, QAT does not use a different config file than regular training
  360. # Reusing `self.config` preserves the config items modified by the user when
  361. # `SegModel` is initialized with a `SegConfig` object.
  362. config = self.config.copy()
  363. train_cli_args = []
  364. export_cli_args = []
  365. weight_path = abspath(weight_path)
  366. train_cli_args.append(CLIArgument("--model_path", weight_path))
  367. if batch_size is not None:
  368. train_cli_args.append(CLIArgument("--batch_size", batch_size))
  369. if learning_rate is not None:
  370. train_cli_args.append(CLIArgument("--learning_rate", learning_rate))
  371. if epochs_iters is not None:
  372. train_cli_args.append(CLIArgument("--iters", epochs_iters))
  373. if device is not None:
  374. device_type, _ = self.runner.parse_device(device)
  375. train_cli_args.append(CLIArgument("--device", device_type))
  376. if use_vdl:
  377. train_cli_args.append(CLIArgument("--use_vdl"))
  378. if save_dir is not None:
  379. save_dir = abspath(save_dir)
  380. else:
  381. # `save_dir` is None
  382. save_dir = abspath(os.path.join("output", "compress"))
  383. train_cli_args.append(CLIArgument("--save_dir", save_dir))
  384. # The exported model saved in a subdirectory named `export`
  385. export_cli_args.append(
  386. CLIArgument("--save_dir", os.path.join(save_dir, "export"))
  387. )
  388. input_shape = kwargs.pop("input_shape", None)
  389. if input_shape is not None:
  390. export_cli_args.append(CLIArgument("--input_shape", *input_shape))
  391. self._assert_empty_kwargs(kwargs)
  392. with self._create_new_config_file() as config_path:
  393. config.dump(config_path)
  394. return self.runner.compression(
  395. config_path, train_cli_args, export_cli_args, device, save_dir
  396. )