runner.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import tempfile
  16. from ...base import BaseRunner
  17. from ...base.utils.subprocess import CompletedProcess
  18. class TextRecRunner(BaseRunner):
  19. """Text Recognition Runner"""
  20. def train(
  21. self,
  22. config_path: str,
  23. cli_args: list,
  24. device: str,
  25. ips: str,
  26. save_dir: str,
  27. do_eval=True,
  28. ) -> CompletedProcess:
  29. """train model
  30. Args:
  31. config_path (str): the config file path used to train.
  32. cli_args (list): the additional parameters.
  33. device (str): the training device.
  34. ips (str): the ip addresses of nodes when using distribution.
  35. save_dir (str): the directory path to save training output.
  36. do_eval (bool, optional): whether or not to evaluate model during training. Defaults to True.
  37. Returns:
  38. CompletedProcess: the result of training subprocess execution.
  39. """
  40. args, env = self.distributed(device, ips, log_dir=save_dir)
  41. cmd = [*args, "tools/train.py", "-c", config_path, *cli_args]
  42. if do_eval:
  43. # We simply pass here because in PaddleOCR periodic evaluation cannot be switched off
  44. pass
  45. else:
  46. inf = int(1.0e11)
  47. cmd.extend(["-o", f"Global.eval_batch_step={inf}"])
  48. return self.run_cmd(
  49. cmd,
  50. env=env,
  51. switch_wdir=True,
  52. echo=True,
  53. silent=False,
  54. capture_output=True,
  55. log_path=self._get_train_log_path(save_dir),
  56. )
  57. def evaluate(
  58. self, config_path: str, cli_args: list, device: str, ips: str
  59. ) -> CompletedProcess:
  60. """run model evaluating
  61. Args:
  62. config_path (str): the config file path used to evaluate.
  63. cli_args (list): the additional parameters.
  64. device (str): the evaluating device.
  65. ips (str): the ip addresses of nodes when using distribution.
  66. Returns:
  67. CompletedProcess: the result of evaluating subprocess execution.
  68. """
  69. args, env = self.distributed(device, ips)
  70. cmd = [*args, "tools/eval.py", "-c", config_path]
  71. cp = self.run_cmd(
  72. cmd, env=env, switch_wdir=True, echo=True, silent=False, capture_output=True
  73. )
  74. if cp.returncode == 0:
  75. metric_dict = _extract_eval_metrics(cp.stdout)
  76. cp.metrics = metric_dict
  77. return cp
  78. def predict(
  79. self, config_path: str, cli_args: list, device: str
  80. ) -> CompletedProcess:
  81. """run predicting using dynamic mode
  82. Args:
  83. config_path (str): the config file path used to predict.
  84. cli_args (list): the additional parameters.
  85. device (str): unused.
  86. Returns:
  87. CompletedProcess: the result of predicting subprocess execution.
  88. """
  89. cmd = [self.python, "tools/infer_rec.py", "-c", config_path]
  90. return self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
  91. def export(
  92. self, config_path: str, cli_args: list, device: str, save_dir: str = None
  93. ) -> CompletedProcess:
  94. """run exporting
  95. Args:
  96. config_path (str): the path of config file used to export.
  97. cli_args (list): the additional parameters.
  98. device (str): unused.
  99. save_dir (str, optional): the directory path to save exporting output. Defaults to None.
  100. Returns:
  101. CompletedProcess: the result of exporting subprocess execution.
  102. """
  103. # `device` unused
  104. cmd = [self.python, "tools/export_model.py", "-c", config_path]
  105. cp = self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
  106. return cp
  107. def infer(self, config_path: str, cli_args: list, device: str) -> CompletedProcess:
  108. """run predicting using inference model
  109. Args:
  110. config_path (str): the path of config file used to predict.
  111. cli_args (list): the additional parameters.
  112. device (str): unused.
  113. Returns:
  114. CompletedProcess: the result of infering subprocess execution.
  115. """
  116. cmd = [self.python, "tools/infer/predict_rec.py", *cli_args]
  117. return self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
  118. def compression(
  119. self,
  120. config_path: str,
  121. train_cli_args: list,
  122. export_cli_args: list,
  123. device: str,
  124. train_save_dir: str,
  125. ) -> CompletedProcess:
  126. """run compression model
  127. Args:
  128. config_path (str): the path of config file used to predict.
  129. train_cli_args (list): the additional training parameters.
  130. export_cli_args (list): the additional exporting parameters.
  131. device (str): the running device.
  132. train_save_dir (str): the directory path to save output.
  133. Returns:
  134. CompletedProcess: the result of compression subprocess execution.
  135. """
  136. # Step 1: Train model
  137. args, env = self.distributed(device, log_dir=train_save_dir)
  138. cmd = [*args, "deploy/slim/quantization/quant.py", "-c", config_path]
  139. cp_train = self.run_cmd(
  140. cmd,
  141. env=env,
  142. switch_wdir=True,
  143. echo=True,
  144. silent=False,
  145. capture_output=True,
  146. log_path=self._get_train_log_path(train_save_dir),
  147. )
  148. # Step 2: Export model
  149. export_cli_args = [
  150. *export_cli_args,
  151. "-o",
  152. f"Global.checkpoints={train_save_dir}/latest",
  153. ]
  154. cmd = [
  155. self.python,
  156. "deploy/slim/quantization/export_model.py",
  157. "-c",
  158. config_path,
  159. *export_cli_args,
  160. ]
  161. cp_export = self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
  162. return cp_train, cp_export
  163. def _extract_eval_metrics(stdout: str) -> dict:
  164. """extract evaluation metrics from training log
  165. Args:
  166. stdout (str): the training log
  167. Returns:
  168. dict: the training metric
  169. """
  170. import re
  171. def _lazy_split_lines(s):
  172. prev_idx = 0
  173. while True:
  174. curr_idx = s.find(os.linesep, prev_idx)
  175. if curr_idx == -1:
  176. curr_idx = len(s)
  177. yield s[prev_idx:curr_idx]
  178. prev_idx = curr_idx + len(os.linesep)
  179. if prev_idx >= len(s):
  180. break
  181. _DP = r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?"
  182. pattern_key_pairs = [
  183. (re.compile(r"acc:(_dp)$".replace("_dp", _DP)), "acc"),
  184. (re.compile(r"norm_edit_dis:(_dp)$".replace("_dp", _DP)), "norm_edit_dis"),
  185. (re.compile(r"Teacher_acc:(_dp)$".replace("_dp", _DP)), "teacher_acc"),
  186. (
  187. re.compile(r"Teacher_norm_edit_dis:(_dp)$".replace("_dp", _DP)),
  188. "teacher_norm_edit_dis",
  189. ),
  190. (re.compile(r"precision:(_dp)$".replace("_dp", _DP)), "precision"),
  191. (re.compile(r"recall:(_dp)$".replace("_dp", _DP)), "recall"),
  192. (re.compile(r"hmean:(_dp)$".replace("_dp", _DP)), "hmean"),
  193. (re.compile(r"exp_rate:(_dp)$".replace("_dp", _DP)), "exp_rate"),
  194. ]
  195. metric_dict = dict()
  196. start_match = False
  197. for line in _lazy_split_lines(stdout):
  198. if "metric eval" in line:
  199. start_match = True
  200. if start_match:
  201. for pattern, key in pattern_key_pairs:
  202. match = pattern.search(line)
  203. if match:
  204. assert len(match.groups()) == 1
  205. # Newer overwrites older
  206. metric_dict[key] = float(match.group(1))
  207. return metric_dict