pretrain_weights.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. import paddlex
  2. import paddlex.utils.logging as logging
  3. import os
  4. import os.path as osp
  5. image_pretrain = {
  6. 'ResNet18':
  7. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet18_pretrained.tar',
  8. 'ResNet34':
  9. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_pretrained.tar',
  10. 'ResNet50':
  11. 'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar',
  12. 'ResNet101':
  13. 'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.tar',
  14. 'ResNet50_vd':
  15. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar',
  16. 'ResNet101_vd':
  17. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar',
  18. 'ResNet50_vd_ssld':
  19. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar',
  20. 'ResNet101_vd_ssld':
  21. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_ssld_pretrained.tar',
  22. 'MobileNetV1':
  23. 'http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar',
  24. 'MobileNetV2_x1.0':
  25. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_pretrained.tar',
  26. 'MobileNetV2_x0.5':
  27. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x0_5_pretrained.tar',
  28. 'MobileNetV2_x2.0':
  29. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x2_0_pretrained.tar',
  30. 'MobileNetV2_x0.25':
  31. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x0_25_pretrained.tar',
  32. 'MobileNetV2_x1.5':
  33. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x1_5_pretrained.tar',
  34. 'MobileNetV3_small':
  35. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_pretrained.tar',
  36. 'MobileNetV3_large':
  37. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_pretrained.tar',
  38. 'MobileNetV3_small_x1_0_ssld':
  39. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_ssld_pretrained.tar',
  40. 'MobileNetV3_large_x1_0_ssld':
  41. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar',
  42. 'DarkNet53':
  43. 'https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_ImageNet1k_pretrained.tar',
  44. 'DenseNet121':
  45. 'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet121_pretrained.tar',
  46. 'DenseNet161':
  47. 'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet161_pretrained.tar',
  48. 'DenseNet201':
  49. 'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet201_pretrained.tar',
  50. 'DetResNet50':
  51. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar',
  52. 'SegXception41':
  53. 'https://paddle-imagenet-models-name.bj.bcebos.com/Xception41_deeplab_pretrained.tar',
  54. 'SegXception65':
  55. 'https://paddle-imagenet-models-name.bj.bcebos.com/Xception65_deeplab_pretrained.tar',
  56. 'ShuffleNetV2':
  57. 'https://paddle-imagenet-models-name.bj.bcebos.com/ShuffleNetV2_pretrained.tar',
  58. 'HRNet_W18':
  59. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W18_C_pretrained.tar',
  60. 'HRNet_W30':
  61. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W30_C_pretrained.tar',
  62. 'HRNet_W32':
  63. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W32_C_pretrained.tar',
  64. 'HRNet_W40':
  65. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W40_C_pretrained.tar',
  66. 'HRNet_W44':
  67. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W44_C_pretrained.tar',
  68. 'HRNet_W48':
  69. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W48_C_pretrained.tar',
  70. 'HRNet_W60':
  71. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W60_C_pretrained.tar',
  72. 'HRNet_W64':
  73. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W64_C_pretrained.tar',
  74. 'AlexNet':
  75. 'http://paddle-imagenet-models-name.bj.bcebos.com/AlexNet_pretrained.tar'
  76. }
  77. baidu10w_pretrain = {
  78. 'ResNet50_vd_BAIDU10W':
  79. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_10w_pretrained.tar'
  80. }
  81. coco_pretrain = {
  82. 'YOLOv3_DarkNet53_COCO':
  83. 'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_darknet.tar',
  84. 'YOLOv3_MobileNetV1_COCO':
  85. 'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar',
  86. 'YOLOv3_MobileNetV3_large_COCO':
  87. 'https://bj.bcebos.com/paddlex/models/yolov3_mobilenet_v3.tar',
  88. 'YOLOv3_ResNet34_COCO':
  89. 'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar',
  90. 'YOLOv3_ResNet50_vd_COCO':
  91. 'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r50vd_dcn.tar',
  92. 'FasterRCNN_ResNet18_COCO':
  93. 'https://bj.bcebos.com/paddlex/pretrained_weights/faster_rcnn_r18_fpn_1x.tar',
  94. 'FasterRCNN_ResNet50_COCO':
  95. 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_fpn_2x.tar',
  96. 'FasterRCNN_ResNet50_vd_COCO':
  97. 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_2x.tar',
  98. 'FasterRCNN_ResNet101_COCO':
  99. 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar',
  100. 'FasterRCNN_ResNet101_vd_COCO':
  101. 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_2x.tar',
  102. 'FasterRCNN_HRNet_W18_COCO':
  103. 'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_hrnetv2p_w18_2x.tar',
  104. 'MaskRCNN_ResNet18_COCO':
  105. 'https://bj.bcebos.com/paddlex/pretrained_weights/mask_rcnn_r18_fpn_1x.tar',
  106. 'MaskRCNN_ResNet50_COCO':
  107. 'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_fpn_2x.tar',
  108. 'MaskRCNN_ResNet50_vd_COCO':
  109. 'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_vd_fpn_2x.tar',
  110. 'MaskRCNN_ResNet101_COCO':
  111. 'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r101_fpn_1x.tar',
  112. 'MaskRCNN_ResNet101_vd_COCO':
  113. 'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r101_vd_fpn_1x.tar',
  114. 'MaskRCNN_HRNet_W18_COCO':
  115. 'https://bj.bcebos.com/paddlex/pretrained_weights/mask_rcnn_hrnetv2p_w18_2x.tar',
  116. 'UNet_COCO': 'https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz',
  117. 'DeepLabv3p_MobileNetV2_x1.0_COCO':
  118. 'https://bj.bcebos.com/v1/paddleseg/deeplab_mobilenet_x1_0_coco.tgz',
  119. 'DeepLabv3p_Xception65_COCO':
  120. 'https://paddleseg.bj.bcebos.com/models/xception65_coco.tgz',
  121. 'PPYOLO_ResNet50_vd_ssld_COCO':
  122. 'https://bj.bcebos.com/paddlex/models/ppyolo_resnet50_vd_ssld.tar'
  123. }
  124. cityscapes_pretrain = {
  125. 'DeepLabv3p_MobileNetV3_large_x1_0_ssld_CITYSCAPES':
  126. 'https://paddleseg.bj.bcebos.com/models/deeplabv3p_mobilenetv3_large_cityscapes.tar.gz',
  127. 'DeepLabv3p_MobileNetV2_x1.0_CITYSCAPES':
  128. 'https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz',
  129. 'DeepLabv3p_Xception65_CITYSCAPES':
  130. 'https://paddleseg.bj.bcebos.com/models/xception65_bn_cityscapes.tgz',
  131. 'HRNet_W18_CITYSCAPES':
  132. 'https://paddleseg.bj.bcebos.com/models/hrnet_w18_bn_cityscapes.tgz',
  133. 'FastSCNN_CITYSCAPES':
  134. 'https://paddleseg.bj.bcebos.com/models/fast_scnn_cityscape.tar'
  135. }
  136. def get_pretrain_weights(flag, class_name, backbone, save_dir):
  137. if flag is None:
  138. return None
  139. elif osp.isdir(flag):
  140. return flag
  141. elif osp.isfile(flag):
  142. return flag
  143. warning_info = "{} does not support to be finetuned with weights pretrained on the {} dataset, so pretrain_weights is forced to be set to {}"
  144. if flag == 'COCO':
  145. if class_name == 'DeepLabv3p' and backbone in [
  146. 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5',
  147. 'MobileNetV2_x1.5', 'MobileNetV2_x2.0',
  148. 'MobileNetV3_large_x1_0_ssld'
  149. ]:
  150. model_name = '{}_{}'.format(class_name, backbone)
  151. logging.warning(warning_info.format(model_name, flag, 'IMAGENET'))
  152. flag = 'IMAGENET'
  153. elif class_name == 'HRNet':
  154. logging.warning(warning_info.format(class_name, flag, 'IMAGENET'))
  155. flag = 'IMAGENET'
  156. elif class_name == 'FastSCNN':
  157. logging.warning(
  158. warning_info.format(class_name, flag, 'CITYSCAPES'))
  159. flag = 'CITYSCAPES'
  160. elif flag == 'CITYSCAPES':
  161. model_name = '{}_{}'.format(class_name, backbone)
  162. if class_name == 'UNet':
  163. logging.warning(warning_info.format(class_name, flag, 'COCO'))
  164. flag = 'COCO'
  165. if class_name == 'HRNet' and backbone.split('_')[
  166. -1] in ['W30', 'W32', 'W40', 'W48', 'W60', 'W64']:
  167. logging.warning(warning_info.format(backbone, flag, 'IMAGENET'))
  168. flag = 'IMAGENET'
  169. if class_name == 'DeepLabv3p' and backbone in [
  170. 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5',
  171. 'MobileNetV2_x1.5', 'MobileNetV2_x2.0'
  172. ]:
  173. model_name = '{}_{}'.format(class_name, backbone)
  174. logging.warning(warning_info.format(model_name, flag, 'IMAGENET'))
  175. flag = 'IMAGENET'
  176. elif flag == 'IMAGENET':
  177. if class_name == 'UNet':
  178. logging.warning(warning_info.format(class_name, flag, 'COCO'))
  179. flag = 'COCO'
  180. elif class_name == 'FastSCNN':
  181. logging.warning(
  182. warning_info.format(class_name, flag, 'CITYSCAPES'))
  183. flag = 'CITYSCAPES'
  184. elif flag == 'BAIDU10W':
  185. if class_name not in ['ResNet50_vd']:
  186. raise Exception(
  187. "Only the classifier ResNet50_vd supports BAIDU10W pretrained weights"
  188. )
  189. if flag == 'IMAGENET':
  190. new_save_dir = save_dir
  191. if hasattr(paddlex, 'pretrain_dir'):
  192. new_save_dir = paddlex.pretrain_dir
  193. if backbone.startswith('Xception'):
  194. backbone = 'Seg{}'.format(backbone)
  195. elif backbone == 'MobileNetV2':
  196. backbone = 'MobileNetV2_x1.0'
  197. elif backbone == 'MobileNetV3_small_ssld':
  198. backbone = 'MobileNetV3_small_x1_0_ssld'
  199. elif backbone == 'MobileNetV3_large_ssld':
  200. backbone = 'MobileNetV3_large_x1_0_ssld'
  201. if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN']:
  202. if backbone == 'ResNet50':
  203. backbone = 'DetResNet50'
  204. assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
  205. backbone)
  206. if getattr(paddlex, 'gui_mode', False):
  207. url = image_pretrain[backbone]
  208. fname = osp.split(url)[-1].split('.')[0]
  209. paddlex.utils.download_and_decompress(url, path=new_save_dir)
  210. if not osp.exists(osp.join(new_save_dir, fname)):
  211. for f in os.listdir(new_save_dir):
  212. dir_name = osp.join(new_save_dir, f)
  213. if osp.isdir(dir_name) and fname.split('_')[0] in dir_name:
  214. return dir_name
  215. else:
  216. return osp.join(new_save_dir, fname)
  217. import paddlehub as hub
  218. try:
  219. logging.info(
  220. "Connecting PaddleHub server to get pretrain weights...")
  221. hub.download(backbone, save_path=new_save_dir)
  222. except Exception as e:
  223. logging.error(
  224. "Couldn't download pretrain weight, you can download it manualy from {} (decompress the file if it is a compressed file), and set pretrain weights by your self".
  225. format(image_pretrain[backbone]),
  226. exit=False)
  227. if isinstance(e, hub.ResourceNotFoundError):
  228. raise Exception("Resource for backbone {} not found".format(
  229. backbone))
  230. elif isinstance(e, hub.ServerConnectionError):
  231. raise Exception(
  232. "Cannot get reource for backbone {}, please check your internet connection"
  233. .format(backbone))
  234. else:
  235. raise Exception(
  236. "Unexpected error, please make sure paddlehub >= 1.6.2")
  237. return osp.join(new_save_dir, backbone)
  238. elif flag in ['COCO', 'CITYSCAPES']:
  239. new_save_dir = save_dir
  240. if hasattr(paddlex, 'pretrain_dir'):
  241. new_save_dir = paddlex.pretrain_dir
  242. if class_name in [
  243. 'YOLOv3', 'FasterRCNN', 'MaskRCNN', 'DeepLabv3p', 'PPYOLO'
  244. ]:
  245. backbone = '{}_{}'.format(class_name, backbone)
  246. backbone = "{}_{}".format(backbone, flag)
  247. if flag == 'COCO':
  248. url = coco_pretrain[backbone]
  249. elif flag == 'CITYSCAPES':
  250. url = cityscapes_pretrain[backbone]
  251. fname = osp.split(url)[-1].split('.')[0]
  252. if getattr(paddlex, 'gui_mode', False):
  253. paddlex.utils.download_and_decompress(url, path=new_save_dir)
  254. if not osp.exists(osp.join(new_save_dir, fname)):
  255. for f in os.listdir(new_save_dir):
  256. dir_name = osp.join(new_save_dir, f)
  257. if osp.isdir(dir_name) and fname.split('_')[0] in dir_name:
  258. return dir_name
  259. else:
  260. return osp.join(new_save_dir, fname)
  261. import paddlehub as hub
  262. try:
  263. logging.info(
  264. "Connecting PaddleHub server to get pretrain weights...")
  265. hub.download(backbone, save_path=new_save_dir)
  266. except Exception as e:
  267. logging.error(
  268. "Couldn't download pretrain weight, you can download it manualy from {} (decompress the file if it is a compressed file), and set pretrain weights by your self".
  269. format(url),
  270. exit=False)
  271. if isinstance(e, hub.ResourceNotFoundError):
  272. raise Exception("Resource for backbone {} not found".format(
  273. backbone))
  274. elif isinstance(e, hub.ServerConnectionError):
  275. raise Exception(
  276. "Cannot get reource for backbone {}, please check your internet connection"
  277. .format(backbone))
  278. else:
  279. raise Exception(
  280. "Unexpected error, please make sure paddlehub >= 1.6.2")
  281. return osp.join(new_save_dir, backbone)
  282. elif flag == 'BAIDU10W':
  283. new_save_dir = save_dir
  284. if hasattr(paddlex, 'pretrain_dir'):
  285. new_save_dir = paddlex.pretrain_dir
  286. backbone = backbone + '_BAIDU10W'
  287. url = baidu10w_pretrain[backbone]
  288. fname = osp.split(url)[-1].split('.')[0]
  289. if getattr(paddlex, 'gui_mode', False):
  290. paddlex.utils.download_and_decompress(url, path=new_save_dir)
  291. if not osp.exists(osp.join(new_save_dir, fname)):
  292. for f in os.listdir(new_save_dir):
  293. dir_name = osp.join(new_save_dir, f)
  294. if osp.isdir(dir_name) and fname.split('_')[0] in dir_name:
  295. return dir_name
  296. else:
  297. return osp.join(new_save_dir, fname)
  298. import paddlehub as hub
  299. try:
  300. logging.info(
  301. "Connecting PaddleHub server to get pretrain weights...")
  302. hub.download(backbone, save_path=new_save_dir)
  303. except Exception as e:
  304. logging.error(
  305. "Couldn't download pretrain weight, you can download it manualy from {} (decompress the file if it is a compressed file), and set pretrain weights by your self".
  306. format(url),
  307. exit=False)
  308. if isinstance(e, hub.ResourceNotFoundError):
  309. raise Exception("Resource for backbone {} not found".format(
  310. backbone))
  311. elif isinstance(e, hub.ServerConnectionError):
  312. raise Exception(
  313. "Cannot get reource for backbone {}, please check your internet connection"
  314. .format(backbone))
  315. else:
  316. raise Exception(
  317. "Unexpected error, please make sure paddlehub >= 1.6.2")
  318. return osp.join(new_save_dir, backbone)
  319. else:
  320. logging.error("Path of retrain weights '{}' is not exists!".format(
  321. flag))