runner.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from ...base import BaseRunner
  16. from ...base.utils.subprocess import CompletedProcess
  17. class ClsRunner(BaseRunner):
  18. """Cls Runner"""
  19. def train(
  20. self,
  21. config_path: str,
  22. cli_args: list,
  23. device: str,
  24. ips: str,
  25. save_dir: str,
  26. do_eval=True,
  27. ) -> CompletedProcess:
  28. """train model
  29. Args:
  30. config_path (str): the config file path used to train.
  31. cli_args (list): the additional parameters.
  32. device (str): the training device.
  33. ips (str): the ip addresses of nodes when using distribution.
  34. save_dir (str): the directory path to save training output.
  35. do_eval (bool, optional): whether or not to evaluate model during training. Defaults to True.
  36. Returns:
  37. CompletedProcess: the result of training subprocess execution.
  38. """
  39. args, env = self.distributed(device, ips, log_dir=save_dir)
  40. cmd = [*args, "tools/train.py", "-c", config_path, *cli_args]
  41. cmd.extend(["-o", f"Global.eval_during_train={do_eval}"])
  42. return self.run_cmd(
  43. cmd,
  44. env=env,
  45. switch_wdir=True,
  46. echo=True,
  47. silent=False,
  48. capture_output=True,
  49. log_path=self._get_train_log_path(save_dir),
  50. )
  51. def evaluate(
  52. self, config_path: str, cli_args: list, device: str, ips: str
  53. ) -> CompletedProcess:
  54. """run model evaluating
  55. Args:
  56. config_path (str): the config file path used to evaluate.
  57. cli_args (list): the additional parameters.
  58. device (str): the evaluating device.
  59. ips (str): the ip addresses of nodes when using distribution.
  60. Returns:
  61. CompletedProcess: the result of evaluating subprocess execution.
  62. """
  63. args, env = self.distributed(device, ips)
  64. cmd = [*args, "tools/eval.py", "-c", config_path, *cli_args]
  65. cp = self.run_cmd(
  66. cmd, env=env, switch_wdir=True, echo=True, silent=False, capture_output=True
  67. )
  68. if cp.returncode == 0:
  69. metric_dict = _extract_eval_metrics(cp.stdout)
  70. cp.metrics = metric_dict
  71. return cp
  72. def predict(
  73. self, config_path: str, cli_args: list, device: str
  74. ) -> CompletedProcess:
  75. """run predicting using dynamic mode
  76. Args:
  77. config_path (str): the config file path used to predict.
  78. cli_args (list): the additional parameters.
  79. device (str): unused.
  80. Returns:
  81. CompletedProcess: the result of predicting subprocess execution.
  82. """
  83. # `device` unused
  84. cmd = [self.python, "tools/infer.py", "-c", config_path, *cli_args]
  85. return self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
  86. def export(
  87. self, config_path: str, cli_args: list, device: str, save_dir: str = None
  88. ) -> CompletedProcess:
  89. """run exporting
  90. Args:
  91. config_path (str): the path of config file used to export.
  92. cli_args (list): the additional parameters.
  93. device (str): unused.
  94. save_dir (str, optional): the directory path to save exporting output. Defaults to None.
  95. Returns:
  96. CompletedProcess: the result of exporting subprocess execution.
  97. """
  98. # `device` unused
  99. cmd = [
  100. self.python,
  101. "tools/export_model.py",
  102. "-c",
  103. config_path,
  104. *cli_args,
  105. "-o",
  106. "Global.export_for_fd=True",
  107. ]
  108. cp = self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
  109. return cp
  110. def infer(self, config_path: str, cli_args: list, device: str) -> CompletedProcess:
  111. """run predicting using inference model
  112. Args:
  113. config_path (str): the path of config file used to predict.
  114. cli_args (list): the additional parameters.
  115. device (str): unused.
  116. Returns:
  117. CompletedProcess: the result of inferring subprocess execution.
  118. """
  119. # `device` unused
  120. cmd = [self.python, "python/predict_cls.py", "-c", config_path, *cli_args]
  121. return self.run_cmd(cmd, switch_wdir="deploy", echo=True, silent=False)
  122. def compression(
  123. self,
  124. config_path: str,
  125. train_cli_args: list,
  126. export_cli_args: list,
  127. device: str,
  128. train_save_dir: str,
  129. ) -> CompletedProcess:
  130. """run compression model
  131. Args:
  132. config_path (str): the path of config file used to predict.
  133. train_cli_args (list): the additional training parameters.
  134. export_cli_args (list): the additional exporting parameters.
  135. device (str): the running device.
  136. train_save_dir (str): the directory path to save output.
  137. Returns:
  138. CompletedProcess: the result of compression subprocess execution.
  139. """
  140. # Step 1: Train model
  141. cp_train = self.train(config_path, train_cli_args, device, None, train_save_dir)
  142. # Step 2: Export model
  143. weight_path = os.path.join(train_save_dir, "best_model", "model")
  144. export_cli_args = [
  145. *export_cli_args,
  146. "-o",
  147. f"Global.pretrained_model={weight_path}",
  148. ]
  149. cp_export = self.export(config_path, export_cli_args, device)
  150. return cp_train, cp_export
  151. def _extract_eval_metrics(stdout: str) -> dict:
  152. """extract evaluation metrics from training log
  153. Args:
  154. stdout (str): the training log
  155. Returns:
  156. dict: the training metric
  157. """
  158. import re
  159. _DP = r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?"
  160. patterns = [
  161. r"\[Eval\]\[Epoch 0\]\[Avg\].*top1: (_dp)".replace("_dp", _DP),
  162. r"\[Eval\]\[Epoch 0\]\[Avg\].*top1: (_dp), top5: (_dp)".replace("_dp", _DP),
  163. r"\[Eval\]\[Epoch 0\]\[Avg\].*recall1: (_dp), recall5: (_dp), mAP: (_dp)".replace(
  164. "_dp", _DP
  165. ),
  166. r"\[Eval\]\[Epoch 0\]\[Avg\].*MultiLabelMAP\(integral\): (_dp)".replace(
  167. "_dp", _DP
  168. ),
  169. r"\[Eval\]\[Epoch 0\]\[Avg\].*evalres:\ ma: (_dp)".replace("_dp", _DP),
  170. ]
  171. keys = [
  172. ["val.top1"],
  173. ["val.top1", "val.top5"],
  174. ["recall1", "recall5", "mAP"],
  175. ["MultiLabelMAP"],
  176. ["evalres: ma"],
  177. ]
  178. metric_dict = dict()
  179. for pattern, key in zip(patterns, keys):
  180. pattern = re.compile(pattern)
  181. for line in stdout.splitlines():
  182. match = pattern.search(line)
  183. if match:
  184. for k, v in zip(key, map(float, match.groups())):
  185. metric_dict[k] = v
  186. return metric_dict