model.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from ...base import BaseModel
  16. from ...base.utils.arg import CLIArgument
  17. from ...base.utils.subprocess import CompletedProcess
  18. from ....utils.misc import abspath
  19. from ....utils.errors import raise_unsupported_api_error
  20. class TSModel(BaseModel):
  21. """TS Model"""
  22. def train(
  23. self,
  24. batch_size: int = None,
  25. learning_rate: float = None,
  26. epochs_iters: int = None,
  27. ips: str = None,
  28. device: str = "gpu",
  29. resume_path: str = None,
  30. dy2st: bool = False,
  31. amp: str = "OFF",
  32. num_workers: int = None,
  33. use_vdl: bool = False,
  34. save_dir: str = None,
  35. **kwargs,
  36. ) -> CompletedProcess:
  37. """train self
  38. Args:
  39. batch_size (int, optional): the train batch size value. Defaults to None.
  40. learning_rate (float, optional): the train learning rate value. Defaults to None.
  41. epochs_iters (int, optional): the train epochs value. Defaults to None.
  42. ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
  43. device (str, optional): the running device. Defaults to 'gpu'.
  44. resume_path (str, optional): the checkpoint file path to resume training. Train from scratch if it is set
  45. to None. Defaults to None.
  46. dy2st (bool, optional): Enable dynamic to static. Defaults to False.
  47. amp (str, optional): the amp settings. Defaults to 'OFF'.
  48. num_workers (int, optional): the workers number. Defaults to None.
  49. use_vdl (bool, optional): enable VisualDL. Defaults to False.
  50. save_dir (str, optional): the directory path to save train output. Defaults to None.
  51. Returns:
  52. CompletedProcess: the result of training subprocess execution.
  53. """
  54. config = self.config.copy()
  55. cli_args = []
  56. if batch_size is not None:
  57. cli_args.append(CLIArgument("--batch_size", batch_size))
  58. if learning_rate is not None:
  59. cli_args.append(CLIArgument("--learning_rate", learning_rate))
  60. if epochs_iters is not None:
  61. cli_args.append(CLIArgument("--epoch", epochs_iters))
  62. if resume_path:
  63. raise ValueError("`resume_path` is not supported.")
  64. # No need to handle `ips`
  65. if amp is not None and amp != "OFF":
  66. raise ValueError(f"`amp`={amp} is not supported.")
  67. if dy2st:
  68. raise ValueError(f"`dy2st`={dy2st} is not supported.")
  69. if use_vdl:
  70. raise ValueError(f"`use_vdl`={use_vdl} is not supported.")
  71. if device is not None:
  72. device_type, _ = self.runner.parse_device(device)
  73. cli_args.append(CLIArgument("--device", device_type))
  74. if save_dir is not None:
  75. save_dir = abspath(save_dir)
  76. else:
  77. # `save_dir` is None
  78. save_dir = abspath(os.path.join("output", "train"))
  79. cli_args.append(CLIArgument("--save_dir", save_dir))
  80. # Benchmarking mode settings
  81. benchmark = kwargs.pop("benchmark", None)
  82. if benchmark is not None:
  83. envs = benchmark.get("env", None)
  84. num_workers = benchmark.get("num_workers", None)
  85. config.update_log_ranks(device)
  86. config.update_print_mem_info(benchmark.get("print_mem_info", True))
  87. if num_workers is not None:
  88. assert isinstance(num_workers, int), "num_workers must be an integer"
  89. cli_args.append(CLIArgument("--num_workers", num_workers))
  90. if envs is not None:
  91. for env_name, env_value in envs.items():
  92. os.environ[env_name] = str(env_value)
  93. else:
  94. if num_workers is not None:
  95. cli_args.append(CLIArgument("--num_workers", num_workers))
  96. config.update({"uniform_output_enabled": True})
  97. config.update({"pdx_model_name": self.name})
  98. self._assert_empty_kwargs(kwargs)
  99. with self._create_new_config_file() as config_path:
  100. config.dump(config_path)
  101. return self.runner.train(config_path, cli_args, device, ips, save_dir)
  102. def evaluate(
  103. self,
  104. weight_path: str,
  105. batch_size: int = None,
  106. ips: str = None,
  107. device: str = "gpu",
  108. amp: str = "OFF",
  109. num_workers: int = None,
  110. **kwargs,
  111. ) -> CompletedProcess:
  112. """evaluate self using specified weight
  113. Args:
  114. weight_path (str): the path of model weight file to be evaluated.
  115. batch_size (int, optional): the batch size value in evaluating. Defaults to None.
  116. ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
  117. device (str, optional): the running device. Defaults to 'gpu'.
  118. amp (str, optional): the AMP setting. Defaults to 'OFF'.
  119. num_workers (int, optional): the workers number in evaluating. Defaults to None.
  120. Returns:
  121. CompletedProcess: the result of evaluating subprocess execution.
  122. """
  123. config = self.config.copy()
  124. cli_args = []
  125. weight_path = abspath(weight_path)
  126. cli_args.append(CLIArgument("--checkpoints", weight_path))
  127. if batch_size is not None:
  128. if batch_size != 1:
  129. raise ValueError("Batch size other than 1 is not supported.")
  130. # No need to handle `ips`
  131. if device is not None:
  132. device_type, _ = self.runner.parse_device(device)
  133. cli_args.append(CLIArgument("--device", device_type))
  134. if amp is not None:
  135. if amp != "OFF":
  136. raise ValueError(f"`amp`={amp} is not supported.")
  137. if num_workers is not None:
  138. cli_args.append(CLIArgument("--num_workers", num_workers))
  139. self._assert_empty_kwargs(kwargs)
  140. with self._create_new_config_file() as config_path:
  141. config.dump(config_path)
  142. cp = self.runner.evaluate(config_path, cli_args, device, ips)
  143. return cp
  144. def predict(
  145. self,
  146. weight_path: str,
  147. input_path: str,
  148. device: str = "gpu",
  149. save_dir: str = None,
  150. **kwargs,
  151. ) -> CompletedProcess:
  152. """predict using specified weight
  153. Args:
  154. weight_path (str): the path of model weight file used to predict.
  155. input_path (str): the path of image file to be predicted.
  156. device (str, optional): the running device. Defaults to 'gpu'.
  157. save_dir (str, optional): the directory path to save predict output. Defaults to None.
  158. Returns:
  159. CompletedProcess: the result of predicting subprocess execution.
  160. """
  161. config = self.config.copy()
  162. cli_args = []
  163. weight_path = abspath(weight_path)
  164. cli_args.append(CLIArgument("--checkpoints", weight_path))
  165. input_path = abspath(input_path)
  166. cli_args.append(CLIArgument("--csv_path", input_path))
  167. if device is not None:
  168. device_type, _ = self.runner.parse_device(device)
  169. cli_args.append(CLIArgument("--device", device_type))
  170. if save_dir is not None:
  171. save_dir = abspath(save_dir)
  172. else:
  173. # `save_dir` is None
  174. save_dir = abspath(os.path.join("output", "predict"))
  175. cli_args.append(CLIArgument("--save_dir", save_dir))
  176. self._assert_empty_kwargs(kwargs)
  177. with self._create_new_config_file() as config_path:
  178. config.dump(config_path)
  179. return self.runner.predict(config_path, cli_args, device)
  180. def export(
  181. self, weight_path: str, save_dir: str = None, device: str = "gpu", **kwargs
  182. ):
  183. """export"""
  184. weight_path = abspath(weight_path)
  185. save_dir = abspath(save_dir)
  186. cli_args = []
  187. weight_path = abspath(weight_path)
  188. cli_args.append(CLIArgument("--checkpoints", weight_path))
  189. if save_dir is not None:
  190. save_dir = abspath(save_dir)
  191. else:
  192. save_dir = abspath(os.path.join("output", "inference"))
  193. cli_args.append(CLIArgument("--save_dir", save_dir))
  194. if device is not None:
  195. device_type, _ = self.runner.parse_device(device)
  196. cli_args.append(CLIArgument("--device", device_type))
  197. self._assert_empty_kwargs(kwargs)
  198. with self._create_new_config_file() as config_path:
  199. # Update YAML config file
  200. config = self.config.copy()
  201. config.update_pretrained_weights(weight_path)
  202. config.update({"pdx_model_name": self.name})
  203. config.dump(config_path)
  204. return self.runner.export(config_path, cli_args, device)
  205. def infer(
  206. self,
  207. model_dir: str,
  208. input_path: str,
  209. device: str = "gpu",
  210. save_dir: str = None,
  211. **kwargs,
  212. ):
  213. """infer"""
  214. raise_unsupported_api_error("infer", self.__class__)
  215. def compression(
  216. self,
  217. weight_path: str,
  218. batch_size=None,
  219. learning_rate=None,
  220. epochs_iters=None,
  221. device: str = "gpu",
  222. use_vdl=True,
  223. save_dir=None,
  224. **kwargs,
  225. ):
  226. """compression"""
  227. raise_unsupported_api_error("compression", self.__class__)