pretrain_weights.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. import paddlex
  2. import paddlex.utils.logging as logging
  3. import paddlehub as hub
  4. import os
  5. import os.path as osp
  6. image_pretrain = {
  7. 'ResNet18':
  8. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet18_pretrained.tar',
  9. 'ResNet34':
  10. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_pretrained.tar',
  11. 'ResNet50':
  12. 'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar',
  13. 'ResNet101':
  14. 'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.tar',
  15. 'ResNet50_vd':
  16. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar',
  17. 'ResNet101_vd':
  18. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar',
  19. 'ResNet50_vd_ssld':
  20. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar',
  21. 'ResNet101_vd_ssld':
  22. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_ssld_pretrained.tar',
  23. 'MobileNetV1':
  24. 'http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar',
  25. 'MobileNetV2_x1.0':
  26. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_pretrained.tar',
  27. 'MobileNetV2_x0.5':
  28. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x0_5_pretrained.tar',
  29. 'MobileNetV2_x2.0':
  30. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x2_0_pretrained.tar',
  31. 'MobileNetV2_x0.25':
  32. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x0_25_pretrained.tar',
  33. 'MobileNetV2_x1.5':
  34. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x1_5_pretrained.tar',
  35. 'MobileNetV3_small':
  36. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_pretrained.tar',
  37. 'MobileNetV3_large':
  38. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_pretrained.tar',
  39. 'MobileNetV3_small_x1_0_ssld':
  40. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_ssld_pretrained.tar',
  41. 'MobileNetV3_large_x1_0_ssld':
  42. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar',
  43. 'DarkNet53':
  44. 'https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_ImageNet1k_pretrained.tar',
  45. 'DenseNet121':
  46. 'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet121_pretrained.tar',
  47. 'DenseNet161':
  48. 'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet161_pretrained.tar',
  49. 'DenseNet201':
  50. 'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet201_pretrained.tar',
  51. 'DetResNet50':
  52. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar',
  53. 'SegXception41':
  54. 'https://paddle-imagenet-models-name.bj.bcebos.com/Xception41_deeplab_pretrained.tar',
  55. 'SegXception65':
  56. 'https://paddle-imagenet-models-name.bj.bcebos.com/Xception65_deeplab_pretrained.tar',
  57. 'ShuffleNetV2':
  58. 'https://paddle-imagenet-models-name.bj.bcebos.com/ShuffleNetV2_pretrained.tar',
  59. 'HRNet_W18':
  60. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W18_C_pretrained.tar',
  61. 'HRNet_W30':
  62. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W30_C_pretrained.tar',
  63. 'HRNet_W32':
  64. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W32_C_pretrained.tar',
  65. 'HRNet_W40':
  66. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W40_C_pretrained.tar',
  67. 'HRNet_W44':
  68. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W44_C_pretrained.tar',
  69. 'HRNet_W48':
  70. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W48_C_pretrained.tar',
  71. 'HRNet_W60':
  72. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W60_C_pretrained.tar',
  73. 'HRNet_W64':
  74. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W64_C_pretrained.tar',
  75. 'AlexNet':
  76. 'http://paddle-imagenet-models-name.bj.bcebos.com/AlexNet_pretrained.tar'
  77. }
  78. coco_pretrain = {
  79. 'YOLOv3_DarkNet53_COCO':
  80. 'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_darknet.tar',
  81. 'YOLOv3_MobileNetV1_COCO':
  82. 'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar',
  83. 'YOLOv3_MobileNetV3_large_COCO':
  84. 'https://bj.bcebos.com/paddlex/models/yolov3_mobilenet_v3.tar',
  85. 'YOLOv3_ResNet34_COCO':
  86. 'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar',
  87. 'YOLOv3_ResNet50_vd_COCO':
  88. 'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r50vd_dcn.tar',
  89. 'FasterRCNN_ResNet50_COCO':
  90. 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_fpn_2x.tar',
  91. 'FasterRCNN_ResNet50_vd_COCO':
  92. 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_2x.tar',
  93. 'FasterRCNN_ResNet101_COCO':
  94. 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar',
  95. 'FasterRCNN_ResNet101_vd_COCO':
  96. 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_2x.tar',
  97. 'FasterRCNN_HRNet_W18_COCO':
  98. 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_hrnetv2p_w18_2x.tar',
  99. 'MaskRCNN_ResNet50_COCO':
  100. 'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_fpn_2x.tar',
  101. 'MaskRCNN_ResNet50_vd_COCO':
  102. 'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_vd_fpn_2x.tar',
  103. 'MaskRCNN_ResNet101_COCO':
  104. 'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r101_fpn_1x.tar',
  105. 'MaskRCNN_ResNet101_vd_COCO':
  106. 'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r101_vd_fpn_1x.tar',
  107. 'UNet_COCO': 'https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz',
  108. 'DeepLabv3p_MobileNetV2_x1.0_COCO':
  109. 'https://bj.bcebos.com/v1/paddleseg/deeplab_mobilenet_x1_0_coco.tgz',
  110. 'DeepLabv3p_Xception65_COCO':
  111. 'https://paddleseg.bj.bcebos.com/models/xception65_coco.tgz'
  112. }
  113. cityscapes_pretrain = {
  114. 'DeepLabv3p_MobileNetV2_x1.0_CITYSCAPES':
  115. 'https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz',
  116. 'DeepLabv3p_Xception65_CITYSCAPES':
  117. 'https://paddleseg.bj.bcebos.com/models/xception65_bn_cityscapes.tgz',
  118. 'HRNet_W18_CITYSCAPES':
  119. 'https://paddleseg.bj.bcebos.com/models/hrnet_w18_bn_cityscapes.tgz',
  120. 'FastSCNN_CITYSCAPES':
  121. 'https://paddleseg.bj.bcebos.com/models/fast_scnn_cityscape.tar'
  122. }
  123. def get_pretrain_weights(flag, class_name, backbone, save_dir):
  124. if flag is None:
  125. return None
  126. elif osp.isdir(flag):
  127. return flag
  128. elif osp.isfile(flag):
  129. return flag
  130. warning_info = "{} does not support to be finetuned with weights pretrained on the {} dataset, so pretrain_weights is forced to be set to {}"
  131. if flag == 'COCO':
  132. if class_name == "FasterRCNN" and backbone in ['ResNet18'] or \
  133. class_name == "MaskRCNN" and backbone in ['ResNet18', 'HRNet_W18'] or \
  134. class_name == 'DeepLabv3p' and backbone in ['Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0']:
  135. model_name = '{}_{}'.format(class_name, backbone)
  136. logging.warning(warning_info.format(model_name, flag, 'IMAGENET'))
  137. flag = 'IMAGENET'
  138. elif class_name == 'HRNet':
  139. logging.warning(warning_info.format(class_name, flag, 'IMAGENET'))
  140. flag = 'IMAGENET'
  141. elif class_name == 'FastSCNN':
  142. logging.warning(
  143. warning_info.format(class_name, flag, 'CITYSCAPES'))
  144. flag = 'CITYSCAPES'
  145. elif flag == 'CITYSCAPES':
  146. model_name = '{}_{}'.format(class_name, backbone)
  147. if class_name == 'UNet':
  148. logging.warning(warning_info.format(class_name, flag, 'COCO'))
  149. flag = 'COCO'
  150. if class_name == 'HRNet' and backbone.split('_')[
  151. -1] in ['W30', 'W32', 'W40', 'W48', 'W60', 'W64']:
  152. logging.warning(warning_info.format(backbone, flag, 'IMAGENET'))
  153. flag = 'IMAGENET'
  154. if class_name == 'DeepLabv3p' and backbone in [
  155. 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5',
  156. 'MobileNetV2_x1.5', 'MobileNetV2_x2.0'
  157. ]:
  158. model_name = '{}_{}'.format(class_name, backbone)
  159. logging.warning(warning_info.format(model_name, flag, 'IMAGENET'))
  160. flag = 'IMAGENET'
  161. elif flag == 'IMAGENET':
  162. if class_name == 'UNet':
  163. logging.warning(warning_info.format(class_name, flag, 'COCO'))
  164. flag = 'COCO'
  165. elif class_name == 'FastSCNN':
  166. logging.warning(
  167. warning_info.format(class_name, flag, 'CITYSCAPES'))
  168. flag = 'CITYSCAPES'
  169. if flag == 'IMAGENET':
  170. new_save_dir = save_dir
  171. if hasattr(paddlex, 'pretrain_dir'):
  172. new_save_dir = paddlex.pretrain_dir
  173. if backbone.startswith('Xception'):
  174. backbone = 'Seg{}'.format(backbone)
  175. elif backbone == 'MobileNetV2':
  176. backbone = 'MobileNetV2_x1.0'
  177. elif backbone == 'MobileNetV3_small_ssld':
  178. backbone = 'MobileNetV3_small_x1_0_ssld'
  179. elif backbone == 'MobileNetV3_large_ssld':
  180. backbone = 'MobileNetV3_large_x1_0_ssld'
  181. if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN']:
  182. if backbone == 'ResNet50':
  183. backbone = 'DetResNet50'
  184. assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
  185. backbone)
  186. # if backbone == 'AlexNet':
  187. # url = image_pretrain[backbone]
  188. # fname = osp.split(url)[-1].split('.')[0]
  189. # paddlex.utils.download_and_decompress(url, path=new_save_dir)
  190. # return osp.join(new_save_dir, fname)
  191. try:
  192. hub.download(backbone, save_path=new_save_dir)
  193. except Exception as e:
  194. if isinstance(e, hub.ResourceNotFoundError):
  195. raise Exception("Resource for backbone {} not found".format(
  196. backbone))
  197. elif isinstance(e, hub.ServerConnectionError):
  198. raise Exception(
  199. "Cannot get reource for backbone {}, please check your internet connection"
  200. .format(backbone))
  201. else:
  202. raise Exception(
  203. "Unexpected error, please make sure paddlehub >= 1.6.2")
  204. return osp.join(new_save_dir, backbone)
  205. elif flag in ['COCO', 'CITYSCAPES']:
  206. new_save_dir = save_dir
  207. if hasattr(paddlex, 'pretrain_dir'):
  208. new_save_dir = paddlex.pretrain_dir
  209. if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN', 'DeepLabv3p']:
  210. backbone = '{}_{}'.format(class_name, backbone)
  211. backbone = "{}_{}".format(backbone, flag)
  212. if flag == 'COCO':
  213. url = coco_pretrain[backbone]
  214. elif flag == 'CITYSCAPES':
  215. url = cityscapes_pretrain[backbone]
  216. fname = osp.split(url)[-1].split('.')[0]
  217. # paddlex.utils.download_and_decompress(url, path=new_save_dir)
  218. # return osp.join(new_save_dir, fname)
  219. try:
  220. hub.download(backbone, save_path=new_save_dir)
  221. except Exception as e:
  222. if isinstance(hub.ResourceNotFoundError):
  223. raise Exception("Resource for backbone {} not found".format(
  224. backbone))
  225. elif isinstance(hub.ServerConnectionError):
  226. raise Exception(
  227. "Cannot get reource for backbone {}, please check your internet connection"
  228. .format(backbone))
  229. else:
  230. raise Exception(
  231. "Unexpected error, please make sure paddlehub >= 1.6.2")
  232. return osp.join(new_save_dir, backbone)
  233. else:
  234. raise Exception(
  235. "pretrain_weights need to be defined as directory path or 'IMAGENET' or 'COCO' or 'Cityscapes' (download pretrain weights automatically)."
  236. )