register.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. from pathlib import Path
  17. from ...base.register import register_model_info, register_suite_info
  18. from .model import SegModel
  19. from .runner import SegRunner
  20. from .config import SegConfig
  21. REPO_ROOT_PATH = os.environ.get("PADDLE_PDX_PADDLESEG_PATH")
  22. PDX_CONFIG_DIR = osp.abspath(osp.join(osp.dirname(__file__), "..", "configs"))
  23. register_suite_info(
  24. {
  25. "suite_name": "Seg",
  26. "model": SegModel,
  27. "runner": SegRunner,
  28. "config": SegConfig,
  29. "runner_root_path": REPO_ROOT_PATH,
  30. }
  31. )
  32. ################ Models Using Universal Config ################
  33. # OCRNet
  34. register_model_info(
  35. {
  36. "model_name": "OCRNet_HRNet-W48",
  37. "suite": "Seg",
  38. "config_path": osp.join(PDX_CONFIG_DIR, "OCRNet_HRNet-W48.yaml"),
  39. "supported_apis": ["train", "evaluate", "predict", "export"],
  40. }
  41. )
  42. register_model_info(
  43. {
  44. "model_name": "OCRNet_HRNet-W18",
  45. "suite": "Seg",
  46. "config_path": osp.join(PDX_CONFIG_DIR, "OCRNet_HRNet-W18.yaml"),
  47. "supported_apis": ["train", "evaluate", "predict", "export"],
  48. }
  49. )
  50. # PP-LiteSeg
  51. register_model_info(
  52. {
  53. "model_name": "PP-LiteSeg-T",
  54. "suite": "Seg",
  55. "config_path": osp.join(PDX_CONFIG_DIR, "PP-LiteSeg-T.yaml"),
  56. "supported_apis": ["train", "evaluate", "predict", "export"],
  57. "supported_train_opts": {
  58. "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
  59. "dy2st": True,
  60. "amp": ["O1", "O2"],
  61. },
  62. "supported_evaluate_opts": {
  63. "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
  64. "amp": [],
  65. },
  66. "supported_predict_opts": {"device": ["cpu", "gpu", "xpu", "npu", "mlu"]},
  67. "supported_infer_opts": {"device": ["cpu", "gpu", "xpu", "npu", "mlu"]},
  68. "supported_dataset_types": [],
  69. }
  70. )
  71. # PP-LiteSeg
  72. register_model_info(
  73. {
  74. "model_name": "PP-LiteSeg-B",
  75. "suite": "Seg",
  76. "config_path": osp.join(PDX_CONFIG_DIR, "PP-LiteSeg-B.yaml"),
  77. "supported_apis": ["train", "evaluate", "predict", "export"],
  78. "supported_train_opts": {
  79. "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
  80. "dy2st": True,
  81. "amp": ["O1", "O2"],
  82. },
  83. "supported_evaluate_opts": {
  84. "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
  85. "amp": [],
  86. },
  87. "supported_predict_opts": {"device": ["cpu", "gpu", "xpu", "npu", "mlu"]},
  88. "supported_infer_opts": {"device": ["cpu", "gpu", "xpu", "npu", "mlu"]},
  89. "supported_dataset_types": [],
  90. }
  91. )
  92. # seaformer
  93. register_model_info(
  94. {
  95. "model_name": "SeaFormer_base",
  96. "suite": "Seg",
  97. "config_path": osp.join(PDX_CONFIG_DIR, "SeaFormer_base.yaml"),
  98. "supported_apis": ["train", "evaluate", "predict", "export"],
  99. }
  100. )
  101. register_model_info(
  102. {
  103. "model_name": "SeaFormer_tiny",
  104. "suite": "Seg",
  105. "config_path": osp.join(PDX_CONFIG_DIR, "SeaFormer_tiny.yaml"),
  106. "supported_apis": ["train", "evaluate", "predict", "export"],
  107. }
  108. )
  109. register_model_info(
  110. {
  111. "model_name": "SeaFormer_small",
  112. "suite": "Seg",
  113. "config_path": osp.join(PDX_CONFIG_DIR, "SeaFormer_small.yaml"),
  114. "supported_apis": ["train", "evaluate", "predict", "export"],
  115. }
  116. )
  117. register_model_info(
  118. {
  119. "model_name": "SeaFormer_large",
  120. "suite": "Seg",
  121. "config_path": osp.join(PDX_CONFIG_DIR, "SeaFormer_large.yaml"),
  122. "supported_apis": ["train", "evaluate", "predict", "export"],
  123. }
  124. )
  125. # SegFormer
  126. register_model_info(
  127. {
  128. "model_name": "SegFormer-B0",
  129. "suite": "Seg",
  130. "config_path": osp.join(PDX_CONFIG_DIR, "SegFormer-B0.yaml"),
  131. "supported_apis": ["train", "evaluate", "predict", "export"],
  132. }
  133. )
  134. register_model_info(
  135. {
  136. "model_name": "SegFormer-B1",
  137. "suite": "Seg",
  138. "config_path": osp.join(PDX_CONFIG_DIR, "SegFormer-B1.yaml"),
  139. "supported_apis": ["train", "evaluate", "predict", "export"],
  140. }
  141. )
  142. register_model_info(
  143. {
  144. "model_name": "SegFormer-B2",
  145. "suite": "Seg",
  146. "config_path": osp.join(PDX_CONFIG_DIR, "SegFormer-B2.yaml"),
  147. "supported_apis": ["train", "evaluate", "predict", "export"],
  148. }
  149. )
  150. register_model_info(
  151. {
  152. "model_name": "SegFormer-B3",
  153. "suite": "Seg",
  154. "config_path": osp.join(PDX_CONFIG_DIR, "SegFormer-B3.yaml"),
  155. "supported_apis": ["train", "evaluate", "predict", "export"],
  156. }
  157. )
  158. register_model_info(
  159. {
  160. "model_name": "SegFormer-B4",
  161. "suite": "Seg",
  162. "config_path": osp.join(PDX_CONFIG_DIR, "SegFormer-B4.yaml"),
  163. "supported_apis": ["train", "evaluate", "predict", "export"],
  164. }
  165. )
  166. register_model_info(
  167. {
  168. "model_name": "SegFormer-B5",
  169. "suite": "Seg",
  170. "config_path": osp.join(PDX_CONFIG_DIR, "SegFormer-B5.yaml"),
  171. "supported_apis": ["train", "evaluate", "predict", "export"],
  172. }
  173. )
  174. # deeplab
  175. register_model_info(
  176. {
  177. "model_name": "Deeplabv3-R50",
  178. "suite": "Seg",
  179. "config_path": osp.join(PDX_CONFIG_DIR, "Deeplabv3-R50.yaml"),
  180. "supported_apis": ["train", "evaluate", "predict", "export"],
  181. }
  182. )
  183. register_model_info(
  184. {
  185. "model_name": "Deeplabv3-R101",
  186. "suite": "Seg",
  187. "config_path": osp.join(PDX_CONFIG_DIR, "Deeplabv3-R101.yaml"),
  188. "supported_apis": ["train", "evaluate", "predict", "export"],
  189. }
  190. )
  191. register_model_info(
  192. {
  193. "model_name": "Deeplabv3_Plus-R50",
  194. "suite": "Seg",
  195. "config_path": osp.join(PDX_CONFIG_DIR, "Deeplabv3_Plus-R50.yaml"),
  196. "supported_apis": ["train", "evaluate", "predict", "export"],
  197. }
  198. )
  199. register_model_info(
  200. {
  201. "model_name": "Deeplabv3_Plus-R101",
  202. "suite": "Seg",
  203. "config_path": osp.join(PDX_CONFIG_DIR, "Deeplabv3_Plus-R101.yaml"),
  204. "supported_apis": ["train", "evaluate", "predict", "export"],
  205. }
  206. )
  207. register_model_info(
  208. {
  209. "model_name": "STFPM",
  210. "suite": "Seg",
  211. "config_path": osp.join(PDX_CONFIG_DIR, "STFPM.yaml"),
  212. "supported_apis": ["train", "evaluate", "predict", "export"],
  213. }
  214. )
  215. # For compatibility
  216. def _set_alias(model_name, alias):
  217. from ...base.register import get_registered_model_info
  218. record = get_registered_model_info(model_name)
  219. record = dict(**record)
  220. record["model_name"] = alias
  221. register_model_info(record)
  222. _set_alias("OCRNet_HRNet-W48", "ocrnet_hrnetw48")
  223. _set_alias("OCRNet_HRNet-W18", "ocrnet_hrnetw18")
  224. _set_alias("PP-LiteSeg-T", "pp_liteseg_stdc1")