pretrain_weights.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import paddlex
  2. import paddlehub as hub
  3. import os
  4. import os.path as osp
  5. image_pretrain = {
  6. 'ResNet18':
  7. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet18_pretrained.tar',
  8. 'ResNet34':
  9. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_pretrained.tar',
  10. 'ResNet50':
  11. 'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar',
  12. 'ResNet101':
  13. 'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.tar',
  14. 'ResNet50_vd':
  15. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar',
  16. 'ResNet101_vd':
  17. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar',
  18. 'ResNet50_vd_ssld':
  19. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar',
  20. 'ResNet101_vd_ssld':
  21. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_ssld_pretrained.tar',
  22. 'MobileNetV1':
  23. 'http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar',
  24. 'MobileNetV2_x1.0':
  25. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_pretrained.tar',
  26. 'MobileNetV2_x0.5':
  27. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x0_5_pretrained.tar',
  28. 'MobileNetV2_x2.0':
  29. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x2_0_pretrained.tar',
  30. 'MobileNetV2_x0.25':
  31. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x0_25_pretrained.tar',
  32. 'MobileNetV2_x1.5':
  33. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x1_5_pretrained.tar',
  34. 'MobileNetV3_small':
  35. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_pretrained.tar',
  36. 'MobileNetV3_large':
  37. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_pretrained.tar',
  38. 'MobileNetV3_small_x1_0_ssld':
  39. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_ssld_pretrained.tar',
  40. 'MobileNetV3_large_x1_0_ssld':
  41. 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar',
  42. 'DarkNet53':
  43. 'https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_ImageNet1k_pretrained.tar',
  44. 'DenseNet121':
  45. 'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet121_pretrained.tar',
  46. 'DenseNet161':
  47. 'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet161_pretrained.tar',
  48. 'DenseNet201':
  49. 'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet201_pretrained.tar',
  50. 'DetResNet50':
  51. 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar',
  52. 'SegXception41':
  53. 'https://paddle-imagenet-models-name.bj.bcebos.com/Xception41_deeplab_pretrained.tar',
  54. 'SegXception65':
  55. 'https://paddle-imagenet-models-name.bj.bcebos.com/Xception65_deeplab_pretrained.tar',
  56. 'ShuffleNetV2':
  57. 'https://paddle-imagenet-models-name.bj.bcebos.com/ShuffleNetV2_pretrained.tar',
  58. 'HRNet_W18':
  59. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W18_C_pretrained.tar',
  60. 'HRNet_W30':
  61. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W30_C_pretrained.tar',
  62. 'HRNet_W32':
  63. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W32_C_pretrained.tar',
  64. 'HRNet_W40':
  65. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W40_C_pretrained.tar',
  66. 'HRNet_W48':
  67. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W48_C_pretrained.tar',
  68. 'HRNet_W60':
  69. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W60_C_pretrained.tar',
  70. 'HRNet_W64':
  71. 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W64_C_pretrained.tar',
  72. 'AlexNet':
  73. 'http://paddle-imagenet-models-name.bj.bcebos.com/AlexNet_pretrained.tar'
  74. }
  75. coco_pretrain = {
  76. 'UNet': 'https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz'
  77. }
  78. def get_pretrain_weights(flag, model_type, backbone, save_dir):
  79. if flag is None:
  80. return None
  81. elif osp.isdir(flag):
  82. return flag
  83. elif flag == 'IMAGENET':
  84. new_save_dir = save_dir
  85. if hasattr(paddlex, 'pretrain_dir'):
  86. new_save_dir = paddlex.pretrain_dir
  87. if backbone.startswith('Xception'):
  88. backbone = 'Seg{}'.format(backbone)
  89. elif backbone == 'MobileNetV2':
  90. backbone = 'MobileNetV2_x1.0'
  91. elif backbone == 'MobileNetV3_small_ssld':
  92. backbone = 'MobileNetV3_small_x1_0_ssld'
  93. elif backbone == 'MobileNetV3_large_ssld':
  94. backbone = 'MobileNetV3_large_x1_0_ssld'
  95. if model_type == 'detector':
  96. if backbone == 'ResNet50':
  97. backbone = 'DetResNet50'
  98. assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
  99. backbone)
  100. if backbone == 'AlexNet':
  101. url = image_pretrain[backbone]
  102. fname = osp.split(url)[-1].split('.')[0]
  103. paddlex.utils.download_and_decompress(url, path=new_save_dir)
  104. return osp.join(new_save_dir, fname)
  105. try:
  106. hub.download(backbone, save_path=new_save_dir)
  107. except Exception as e:
  108. if isinstance(e, hub.ResourceNotFoundError):
  109. raise Exception("Resource for backbone {} not found".format(
  110. backbone))
  111. elif isinstance(e, hub.ServerConnectionError):
  112. raise Exception(
  113. "Cannot get reource for backbone {}, please check your internet connecgtion"
  114. .format(backbone))
  115. else:
  116. raise Exception(
  117. "Unexpected error, please make sure paddlehub >= 1.6.2")
  118. return osp.join(new_save_dir, backbone)
  119. elif flag == 'COCO':
  120. new_save_dir = save_dir
  121. if hasattr(paddlex, 'pretrain_dir'):
  122. new_save_dir = paddlex.pretrain_dir
  123. url = coco_pretrain[backbone]
  124. fname = osp.split(url)[-1].split('.')[0]
  125. # paddlex.utils.download_and_decompress(url, path=new_save_dir)
  126. # return osp.join(new_save_dir, fname)
  127. assert backbone in coco_pretrain, "There is not COCO pretrain weights for {}, you may try ImageNet.".format(
  128. backbone)
  129. try:
  130. hub.download(backbone, save_path=new_save_dir)
  131. except Exception as e:
  132. if isinstance(hub.ResourceNotFoundError):
  133. raise Exception("Resource for backbone {} not found".format(
  134. backbone))
  135. elif isinstance(hub.ServerConnectionError):
  136. raise Exception(
  137. "Cannot get reource for backbone {}, please check your internet connecgtion"
  138. .format(backbone))
  139. else:
  140. raise Exception(
  141. "Unexpected error, please make sure paddlehub >= 1.6.2")
  142. return osp.join(new_save_dir, backbone)
  143. else:
  144. raise Exception(
  145. "pretrain_weights need to be defined as directory path or `IMAGENET` or 'COCO' (download pretrain weights automatically)."
  146. )