exportor.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from abc import ABC
  16. from pathlib import Path
  17. from ...utils import logging
  18. from ...utils.config import AttrDict
  19. from ...utils.device import (
  20. check_supported_device,
  21. set_env_for_device,
  22. update_device_num,
  23. )
  24. from ...utils.misc import AutoRegisterABCMetaClass
  25. from .build_model import build_model
  26. def build_exportor(config: AttrDict) -> "BaseExportor":
  27. """build model exportor
  28. Args:
  29. config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
  30. Returns:
  31. BaseExportor: the exportor, which is subclass of BaseExportor.
  32. """
  33. model_name = config.Global.model
  34. try:
  35. pass
  36. except ModuleNotFoundError:
  37. pass
  38. return BaseExportor.get(model_name)(config)
  39. class BaseExportor(ABC, metaclass=AutoRegisterABCMetaClass):
  40. """Base Model Exportor"""
  41. __is_base = True
  42. def __init__(self, config):
  43. """Initialize the instance.
  44. Args:
  45. config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
  46. """
  47. super().__init__()
  48. self.global_config = config.Global
  49. self.export_config = config.Export
  50. config_path = self.get_config_path(self.export_config.weight_path)
  51. if self.export_config.get("basic_config_path", None):
  52. config_path = self.export_config.get("basic_config_path", None)
  53. self.pdx_config, self.pdx_model = build_model(
  54. self.global_config.model, config_path=config_path
  55. )
  56. def get_config_path(self, weight_path):
  57. """
  58. get config path
  59. Args:
  60. weight_path (str): The path to the weight
  61. Returns:
  62. config_path (str): The path to the config
  63. """
  64. config_path = Path(weight_path).parent / "config.yaml"
  65. # `Path("https://xxx/xxx")` would cause error on Windows
  66. try:
  67. is_exists = config_path.exists()
  68. except Exception:
  69. is_exists = False
  70. if not is_exists:
  71. logging.warning(
  72. f"The config file(`{config_path}`) related to weight file(`{weight_path}`) is not exist, use default instead."
  73. )
  74. config_path = None
  75. return config_path
  76. def export(self) -> dict:
  77. """execute model exporting
  78. Returns:
  79. dict: the export metrics
  80. """
  81. self.update_config()
  82. export_result = self.pdx_model.export(**self.get_export_kwargs())
  83. assert (
  84. export_result.returncode == 0
  85. ), f"Encountered an unexpected error({export_result.returncode}) in \
  86. exporting!"
  87. return None
  88. def get_device(self, using_device_number: int = None) -> str:
  89. """get device setting from config
  90. Args:
  91. using_device_number (int, optional): specify device number to use.
  92. Defaults to None, means that base on config setting.
  93. Returns:
  94. str: device setting, such as: `gpu:0,1`, `npu:0,1`, `cpu`.
  95. """
  96. check_supported_device(self.global_config.device, self.global_config.model)
  97. set_env_for_device(self.global_config.device)
  98. device_setting = (
  99. update_device_num(self.global_config.device, using_device_number)
  100. if using_device_number
  101. else self.global_config.device
  102. )
  103. # replace "dcu" with "gpu"
  104. device_setting = device_setting.replace("dcu", "gpu")
  105. return device_setting
  106. def update_config(self):
  107. """update export config"""
  108. def get_export_kwargs(self):
  109. """get key-value arguments of model export function"""
  110. export_with_pir = self.global_config.get("export_with_pir", False) or os.getenv(
  111. "FLAGS_json_format_model"
  112. ) in ["1", "True"]
  113. return {
  114. "weight_path": self.export_config.weight_path,
  115. "save_dir": self.global_config.output,
  116. "device": self.get_device(1),
  117. "export_with_pir": export_with_pir,
  118. }