model.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from ...base import BaseModel
  16. from ...base.utils.arg import CLIArgument
  17. from ...base.utils.subprocess import CompletedProcess
  18. from ....utils.device import parse_device
  19. from ....utils.misc import abspath
  20. from ....utils.errors import raise_unsupported_api_error
  21. class TSModel(BaseModel):
  22. """TS Model"""
  23. def train(
  24. self,
  25. batch_size: int = None,
  26. learning_rate: float = None,
  27. epochs_iters: int = None,
  28. ips: str = None,
  29. device: str = "gpu",
  30. resume_path: str = None,
  31. dy2st: bool = False,
  32. amp: str = "OFF",
  33. num_workers: int = None,
  34. use_vdl: bool = False,
  35. save_dir: str = None,
  36. **kwargs,
  37. ) -> CompletedProcess:
  38. """train self
  39. Args:
  40. batch_size (int, optional): the train batch size value. Defaults to None.
  41. learning_rate (float, optional): the train learning rate value. Defaults to None.
  42. epochs_iters (int, optional): the train epochs value. Defaults to None.
  43. ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
  44. device (str, optional): the running device. Defaults to 'gpu'.
  45. resume_path (str, optional): the checkpoint file path to resume training. Train from scratch if it is set
  46. to None. Defaults to None.
  47. dy2st (bool, optional): Enable dynamic to static. Defaults to False.
  48. amp (str, optional): the amp settings. Defaults to 'OFF'.
  49. num_workers (int, optional): the workers number. Defaults to None.
  50. use_vdl (bool, optional): enable VisualDL. Defaults to False.
  51. save_dir (str, optional): the directory path to save train output. Defaults to None.
  52. Returns:
  53. CompletedProcess: the result of training subprocess execution.
  54. """
  55. config = self.config.copy()
  56. cli_args = []
  57. if batch_size is not None:
  58. cli_args.append(CLIArgument("--batch_size", batch_size))
  59. if learning_rate is not None:
  60. cli_args.append(CLIArgument("--learning_rate", learning_rate))
  61. if epochs_iters is not None:
  62. cli_args.append(CLIArgument("--epoch", epochs_iters))
  63. if resume_path:
  64. raise ValueError("`resume_path` is not supported.")
  65. # No need to handle `ips`
  66. if amp is not None and amp != "OFF":
  67. raise ValueError(f"`amp`={amp} is not supported.")
  68. if dy2st:
  69. raise ValueError(f"`dy2st`={dy2st} is not supported.")
  70. if use_vdl:
  71. raise ValueError(f"`use_vdl`={use_vdl} is not supported.")
  72. if device is not None:
  73. device_type, _ = parse_device(device)
  74. cli_args.append(CLIArgument("--device", device_type))
  75. if save_dir is not None:
  76. save_dir = abspath(save_dir)
  77. else:
  78. # `save_dir` is None
  79. save_dir = abspath(os.path.join("output", "train"))
  80. cli_args.append(CLIArgument("--save_dir", save_dir))
  81. # Benchmarking mode settings
  82. benchmark = kwargs.pop("benchmark", None)
  83. if benchmark is not None:
  84. envs = benchmark.get("env", None)
  85. num_workers = benchmark.get("num_workers", None)
  86. config.update_log_ranks(device)
  87. config.update_print_mem_info(benchmark.get("print_mem_info", True))
  88. if num_workers is not None:
  89. assert isinstance(num_workers, int), "num_workers must be an integer"
  90. cli_args.append(CLIArgument("--num_workers", num_workers))
  91. if envs is not None:
  92. for env_name, env_value in envs.items():
  93. os.environ[env_name] = str(env_value)
  94. else:
  95. if num_workers is not None:
  96. cli_args.append(CLIArgument("--num_workers", num_workers))
  97. # PDX related settings
  98. if device_type in ["npu", "xpu", "mlu"]:
  99. uniform_output_enabled = False
  100. else:
  101. uniform_output_enabled = True
  102. config.update({"uniform_output_enabled": uniform_output_enabled})
  103. config.update({"pdx_model_name": self.name})
  104. self._assert_empty_kwargs(kwargs)
  105. with self._create_new_config_file() as config_path:
  106. config.dump(config_path)
  107. return self.runner.train(config_path, cli_args, device, ips, save_dir)
  108. def evaluate(
  109. self,
  110. weight_path: str,
  111. batch_size: int = None,
  112. ips: str = None,
  113. device: str = "gpu",
  114. amp: str = "OFF",
  115. num_workers: int = None,
  116. **kwargs,
  117. ) -> CompletedProcess:
  118. """evaluate self using specified weight
  119. Args:
  120. weight_path (str): the path of model weight file to be evaluated.
  121. batch_size (int, optional): the batch size value in evaluating. Defaults to None.
  122. ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
  123. device (str, optional): the running device. Defaults to 'gpu'.
  124. amp (str, optional): the AMP setting. Defaults to 'OFF'.
  125. num_workers (int, optional): the workers number in evaluating. Defaults to None.
  126. Returns:
  127. CompletedProcess: the result of evaluating subprocess execution.
  128. """
  129. config = self.config.copy()
  130. cli_args = []
  131. weight_path = abspath(weight_path)
  132. cli_args.append(CLIArgument("--checkpoints", weight_path))
  133. if batch_size is not None:
  134. if batch_size != 1:
  135. raise ValueError("Batch size other than 1 is not supported.")
  136. # No need to handle `ips`
  137. if device is not None:
  138. device_type, _ = parse_device(device)
  139. cli_args.append(CLIArgument("--device", device_type))
  140. if amp is not None:
  141. if amp != "OFF":
  142. raise ValueError(f"`amp`={amp} is not supported.")
  143. if num_workers is not None:
  144. cli_args.append(CLIArgument("--num_workers", num_workers))
  145. self._assert_empty_kwargs(kwargs)
  146. with self._create_new_config_file() as config_path:
  147. config.dump(config_path)
  148. cp = self.runner.evaluate(config_path, cli_args, device, ips)
  149. return cp
  150. def predict(
  151. self,
  152. weight_path: str,
  153. input_path: str,
  154. device: str = "gpu",
  155. save_dir: str = None,
  156. **kwargs,
  157. ) -> CompletedProcess:
  158. """predict using specified weight
  159. Args:
  160. weight_path (str): the path of model weight file used to predict.
  161. input_path (str): the path of image file to be predicted.
  162. device (str, optional): the running device. Defaults to 'gpu'.
  163. save_dir (str, optional): the directory path to save predict output. Defaults to None.
  164. Returns:
  165. CompletedProcess: the result of predicting subprocess execution.
  166. """
  167. config = self.config.copy()
  168. cli_args = []
  169. weight_path = abspath(weight_path)
  170. cli_args.append(CLIArgument("--checkpoints", weight_path))
  171. input_path = abspath(input_path)
  172. cli_args.append(CLIArgument("--csv_path", input_path))
  173. if device is not None:
  174. device_type, _ = parse_device(device)
  175. cli_args.append(CLIArgument("--device", device_type))
  176. if save_dir is not None:
  177. save_dir = abspath(save_dir)
  178. else:
  179. # `save_dir` is None
  180. save_dir = abspath(os.path.join("output", "predict"))
  181. cli_args.append(CLIArgument("--save_dir", save_dir))
  182. self._assert_empty_kwargs(kwargs)
  183. with self._create_new_config_file() as config_path:
  184. config.dump(config_path)
  185. return self.runner.predict(config_path, cli_args, device)
  186. def export(
  187. self, weight_path: str, save_dir: str = None, device: str = "gpu", **kwargs
  188. ):
  189. """export"""
  190. weight_path = abspath(weight_path)
  191. save_dir = abspath(save_dir)
  192. cli_args = []
  193. weight_path = abspath(weight_path)
  194. cli_args.append(CLIArgument("--checkpoints", weight_path))
  195. if save_dir is not None:
  196. save_dir = abspath(save_dir)
  197. else:
  198. save_dir = abspath(os.path.join("output", "inference"))
  199. cli_args.append(CLIArgument("--save_dir", save_dir))
  200. if device is not None:
  201. device_type, _ = parse_device(device)
  202. cli_args.append(CLIArgument("--device", device_type))
  203. self._assert_empty_kwargs(kwargs)
  204. with self._create_new_config_file() as config_path:
  205. # Update YAML config file
  206. config = self.config.copy()
  207. config.update_pretrained_weights(weight_path)
  208. config.update({"pdx_model_name": self.name})
  209. config.dump(config_path)
  210. return self.runner.export(config_path, cli_args, device)
  211. def infer(
  212. self,
  213. model_dir: str,
  214. input_path: str,
  215. device: str = "gpu",
  216. save_dir: str = None,
  217. **kwargs,
  218. ):
  219. """infer"""
  220. raise_unsupported_api_error("infer", self.__class__)
  221. def compression(
  222. self,
  223. weight_path: str,
  224. batch_size=None,
  225. learning_rate=None,
  226. epochs_iters=None,
  227. device: str = "gpu",
  228. use_vdl=True,
  229. save_dir=None,
  230. **kwargs,
  231. ):
  232. """compression"""
  233. raise_unsupported_api_error("compression", self.__class__)