runner.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from ...base import BaseRunner
  16. def raise_unsupported_api_error(api_name, cls=None):
  17. # TODO: Automatically extract `api_name` and `cls` from stack frame
  18. if cls is not None:
  19. name = f"{cls.__name__}.{api_name}"
  20. else:
  21. name = api_name
  22. raise UnsupportedAPIError(f"The API `{name}` is not supported.")
  23. class UnsupportedAPIError(Exception):
  24. pass
  25. class BEVFusionRunner(BaseRunner):
  26. def train(self, config_path, cli_args, device, ips, save_dir, do_eval=True):
  27. args, env = self.distributed(device, ips, log_dir=save_dir)
  28. cmd = [*args, "tools/train.py"]
  29. if do_eval:
  30. cmd.append("--do_eval")
  31. cmd.extend(["--config", config_path, *cli_args])
  32. return self.run_cmd(
  33. cmd,
  34. env=env,
  35. switch_wdir=True,
  36. echo=True,
  37. silent=False,
  38. capture_output=True,
  39. log_path=self._get_train_log_path(save_dir),
  40. )
  41. def evaluate(self, config_path, cli_args, device, ips):
  42. args, env = self.distributed(device, ips)
  43. cmd = [*args, "tools/evaluate.py", "--config", config_path, *cli_args]
  44. cp = self.run_cmd(
  45. cmd, env=env, switch_wdir=True, echo=True, silent=False, capture_output=True
  46. )
  47. if cp.returncode == 0:
  48. metric_dict = _extract_eval_metrics(cp.stdout)
  49. cp.metrics = metric_dict
  50. return cp
  51. def predict(self, config_path, cli_args, device):
  52. raise_unsupported_api_error("predict", self.__class__)
  53. def export(self, config_path, cli_args, device):
  54. # `device` unused
  55. cmd = [self.python, "tools/export.py", "--config", config_path, *cli_args]
  56. return self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
  57. def infer(self, config_path, cli_args, device, infer_dir, save_dir=None):
  58. # `config_path` and `device` unused
  59. cmd = [self.python, "infer.py", *cli_args]
  60. python_infer_dir = os.path.join(infer_dir, "python")
  61. cp = self.run_cmd(cmd, switch_wdir=python_infer_dir, echo=True, silent=False)
  62. return cp
  63. def compression(
  64. self, config_path, train_cli_args, export_cli_args, device, train_save_dir
  65. ):
  66. raise_unsupported_api_error("compression", self.__class__)
  67. def _extract_eval_metrics(stdout):
  68. import re
  69. _DP = r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?"
  70. metrics = ["mAP", "NDS"]
  71. patterns = {}
  72. for metric in metrics:
  73. pattern = f"{metric}: (_dp)".replace("_dp", _DP)
  74. patterns[metric] = pattern
  75. metric_dict = dict()
  76. # TODO: Use lazy version to make it more efficient
  77. lines = stdout.splitlines()
  78. for line in lines:
  79. for m in patterns:
  80. p = re.compile(patterns[m])
  81. match = p.search(line)
  82. if match:
  83. metric_dict[m] = float(match.groups()[0])
  84. return metric_dict