register.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. from ...base.register import register_model_info, register_suite_info
  17. from .model import SegModel
  18. from .runner import SegRunner
  19. from .config import SegConfig
  20. REPO_ROOT_PATH = os.environ.get('PADDLE_PDX_PADDLESEG_PATH')
  21. PDX_CONFIG_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', 'configs'))
  22. register_suite_info({
  23. 'suite_name': 'Seg',
  24. 'model': SegModel,
  25. 'runner': SegRunner,
  26. 'config': SegConfig,
  27. 'runner_root_path': REPO_ROOT_PATH
  28. })
  29. ################ Models Using Universal Config ################
  30. # OCRNet
  31. register_model_info({
  32. 'model_name': 'OCRNet_HRNet-W48',
  33. 'suite': 'Seg',
  34. 'config_path': osp.join(PDX_CONFIG_DIR, 'OCRNet_HRNet-W48.yaml'),
  35. 'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
  36. })
  37. register_model_info({
  38. 'model_name': 'OCRNet_HRNet-W18',
  39. 'suite': 'Seg',
  40. 'config_path': osp.join(PDX_CONFIG_DIR, 'OCRNet_HRNet-W18.yaml'),
  41. 'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
  42. })
  43. # PP-LiteSeg
  44. register_model_info({
  45. 'model_name': 'PP-LiteSeg-T',
  46. 'suite': 'Seg',
  47. 'config_path': osp.join(PDX_CONFIG_DIR, 'PP-LiteSeg-T.yaml'),
  48. 'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer'],
  49. 'supported_train_opts': {
  50. 'device': ['cpu', 'gpu_nxcx', 'xpu', 'npu', 'mlu'],
  51. 'dy2st': True,
  52. 'amp': ['O1', 'O2']
  53. },
  54. 'supported_evaluate_opts': {
  55. 'device': ['cpu', 'gpu_nxcx', 'xpu', 'npu', 'mlu'],
  56. 'amp': []
  57. },
  58. 'supported_predict_opts': {
  59. 'device': ['cpu', 'gpu', 'xpu', 'npu', 'mlu']
  60. },
  61. 'supported_infer_opts': {
  62. 'device': ['cpu', 'gpu', 'xpu', 'npu', 'mlu']
  63. },
  64. 'supported_dataset_types': []
  65. })
  66. # seaformer
  67. register_model_info({
  68. 'model_name': 'SeaFormer_base',
  69. 'suite': 'Seg',
  70. 'config_path': osp.join(PDX_CONFIG_DIR, 'SeaFormer_base.yaml'),
  71. 'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
  72. })
  73. register_model_info({
  74. 'model_name': 'SeaFormer_tiny',
  75. 'suite': 'Seg',
  76. 'config_path': osp.join(PDX_CONFIG_DIR, 'SeaFormer_tiny.yaml'),
  77. 'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
  78. })
  79. register_model_info({
  80. 'model_name': 'SeaFormer_small',
  81. 'suite': 'Seg',
  82. 'config_path': osp.join(PDX_CONFIG_DIR, 'SeaFormer_small.yaml'),
  83. 'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
  84. })
  85. register_model_info({
  86. 'model_name': 'SeaFormer_large',
  87. 'suite': 'Seg',
  88. 'config_path': osp.join(PDX_CONFIG_DIR, 'SeaFormer_large.yaml'),
  89. 'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
  90. })
  91. # SegFormer
  92. register_model_info({
  93. 'model_name': 'SegFormer-B0',
  94. 'suite': 'Seg',
  95. 'config_path': osp.join(PDX_CONFIG_DIR, 'SegFormer-B0.yaml'),
  96. 'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
  97. })
  98. register_model_info({
  99. 'model_name': 'SegFormer-B1',
  100. 'suite': 'Seg',
  101. 'config_path': osp.join(PDX_CONFIG_DIR, 'SegFormer-B1.yaml'),
  102. 'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
  103. })
  104. register_model_info({
  105. 'model_name': 'SegFormer-B2',
  106. 'suite': 'Seg',
  107. 'config_path': osp.join(PDX_CONFIG_DIR, 'SegFormer-B2.yaml'),
  108. 'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
  109. })
  110. register_model_info({
  111. 'model_name': 'SegFormer-B3',
  112. 'suite': 'Seg',
  113. 'config_path': osp.join(PDX_CONFIG_DIR, 'SegFormer-B3.yaml'),
  114. 'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
  115. })
  116. register_model_info({
  117. 'model_name': 'SegFormer-B4',
  118. 'suite': 'Seg',
  119. 'config_path': osp.join(PDX_CONFIG_DIR, 'SegFormer-B4.yaml'),
  120. 'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
  121. })
  122. register_model_info({
  123. 'model_name': 'SegFormer-B5',
  124. 'suite': 'Seg',
  125. 'config_path': osp.join(PDX_CONFIG_DIR, 'SegFormer-B5.yaml'),
  126. 'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
  127. })
  128. # deeplab
  129. register_model_info({
  130. 'model_name': 'Deeplabv3-R50',
  131. 'suite': 'Seg',
  132. 'config_path': osp.join(PDX_CONFIG_DIR, 'Deeplabv3-R50.yaml'),
  133. 'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
  134. })
  135. register_model_info({
  136. 'model_name': 'Deeplabv3-R101',
  137. 'suite': 'Seg',
  138. 'config_path': osp.join(PDX_CONFIG_DIR, 'Deeplabv3-R101.yaml'),
  139. 'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
  140. })
  141. register_model_info({
  142. 'model_name': 'Deeplabv3_Plus-R50',
  143. 'suite': 'Seg',
  144. 'config_path': osp.join(PDX_CONFIG_DIR, 'Deeplabv3_Plus-R50.yaml'),
  145. 'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
  146. })
  147. register_model_info({
  148. 'model_name': 'Deeplabv3_Plus-R101',
  149. 'suite': 'Seg',
  150. 'config_path': osp.join(PDX_CONFIG_DIR, 'Deeplabv3_Plus-R101.yaml'),
  151. 'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
  152. })
  153. # For compatibility
  154. def _set_alias(model_name, alias):
  155. from ...base.register import get_registered_model_info
  156. record = get_registered_model_info(model_name)
  157. record = dict(**record)
  158. record['model_name'] = alias
  159. register_model_info(record)
  160. _set_alias('OCRNet_HRNet-W48', 'ocrnet_hrnetw48')
  161. _set_alias('OCRNet_HRNet-W18', 'ocrnet_hrnetw18')
  162. _set_alias('PP-LiteSeg-T', 'pp_liteseg_stdc1')