model.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. from ....utils import logging
  16. from ....utils.misc import abspath
  17. from ...base import BaseModel
  18. from ...base.utils.arg import CLIArgument
  19. from .runner import raise_unsupported_api_error
  20. class BEVFusionModel(BaseModel):
  21. def train(
  22. self,
  23. batch_size=None,
  24. learning_rate=None,
  25. epochs_iters=None,
  26. pretrained=None,
  27. ips=None,
  28. device="gpu",
  29. resume_path=None,
  30. dy2st=False,
  31. amp="OFF",
  32. num_workers=None,
  33. use_vdl=True,
  34. save_dir=None,
  35. **kwargs,
  36. ):
  37. if resume_path is not None:
  38. resume_path = abspath(resume_path)
  39. if not use_vdl:
  40. logging.warning("Currently, VisualDL cannot be disabled during training.")
  41. if save_dir is not None:
  42. save_dir = abspath(save_dir)
  43. else:
  44. # `save_dir` is None
  45. save_dir = abspath(osp.join("output", "train"))
  46. if dy2st:
  47. raise ValueError(f"`dy2st`={dy2st} is not supported.")
  48. if device in ("cpu", "gpu"):
  49. logging.warning(
  50. f"The device type to use will be automatically determined, which may differ from the specified type: {repr(device)}."
  51. )
  52. # Update YAML config file
  53. config = self.config.copy()
  54. if epochs_iters is not None:
  55. config.update_iters(epochs_iters)
  56. if amp is not None:
  57. if amp != "OFF":
  58. config._update_amp(amp)
  59. # Parse CLI arguments
  60. cli_args = []
  61. if batch_size is not None:
  62. cli_args.append(CLIArgument("--batch_size", batch_size))
  63. if learning_rate is not None:
  64. cli_args.append(CLIArgument("--learning_rate", learning_rate))
  65. if num_workers is not None:
  66. cli_args.append(CLIArgument("--num_workers", num_workers))
  67. if resume_path is not None:
  68. if save_dir is not None:
  69. raise ValueError(
  70. "When `resume_path` is not None, `save_dir` must be set to None."
  71. )
  72. model_dir = osp.dirname(resume_path)
  73. cli_args.append(CLIArgument("--resume"))
  74. cli_args.append(CLIArgument("--save_dir", model_dir))
  75. if save_dir is not None:
  76. cli_args.append(CLIArgument("--save_dir", save_dir))
  77. if pretrained is not None:
  78. cli_args.append(CLIArgument("--model", abspath(pretrained)))
  79. do_eval = kwargs.pop("do_eval", True)
  80. profile = kwargs.pop("profile", None)
  81. if profile is not None:
  82. cli_args.append(CLIArgument("--profiler_options", profile))
  83. log_interval = kwargs.pop("log_interval", 1)
  84. if log_interval is not None:
  85. cli_args.append(CLIArgument("--log_interval", log_interval))
  86. save_interval = kwargs.pop("save_interval", 1)
  87. if save_interval is not None:
  88. cli_args.append(CLIArgument("--save_interval", save_interval))
  89. seed = kwargs.pop("seed", None)
  90. if seed is not None:
  91. cli_args.append(CLIArgument("--seed", seed))
  92. self._assert_empty_kwargs(kwargs)
  93. # PDX related settings
  94. uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
  95. export_with_pir = kwargs.pop("export_with_pir", False)
  96. config.update({"uniform_output_enabled": uniform_output_enabled})
  97. config.update({"pdx_model_name": self.name})
  98. if export_with_pir:
  99. config.update({"export_with_pir": export_with_pir})
  100. with self._create_new_config_file() as config_path:
  101. config.dump(config_path)
  102. return self.runner.train(
  103. config_path, cli_args, device, ips, save_dir, do_eval=do_eval
  104. )
  105. def evaluate(
  106. self,
  107. weight_path,
  108. batch_size=None,
  109. ips=None,
  110. device="gpu",
  111. amp="OFF",
  112. num_workers=None,
  113. **kwargs,
  114. ):
  115. weight_path = abspath(weight_path)
  116. if device in ("cpu", "gpu"):
  117. logging.warning(
  118. f"The device type to use will be automatically determined, which may differ from the specified type: {repr(device)}."
  119. )
  120. # Update YAML config file
  121. config = self.config.copy()
  122. if amp is not None:
  123. if amp != "OFF":
  124. raise ValueError("AMP evaluation is not supported.")
  125. # Parse CLI arguments
  126. cli_args = []
  127. if weight_path is not None:
  128. cli_args.append(CLIArgument("--model", weight_path))
  129. if batch_size is not None:
  130. cli_args.append(CLIArgument("--batch_size", batch_size))
  131. if batch_size != 1:
  132. raise ValueError("Batch size other than 1 is not supported.")
  133. if num_workers is not None:
  134. cli_args.append(CLIArgument("--num_workers", num_workers))
  135. self._assert_empty_kwargs(kwargs)
  136. # PDX related settings
  137. uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
  138. export_with_pir = kwargs.pop("export_with_pir", False)
  139. config.update({"uniform_output_enabled": uniform_output_enabled})
  140. config.update({"pdx_model_name": self.name})
  141. if export_with_pir:
  142. config.update({"export_with_pir": export_with_pir})
  143. with self._create_new_config_file() as config_path:
  144. config.dump(config_path)
  145. cp = self.runner.evaluate(config_path, cli_args, device, ips)
  146. return cp
  147. def predict(self, weight_path, input_path, device="gpu", save_dir=None, **kwargs):
  148. raise_unsupported_api_error("predict", self.__class__)
  149. def export(self, weight_path, save_dir, **kwargs):
  150. if not weight_path.startswith("http"):
  151. weight_path = abspath(weight_path)
  152. save_dir = abspath(save_dir)
  153. # Update YAML config file
  154. config = self.config.copy()
  155. # Parse CLI arguments
  156. cli_args = []
  157. if weight_path is not None:
  158. cli_args.append(CLIArgument("--model", weight_path))
  159. if save_dir is not None:
  160. cli_args.append(CLIArgument("--save_dir", save_dir))
  161. cli_args.append(CLIArgument("--save_name", "inference"))
  162. cli_args.append(CLIArgument("--save_inference_yml"))
  163. # PDX related settings
  164. uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
  165. export_with_pir = kwargs.pop("export_with_pir", False)
  166. config.update({"uniform_output_enabled": uniform_output_enabled})
  167. config.update({"pdx_model_name": self.name})
  168. if export_with_pir:
  169. config.update({"export_with_pir": export_with_pir})
  170. self._assert_empty_kwargs(kwargs)
  171. with self._create_new_config_file() as config_path:
  172. config.dump(config_path)
  173. return self.runner.export(config_path, cli_args, None)
  174. def infer(self, model_dir, device="gpu", **kwargs):
  175. model_dir = abspath(model_dir)
  176. # Parse CLI arguments
  177. cli_args = []
  178. model_file_path = osp.join(model_dir, ".pdmodel")
  179. params_file_path = osp.join(model_dir, ".pdiparams")
  180. cli_args.append(CLIArgument("--model_file", model_file_path))
  181. cli_args.append(CLIArgument("--params_file", params_file_path))
  182. if device is not None:
  183. device_type, _ = self.runner.parse_device(device)
  184. if device_type not in ("cpu", "gpu"):
  185. raise ValueError(f"`device`={repr(device)} is not supported.")
  186. infer_dir = osp.join(self.runner.runner_root_path, self.model_info["infer_dir"])
  187. self._assert_empty_kwargs(kwargs)
  188. # The inference script does not require a config file
  189. return self.runner.infer(None, cli_args, device, infer_dir, None)
  190. def compression(
  191. self,
  192. weight_path,
  193. ann_file=None,
  194. class_names=None,
  195. batch_size=None,
  196. learning_rate=None,
  197. epochs_iters=None,
  198. device="gpu",
  199. use_vdl=True,
  200. save_dir=None,
  201. **kwargs,
  202. ):
  203. raise_unsupported_api_error("compression", self.__class__)