register.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. from pathlib import Path
  17. from ...base.register import register_model_info, register_suite_info
  18. from .model import SegModel
  19. from .runner import SegRunner
  20. from .config import SegConfig
  21. REPO_ROOT_PATH = os.environ.get("PADDLE_PDX_PADDLESEG_PATH")
  22. PDX_CONFIG_DIR = osp.abspath(osp.join(osp.dirname(__file__), "..", "configs"))
  23. HPI_CONFIG_DIR = Path(__file__).parent.parent.parent.parent / "utils" / "hpi_configs"
  24. register_suite_info(
  25. {
  26. "suite_name": "Seg",
  27. "model": SegModel,
  28. "runner": SegRunner,
  29. "config": SegConfig,
  30. "runner_root_path": REPO_ROOT_PATH,
  31. }
  32. )
  33. ################ Models Using Universal Config ################
  34. # OCRNet
  35. register_model_info(
  36. {
  37. "model_name": "OCRNet_HRNet-W48",
  38. "suite": "Seg",
  39. "config_path": osp.join(PDX_CONFIG_DIR, "OCRNet_HRNet-W48.yaml"),
  40. "supported_apis": ["train", "evaluate", "predict", "export"],
  41. "hpi_config_path": HPI_CONFIG_DIR / "OCRNet_HRNet-W48.yaml",
  42. }
  43. )
  44. register_model_info(
  45. {
  46. "model_name": "OCRNet_HRNet-W18",
  47. "suite": "Seg",
  48. "config_path": osp.join(PDX_CONFIG_DIR, "OCRNet_HRNet-W18.yaml"),
  49. "supported_apis": ["train", "evaluate", "predict", "export"],
  50. "hpi_config_path": HPI_CONFIG_DIR / "OCRNet_HRNet-W18.yaml",
  51. }
  52. )
  53. # PP-LiteSeg
  54. register_model_info(
  55. {
  56. "model_name": "PP-LiteSeg-T",
  57. "suite": "Seg",
  58. "config_path": osp.join(PDX_CONFIG_DIR, "PP-LiteSeg-T.yaml"),
  59. "supported_apis": ["train", "evaluate", "predict", "export"],
  60. "supported_train_opts": {
  61. "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
  62. "dy2st": True,
  63. "amp": ["O1", "O2"],
  64. },
  65. "supported_evaluate_opts": {
  66. "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
  67. "amp": [],
  68. },
  69. "supported_predict_opts": {"device": ["cpu", "gpu", "xpu", "npu", "mlu"]},
  70. "supported_infer_opts": {"device": ["cpu", "gpu", "xpu", "npu", "mlu"]},
  71. "supported_dataset_types": [],
  72. "hpi_config_path": HPI_CONFIG_DIR / "PP-LiteSeg-T.yaml",
  73. }
  74. )
  75. # PP-LiteSeg
  76. register_model_info(
  77. {
  78. "model_name": "PP-LiteSeg-B",
  79. "suite": "Seg",
  80. "config_path": osp.join(PDX_CONFIG_DIR, "PP-LiteSeg-B.yaml"),
  81. "supported_apis": ["train", "evaluate", "predict", "export"],
  82. "supported_train_opts": {
  83. "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
  84. "dy2st": True,
  85. "amp": ["O1", "O2"],
  86. },
  87. "supported_evaluate_opts": {
  88. "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
  89. "amp": [],
  90. },
  91. "supported_predict_opts": {"device": ["cpu", "gpu", "xpu", "npu", "mlu"]},
  92. "supported_infer_opts": {"device": ["cpu", "gpu", "xpu", "npu", "mlu"]},
  93. "supported_dataset_types": [],
  94. "hpi_config_path": HPI_CONFIG_DIR / "PP-LiteSeg-B.yaml",
  95. }
  96. )
  97. # seaformer
  98. register_model_info(
  99. {
  100. "model_name": "SeaFormer_base",
  101. "suite": "Seg",
  102. "config_path": osp.join(PDX_CONFIG_DIR, "SeaFormer_base.yaml"),
  103. "supported_apis": ["train", "evaluate", "predict", "export"],
  104. "hpi_config_path": HPI_CONFIG_DIR / "SeaFormer_base.yaml",
  105. }
  106. )
  107. register_model_info(
  108. {
  109. "model_name": "SeaFormer_tiny",
  110. "suite": "Seg",
  111. "config_path": osp.join(PDX_CONFIG_DIR, "SeaFormer_tiny.yaml"),
  112. "supported_apis": ["train", "evaluate", "predict", "export"],
  113. "hpi_config_path": HPI_CONFIG_DIR / "SeaFormer_tiny.yaml",
  114. }
  115. )
  116. register_model_info(
  117. {
  118. "model_name": "SeaFormer_small",
  119. "suite": "Seg",
  120. "config_path": osp.join(PDX_CONFIG_DIR, "SeaFormer_small.yaml"),
  121. "supported_apis": ["train", "evaluate", "predict", "export"],
  122. "hpi_config_path": HPI_CONFIG_DIR / "SeaFormer_small.yaml",
  123. }
  124. )
  125. register_model_info(
  126. {
  127. "model_name": "SeaFormer_large",
  128. "suite": "Seg",
  129. "config_path": osp.join(PDX_CONFIG_DIR, "SeaFormer_large.yaml"),
  130. "supported_apis": ["train", "evaluate", "predict", "export"],
  131. "hpi_config_path": HPI_CONFIG_DIR / "SeaFormer_large.yaml",
  132. }
  133. )
  134. # SegFormer
  135. register_model_info(
  136. {
  137. "model_name": "SegFormer-B0",
  138. "suite": "Seg",
  139. "config_path": osp.join(PDX_CONFIG_DIR, "SegFormer-B0.yaml"),
  140. "supported_apis": ["train", "evaluate", "predict", "export"],
  141. "hpi_config_path": HPI_CONFIG_DIR / "SegFormer-B0.yaml",
  142. }
  143. )
  144. register_model_info(
  145. {
  146. "model_name": "SegFormer-B1",
  147. "suite": "Seg",
  148. "config_path": osp.join(PDX_CONFIG_DIR, "SegFormer-B1.yaml"),
  149. "supported_apis": ["train", "evaluate", "predict", "export"],
  150. "hpi_config_path": HPI_CONFIG_DIR / "SegFormer-B1.yaml",
  151. }
  152. )
  153. register_model_info(
  154. {
  155. "model_name": "SegFormer-B2",
  156. "suite": "Seg",
  157. "config_path": osp.join(PDX_CONFIG_DIR, "SegFormer-B2.yaml"),
  158. "supported_apis": ["train", "evaluate", "predict", "export"],
  159. "hpi_config_path": HPI_CONFIG_DIR / "SegFormer-B2.yaml",
  160. }
  161. )
  162. register_model_info(
  163. {
  164. "model_name": "SegFormer-B3",
  165. "suite": "Seg",
  166. "config_path": osp.join(PDX_CONFIG_DIR, "SegFormer-B3.yaml"),
  167. "supported_apis": ["train", "evaluate", "predict", "export"],
  168. "hpi_config_path": HPI_CONFIG_DIR / "SegFormer-B3.yaml",
  169. }
  170. )
  171. register_model_info(
  172. {
  173. "model_name": "SegFormer-B4",
  174. "suite": "Seg",
  175. "config_path": osp.join(PDX_CONFIG_DIR, "SegFormer-B4.yaml"),
  176. "supported_apis": ["train", "evaluate", "predict", "export"],
  177. "hpi_config_path": HPI_CONFIG_DIR / "SegFormer-B4.yaml",
  178. }
  179. )
  180. register_model_info(
  181. {
  182. "model_name": "SegFormer-B5",
  183. "suite": "Seg",
  184. "config_path": osp.join(PDX_CONFIG_DIR, "SegFormer-B5.yaml"),
  185. "supported_apis": ["train", "evaluate", "predict", "export"],
  186. "hpi_config_path": HPI_CONFIG_DIR / "SegFormer-B5.yaml",
  187. }
  188. )
  189. # deeplab
  190. register_model_info(
  191. {
  192. "model_name": "Deeplabv3-R50",
  193. "suite": "Seg",
  194. "config_path": osp.join(PDX_CONFIG_DIR, "Deeplabv3-R50.yaml"),
  195. "supported_apis": ["train", "evaluate", "predict", "export"],
  196. "hpi_config_path": HPI_CONFIG_DIR / "Deeplabv3-R50.yaml",
  197. }
  198. )
  199. register_model_info(
  200. {
  201. "model_name": "Deeplabv3-R101",
  202. "suite": "Seg",
  203. "config_path": osp.join(PDX_CONFIG_DIR, "Deeplabv3-R101.yaml"),
  204. "supported_apis": ["train", "evaluate", "predict", "export"],
  205. "hpi_config_path": HPI_CONFIG_DIR / "Deeplabv3-R101.yaml",
  206. }
  207. )
  208. register_model_info(
  209. {
  210. "model_name": "Deeplabv3_Plus-R50",
  211. "suite": "Seg",
  212. "config_path": osp.join(PDX_CONFIG_DIR, "Deeplabv3_Plus-R50.yaml"),
  213. "supported_apis": ["train", "evaluate", "predict", "export"],
  214. "hpi_config_path": HPI_CONFIG_DIR / "Deeplabv3_Plus-R50.yaml",
  215. }
  216. )
  217. register_model_info(
  218. {
  219. "model_name": "Deeplabv3_Plus-R101",
  220. "suite": "Seg",
  221. "config_path": osp.join(PDX_CONFIG_DIR, "Deeplabv3_Plus-R101.yaml"),
  222. "supported_apis": ["train", "evaluate", "predict", "export"],
  223. "hpi_config_path": HPI_CONFIG_DIR / "Deeplabv3_Plus-R101.yaml",
  224. }
  225. )
  226. register_model_info(
  227. {
  228. "model_name": "STFPM",
  229. "suite": "Seg",
  230. "config_path": osp.join(PDX_CONFIG_DIR, "STFPM.yaml"),
  231. "supported_apis": ["train", "evaluate", "predict", "export"],
  232. }
  233. )
  234. # For compatibility
  235. def _set_alias(model_name, alias):
  236. from ...base.register import get_registered_model_info
  237. record = get_registered_model_info(model_name)
  238. record = dict(**record)
  239. record["model_name"] = alias
  240. register_model_info(record)
  241. _set_alias("OCRNet_HRNet-W48", "ocrnet_hrnetw48")
  242. _set_alias("OCRNet_HRNet-W18", "ocrnet_hrnetw18")
  243. _set_alias("PP-LiteSeg-T", "pp_liteseg_stdc1")