register.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. from pathlib import Path
  17. from ...base.register import register_model_info, register_suite_info
  18. from .model import SegModel
  19. from .runner import SegRunner
  20. from .config import SegConfig
  21. REPO_ROOT_PATH = os.environ.get("PADDLE_PDX_PADDLESEG_PATH")
  22. PDX_CONFIG_DIR = osp.abspath(osp.join(osp.dirname(__file__), "..", "configs"))
  23. HPI_CONFIG_DIR = Path(__file__).parent.parent.parent.parent / "utils" / "hpi_configs"
  24. register_suite_info(
  25. {
  26. "suite_name": "Seg",
  27. "model": SegModel,
  28. "runner": SegRunner,
  29. "config": SegConfig,
  30. "runner_root_path": REPO_ROOT_PATH,
  31. }
  32. )
  33. ################ Models Using Universal Config ################
  34. # OCRNet
  35. register_model_info(
  36. {
  37. "model_name": "OCRNet_HRNet-W48",
  38. "suite": "Seg",
  39. "config_path": osp.join(PDX_CONFIG_DIR, "OCRNet_HRNet-W48.yaml"),
  40. "supported_apis": ["train", "evaluate", "predict", "export"],
  41. "hpi_config_path": HPI_CONFIG_DIR / "OCRNet_HRNet-W48.yaml",
  42. }
  43. )
  44. register_model_info(
  45. {
  46. "model_name": "OCRNet_HRNet-W18",
  47. "suite": "Seg",
  48. "config_path": osp.join(PDX_CONFIG_DIR, "OCRNet_HRNet-W18.yaml"),
  49. "supported_apis": ["train", "evaluate", "predict", "export"],
  50. "hpi_config_path": HPI_CONFIG_DIR / "OCRNet_HRNet-W18.yaml",
  51. }
  52. )
  53. # PP-LiteSeg
  54. register_model_info(
  55. {
  56. "model_name": "PP-LiteSeg-T",
  57. "suite": "Seg",
  58. "config_path": osp.join(PDX_CONFIG_DIR, "PP-LiteSeg-T.yaml"),
  59. "supported_apis": ["train", "evaluate", "predict", "export"],
  60. "supported_train_opts": {
  61. "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
  62. "dy2st": True,
  63. "amp": ["O1", "O2"],
  64. },
  65. "supported_evaluate_opts": {
  66. "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
  67. "amp": [],
  68. },
  69. "supported_predict_opts": {"device": ["cpu", "gpu", "xpu", "npu", "mlu"]},
  70. "supported_infer_opts": {"device": ["cpu", "gpu", "xpu", "npu", "mlu"]},
  71. "supported_dataset_types": [],
  72. "hpi_config_path": HPI_CONFIG_DIR / "PP-LiteSeg-T.yaml",
  73. }
  74. )
  75. # seaformer
  76. register_model_info(
  77. {
  78. "model_name": "SeaFormer_base",
  79. "suite": "Seg",
  80. "config_path": osp.join(PDX_CONFIG_DIR, "SeaFormer_base.yaml"),
  81. "supported_apis": ["train", "evaluate", "predict", "export"],
  82. "hpi_config_path": HPI_CONFIG_DIR / "SeaFormer_base.yaml",
  83. }
  84. )
  85. register_model_info(
  86. {
  87. "model_name": "SeaFormer_tiny",
  88. "suite": "Seg",
  89. "config_path": osp.join(PDX_CONFIG_DIR, "SeaFormer_tiny.yaml"),
  90. "supported_apis": ["train", "evaluate", "predict", "export"],
  91. "hpi_config_path": HPI_CONFIG_DIR / "SeaFormer_tiny.yaml",
  92. }
  93. )
  94. register_model_info(
  95. {
  96. "model_name": "SeaFormer_small",
  97. "suite": "Seg",
  98. "config_path": osp.join(PDX_CONFIG_DIR, "SeaFormer_small.yaml"),
  99. "supported_apis": ["train", "evaluate", "predict", "export"],
  100. "hpi_config_path": HPI_CONFIG_DIR / "SeaFormer_small.yaml",
  101. }
  102. )
  103. register_model_info(
  104. {
  105. "model_name": "SeaFormer_large",
  106. "suite": "Seg",
  107. "config_path": osp.join(PDX_CONFIG_DIR, "SeaFormer_large.yaml"),
  108. "supported_apis": ["train", "evaluate", "predict", "export"],
  109. "hpi_config_path": HPI_CONFIG_DIR / "SeaFormer_large.yaml",
  110. }
  111. )
  112. # SegFormer
  113. register_model_info(
  114. {
  115. "model_name": "SegFormer-B0",
  116. "suite": "Seg",
  117. "config_path": osp.join(PDX_CONFIG_DIR, "SegFormer-B0.yaml"),
  118. "supported_apis": ["train", "evaluate", "predict", "export"],
  119. "hpi_config_path": HPI_CONFIG_DIR / "SegFormer-B0.yaml",
  120. }
  121. )
  122. register_model_info(
  123. {
  124. "model_name": "SegFormer-B1",
  125. "suite": "Seg",
  126. "config_path": osp.join(PDX_CONFIG_DIR, "SegFormer-B1.yaml"),
  127. "supported_apis": ["train", "evaluate", "predict", "export"],
  128. "hpi_config_path": HPI_CONFIG_DIR / "SegFormer-B1.yaml",
  129. }
  130. )
  131. register_model_info(
  132. {
  133. "model_name": "SegFormer-B2",
  134. "suite": "Seg",
  135. "config_path": osp.join(PDX_CONFIG_DIR, "SegFormer-B2.yaml"),
  136. "supported_apis": ["train", "evaluate", "predict", "export"],
  137. "hpi_config_path": HPI_CONFIG_DIR / "SegFormer-B2.yaml",
  138. }
  139. )
  140. register_model_info(
  141. {
  142. "model_name": "SegFormer-B3",
  143. "suite": "Seg",
  144. "config_path": osp.join(PDX_CONFIG_DIR, "SegFormer-B3.yaml"),
  145. "supported_apis": ["train", "evaluate", "predict", "export"],
  146. "hpi_config_path": HPI_CONFIG_DIR / "SegFormer-B3.yaml",
  147. }
  148. )
  149. register_model_info(
  150. {
  151. "model_name": "SegFormer-B4",
  152. "suite": "Seg",
  153. "config_path": osp.join(PDX_CONFIG_DIR, "SegFormer-B4.yaml"),
  154. "supported_apis": ["train", "evaluate", "predict", "export"],
  155. "hpi_config_path": HPI_CONFIG_DIR / "SegFormer-B4.yaml",
  156. }
  157. )
  158. register_model_info(
  159. {
  160. "model_name": "SegFormer-B5",
  161. "suite": "Seg",
  162. "config_path": osp.join(PDX_CONFIG_DIR, "SegFormer-B5.yaml"),
  163. "supported_apis": ["train", "evaluate", "predict", "export"],
  164. "hpi_config_path": HPI_CONFIG_DIR / "SegFormer-B5.yaml",
  165. }
  166. )
  167. # deeplab
  168. register_model_info(
  169. {
  170. "model_name": "Deeplabv3-R50",
  171. "suite": "Seg",
  172. "config_path": osp.join(PDX_CONFIG_DIR, "Deeplabv3-R50.yaml"),
  173. "supported_apis": ["train", "evaluate", "predict", "export"],
  174. "hpi_config_path": HPI_CONFIG_DIR / "Deeplabv3-R50.yaml",
  175. }
  176. )
  177. register_model_info(
  178. {
  179. "model_name": "Deeplabv3-R101",
  180. "suite": "Seg",
  181. "config_path": osp.join(PDX_CONFIG_DIR, "Deeplabv3-R101.yaml"),
  182. "supported_apis": ["train", "evaluate", "predict", "export"],
  183. "hpi_config_path": HPI_CONFIG_DIR / "Deeplabv3-R101.yaml",
  184. }
  185. )
  186. register_model_info(
  187. {
  188. "model_name": "Deeplabv3_Plus-R50",
  189. "suite": "Seg",
  190. "config_path": osp.join(PDX_CONFIG_DIR, "Deeplabv3_Plus-R50.yaml"),
  191. "supported_apis": ["train", "evaluate", "predict", "export"],
  192. "hpi_config_path": HPI_CONFIG_DIR / "Deeplabv3_Plus-R50.yaml",
  193. }
  194. )
  195. register_model_info(
  196. {
  197. "model_name": "Deeplabv3_Plus-R101",
  198. "suite": "Seg",
  199. "config_path": osp.join(PDX_CONFIG_DIR, "Deeplabv3_Plus-R101.yaml"),
  200. "supported_apis": ["train", "evaluate", "predict", "export"],
  201. "hpi_config_path": HPI_CONFIG_DIR / "Deeplabv3_Plus-R101.yaml",
  202. }
  203. )
  204. # For compatibility
  205. def _set_alias(model_name, alias):
  206. from ...base.register import get_registered_model_info
  207. record = get_registered_model_info(model_name)
  208. record = dict(**record)
  209. record["model_name"] = alias
  210. register_model_info(record)
  211. _set_alias("OCRNet_HRNet-W48", "ocrnet_hrnetw48")
  212. _set_alias("OCRNet_HRNet-W18", "ocrnet_hrnetw18")
  213. _set_alias("PP-LiteSeg-T", "pp_liteseg_stdc1")