runner.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from ...base import BaseRunner
  16. from ...base.utils.subprocess import CompletedProcess
  17. class FormulaRecRunner(BaseRunner):
  18. """Formula Recognition Runner"""
  19. def train(
  20. self,
  21. config_path: str,
  22. cli_args: list,
  23. device: str,
  24. ips: str,
  25. save_dir: str,
  26. do_eval=True,
  27. ) -> CompletedProcess:
  28. """train model
  29. Args:
  30. config_path (str): the config file path used to train.
  31. cli_args (list): the additional parameters.
  32. device (str): the training device.
  33. ips (str): the ip addresses of nodes when using distribution.
  34. save_dir (str): the directory path to save training output.
  35. do_eval (bool, optional): whether or not to evaluate model during training. Defaults to True.
  36. Returns:
  37. CompletedProcess: the result of training subprocess execution.
  38. """
  39. args, env = self.distributed(device, ips, log_dir=save_dir)
  40. cmd = [*args, "tools/train.py", "-c", config_path, *cli_args]
  41. if do_eval:
  42. # We simply pass here because in PaddleOCR periodic evaluation cannot be switched off
  43. pass
  44. else:
  45. inf = int(1.0e11)
  46. cmd.extend(["-o", f"Global.eval_batch_step={inf}"])
  47. return self.run_cmd(
  48. cmd,
  49. env=env,
  50. switch_wdir=True,
  51. echo=True,
  52. silent=False,
  53. capture_output=True,
  54. log_path=self._get_train_log_path(save_dir),
  55. )
  56. def evaluate(
  57. self, config_path: str, cli_args: list, device: str, ips: str
  58. ) -> CompletedProcess:
  59. """run model evaluating
  60. Args:
  61. config_path (str): the config file path used to evaluate.
  62. cli_args (list): the additional parameters.
  63. device (str): the evaluating device.
  64. ips (str): the ip addresses of nodes when using distribution.
  65. Returns:
  66. CompletedProcess: the result of evaluating subprocess execution.
  67. """
  68. args, env = self.distributed(device, ips)
  69. cmd = [*args, "tools/eval.py", "-c", config_path]
  70. cp = self.run_cmd(
  71. cmd, env=env, switch_wdir=True, echo=True, silent=False, capture_output=True
  72. )
  73. if cp.returncode == 0:
  74. metric_dict = _extract_eval_metrics(cp.stdout)
  75. cp.metrics = metric_dict
  76. return cp
  77. def predict(
  78. self, config_path: str, cli_args: list, device: str
  79. ) -> CompletedProcess:
  80. """run predicting using dynamic mode
  81. Args:
  82. config_path (str): the config file path used to predict.
  83. cli_args (list): the additional parameters.
  84. device (str): unused.
  85. Returns:
  86. CompletedProcess: the result of predicting subprocess execution.
  87. """
  88. cmd = [self.python, "tools/infer_rec.py", "-c", config_path]
  89. return self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
  90. def export(
  91. self, config_path: str, cli_args: list, device: str, save_dir: str = None
  92. ) -> CompletedProcess:
  93. """run exporting
  94. Args:
  95. config_path (str): the path of config file used to export.
  96. cli_args (list): the additional parameters.
  97. device (str): unused.
  98. save_dir (str, optional): the directory path to save exporting output. Defaults to None.
  99. Returns:
  100. CompletedProcess: the result of exporting subprocess execution.
  101. """
  102. # `device` unused
  103. cmd = [self.python, "tools/export_model.py", "-c", config_path]
  104. cp = self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
  105. return cp
  106. def infer(self, config_path: str, cli_args: list, device: str) -> CompletedProcess:
  107. """run predicting using inference model
  108. Args:
  109. config_path (str): the path of config file used to predict.
  110. cli_args (list): the additional parameters.
  111. device (str): unused.
  112. Returns:
  113. CompletedProcess: the result of inferring subprocess execution.
  114. """
  115. cmd = [self.python, "tools/infer/predict_rec.py", *cli_args]
  116. return self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
  117. def compression(
  118. self,
  119. config_path: str,
  120. train_cli_args: list,
  121. export_cli_args: list,
  122. device: str,
  123. train_save_dir: str,
  124. ) -> CompletedProcess:
  125. """run compression model
  126. Args:
  127. config_path (str): the path of config file used to predict.
  128. train_cli_args (list): the additional training parameters.
  129. export_cli_args (list): the additional exporting parameters.
  130. device (str): the running device.
  131. train_save_dir (str): the directory path to save output.
  132. Returns:
  133. CompletedProcess: the result of compression subprocess execution.
  134. """
  135. # Step 1: Train model
  136. args, env = self.distributed(device, log_dir=train_save_dir)
  137. cmd = [*args, "deploy/slim/quantization/quant.py", "-c", config_path]
  138. cp_train = self.run_cmd(
  139. cmd,
  140. env=env,
  141. switch_wdir=True,
  142. echo=True,
  143. silent=False,
  144. capture_output=True,
  145. log_path=self._get_train_log_path(train_save_dir),
  146. )
  147. # Step 2: Export model
  148. export_cli_args = [
  149. *export_cli_args,
  150. "-o",
  151. f"Global.checkpoints={train_save_dir}/latest",
  152. ]
  153. cmd = [
  154. self.python,
  155. "deploy/slim/quantization/export_model.py",
  156. "-c",
  157. config_path,
  158. *export_cli_args,
  159. ]
  160. cp_export = self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
  161. return cp_train, cp_export
  162. def _extract_eval_metrics(stdout: str) -> dict:
  163. """extract evaluation metrics from training log
  164. Args:
  165. stdout (str): the training log
  166. Returns:
  167. dict: the training metric
  168. """
  169. import re
  170. def _lazy_split_lines(s):
  171. prev_idx = 0
  172. while True:
  173. curr_idx = s.find(os.linesep, prev_idx)
  174. if curr_idx == -1:
  175. curr_idx = len(s)
  176. yield s[prev_idx:curr_idx]
  177. prev_idx = curr_idx + len(os.linesep)
  178. if prev_idx >= len(s):
  179. break
  180. _DP = r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?"
  181. pattern_key_pairs = [
  182. (re.compile(r"acc:(_dp)$".replace("_dp", _DP)), "acc"),
  183. (re.compile(r"norm_edit_dis:(_dp)$".replace("_dp", _DP)), "norm_edit_dis"),
  184. (re.compile(r"Teacher_acc:(_dp)$".replace("_dp", _DP)), "teacher_acc"),
  185. (
  186. re.compile(r"Teacher_norm_edit_dis:(_dp)$".replace("_dp", _DP)),
  187. "teacher_norm_edit_dis",
  188. ),
  189. (re.compile(r"precision:(_dp)$".replace("_dp", _DP)), "precision"),
  190. (re.compile(r"recall:(_dp)$".replace("_dp", _DP)), "recall"),
  191. (re.compile(r"hmean:(_dp)$".replace("_dp", _DP)), "hmean"),
  192. (re.compile(r"exp_rate:(_dp)$".replace("_dp", _DP)), "exp_rate"),
  193. ]
  194. metric_dict = dict()
  195. start_match = False
  196. for line in _lazy_split_lines(stdout):
  197. if "metric eval" in line:
  198. start_match = True
  199. if start_match:
  200. for pattern, key in pattern_key_pairs:
  201. match = pattern.search(line)
  202. if match:
  203. assert len(match.groups()) == 1
  204. # Newer overwrites older
  205. metric_dict[key] = float(match.group(1))
  206. return metric_dict