model.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. from .runner import raise_unsupported_api_error
  16. from ...base import BaseModel
  17. from ....utils import logging
  18. from ...base.utils.arg import CLIArgument
  19. from ....utils.misc import abspath
  20. class BEVFusionModel(BaseModel):
  21. def train(
  22. self,
  23. batch_size=None,
  24. learning_rate=None,
  25. epochs_iters=None,
  26. pretrained=None,
  27. ips=None,
  28. device="gpu",
  29. resume_path=None,
  30. dy2st=False,
  31. amp="OFF",
  32. num_workers=None,
  33. use_vdl=True,
  34. save_dir=None,
  35. **kwargs,
  36. ):
  37. if resume_path is not None:
  38. resume_path = abspath(resume_path)
  39. if not use_vdl:
  40. logging.warning("Currently, VisualDL cannot be disabled during training.")
  41. if save_dir is not None:
  42. save_dir = abspath(save_dir)
  43. else:
  44. # `save_dir` is None
  45. save_dir = abspath(osp.join("output", "train"))
  46. if dy2st:
  47. raise ValueError(f"`dy2st`={dy2st} is not supported.")
  48. if device in ("cpu", "gpu"):
  49. logging.warning(
  50. f"The device type to use will be automatically determined, which may differ from the sepcified type: {repr(device)}."
  51. )
  52. # Update YAML config file
  53. config = self.config.copy()
  54. if epochs_iters is not None:
  55. config.update_iters(epochs_iters)
  56. if amp is not None:
  57. if amp != "OFF":
  58. config._update_amp(amp)
  59. # Parse CLI arguments
  60. cli_args = []
  61. if batch_size is not None:
  62. cli_args.append(CLIArgument("--batch_size", batch_size))
  63. if learning_rate is not None:
  64. cli_args.append(CLIArgument("--learning_rate", learning_rate))
  65. if num_workers is not None:
  66. cli_args.append(CLIArgument("--num_workers", num_workers))
  67. if resume_path is not None:
  68. if save_dir is not None:
  69. raise ValueError(
  70. "When `resume_path` is not None, `save_dir` must be set to None."
  71. )
  72. model_dir = osp.dirname(resume_path)
  73. cli_args.append(CLIArgument("--resume"))
  74. cli_args.append(CLIArgument("--save_dir", model_dir))
  75. if save_dir is not None:
  76. cli_args.append(CLIArgument("--save_dir", save_dir))
  77. if pretrained is not None:
  78. cli_args.append(CLIArgument("--model", abspath(pretrained)))
  79. do_eval = kwargs.pop("do_eval", True)
  80. profile = kwargs.pop("profile", None)
  81. if profile is not None:
  82. cli_args.append(CLIArgument("--profiler_options", profile))
  83. log_interval = kwargs.pop("log_interval", 1)
  84. if log_interval is not None:
  85. cli_args.append(CLIArgument("--log_interval", log_interval))
  86. save_interval = kwargs.pop("save_interval", 1)
  87. if save_interval is not None:
  88. cli_args.append(CLIArgument("--save_interval", save_interval))
  89. seed = kwargs.pop("seed", None)
  90. if seed is not None:
  91. cli_args.append(CLIArgument("--seed", seed))
  92. self._assert_empty_kwargs(kwargs)
  93. # PDX related settings
  94. uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
  95. export_with_pir = kwargs.pop("export_with_pir", False)
  96. config.update({"uniform_output_enabled": uniform_output_enabled})
  97. config.update({"pdx_model_name": self.name})
  98. if export_with_pir:
  99. config.update({"export_with_pir": export_with_pir})
  100. with self._create_new_config_file() as config_path:
  101. config.dump(config_path)
  102. return self.runner.train(
  103. config_path, cli_args, device, ips, save_dir, do_eval=do_eval
  104. )
  105. def evaluate(
  106. self,
  107. weight_path,
  108. batch_size=None,
  109. ips=None,
  110. device="gpu",
  111. amp="OFF",
  112. num_workers=None,
  113. **kwargs,
  114. ):
  115. weight_path = abspath(weight_path)
  116. if device in ("cpu", "gpu"):
  117. logging.warning(
  118. f"The device type to use will be automatically determined, which may differ from the sepcified type: {repr(device)}."
  119. )
  120. # Update YAML config file
  121. config = self.config.copy()
  122. if amp is not None:
  123. if amp != "OFF":
  124. raise ValueError("AMP evaluation is not supported.")
  125. # Parse CLI arguments
  126. cli_args = []
  127. if weight_path is not None:
  128. cli_args.append(CLIArgument("--model", weight_path))
  129. if batch_size is not None:
  130. cli_args.append(CLIArgument("--batch_size", batch_size))
  131. if batch_size != 1:
  132. raise ValueError("Batch size other than 1 is not supported.")
  133. if num_workers is not None:
  134. cli_args.append(CLIArgument("--num_workers", num_workers))
  135. self._assert_empty_kwargs(kwargs)
  136. # PDX related settings
  137. uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
  138. export_with_pir = kwargs.pop("export_with_pir", False)
  139. config.update({"uniform_output_enabled": uniform_output_enabled})
  140. config.update({"pdx_model_name": self.name})
  141. if export_with_pir:
  142. config.update({"export_with_pir": export_with_pir})
  143. with self._create_new_config_file() as config_path:
  144. config.dump(config_path)
  145. cp = self.runner.evaluate(config_path, cli_args, device, ips)
  146. return cp
  147. def predict(self, weight_path, input_path, device="gpu", save_dir=None, **kwargs):
  148. raise_unsupported_api_error("predict", self.__class__)
  149. def export(self, weight_path, save_dir, **kwargs):
  150. if not weight_path.startswith("http"):
  151. weight_path = abspath(weight_path)
  152. save_dir = abspath(save_dir)
  153. # Update YAML config file
  154. config = self.config.copy()
  155. # Parse CLI arguments
  156. cli_args = []
  157. if weight_path is not None:
  158. cli_args.append(CLIArgument("--model", weight_path))
  159. if save_dir is not None:
  160. cli_args.append(CLIArgument("--save_dir", save_dir))
  161. cli_args.append(CLIArgument("--save_name", "inference"))
  162. cli_args.append(CLIArgument("--save_inference_yml"))
  163. # PDX related settings
  164. uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
  165. export_with_pir = kwargs.pop("export_with_pir", False)
  166. config.update({"uniform_output_enabled": uniform_output_enabled})
  167. config.update({"pdx_model_name": self.name})
  168. if export_with_pir:
  169. config.update({"export_with_pir": export_with_pir})
  170. self._assert_empty_kwargs(kwargs)
  171. with self._create_new_config_file() as config_path:
  172. config.dump(config_path)
  173. return self.runner.export(config_path, cli_args, None)
  174. def infer(self, model_dir, device="gpu", **kwargs):
  175. model_dir = abspath(model_dir)
  176. # Parse CLI arguments
  177. cli_args = []
  178. model_file_path = osp.join(model_dir, ".pdmodel")
  179. params_file_path = osp.join(model_dir, ".pdiparams")
  180. cli_args.append(CLIArgument("--model_file", model_file_path))
  181. cli_args.append(CLIArgument("--params_file", params_file_path))
  182. if device is not None:
  183. device_type, _ = self.runner.parse_device(device)
  184. if device_type not in ("cpu", "gpu"):
  185. raise ValueError(f"`device`={repr(device)} is not supported.")
  186. infer_dir = osp.join(self.runner.runner_root_path, self.model_info["infer_dir"])
  187. self._assert_empty_kwargs(kwargs)
  188. # The inference script does not require a config file
  189. return self.runner.infer(None, cli_args, device, infer_dir, None)
  190. def compression(
  191. self,
  192. weight_path,
  193. ann_file=None,
  194. class_names=None,
  195. batch_size=None,
  196. learning_rate=None,
  197. epochs_iters=None,
  198. device="gpu",
  199. use_vdl=True,
  200. save_dir=None,
  201. **kwargs,
  202. ):
  203. raise_unsupported_api_error("compression", self.__class__)