operate.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import pickle
  16. import traceback
  17. import os.path as osp
  18. import multiprocessing as mp
  19. from .cls_dataset import ClsDataset
  20. from .det_dataset import DetDataset
  21. from .seg_dataset import SegDataset
  22. from .ins_seg_dataset import InsSegDataset
  23. from ..utils import set_folder_status, get_folder_status, DatasetStatus, DownloadStatus, download, list_files
  24. dataset_url_list = [
  25. 'https://bj.bcebos.com/paddlex/demos/vegetables_cls.tar.gz',
  26. 'https://bj.bcebos.com/paddlex/demos/insect_det.tar.gz',
  27. 'https://bj.bcebos.com/paddlex/demos/optic_disc_seg.tar.gz',
  28. 'https://bj.bcebos.com/paddlex/demos/xiaoduxiong_ins_det.tar.gz',
  29. 'https://bj.bcebos.com/paddlex/demos/remote_sensing_seg.tar.gz'
  30. ]
  31. dataset_url_dict = {
  32. 'classification':
  33. 'https://bj.bcebos.com/paddlex/demos/vegetables_cls.tar.gz',
  34. 'detection': 'https://bj.bcebos.com/paddlex/demos/insect_det.tar.gz',
  35. 'segmentation':
  36. 'https://bj.bcebos.com/paddlex/demos/optic_disc_seg.tar.gz',
  37. 'instance_segmentation':
  38. 'https://bj.bcebos.com/paddlex/demos/xiaoduxiong_ins_det.tar.gz'
  39. }
  40. def _check_and_copy(dataset, dataset_path, source_path):
  41. try:
  42. dataset.check_dataset(source_path)
  43. except Exception as e:
  44. error_info = traceback.format_exc()
  45. set_folder_status(dataset_path, DatasetStatus.XCHECKFAIL, error_info)
  46. return
  47. set_folder_status(dataset_path, DatasetStatus.XCOPYING, os.getpid())
  48. try:
  49. dataset.copy_dataset(source_path, dataset.all_files)
  50. except Exception as e:
  51. error_info = traceback.format_exc()
  52. set_folder_status(dataset_path, DatasetStatus.XCOPYFAIL, error_info)
  53. return
  54. # 若上传已切分好的数据集
  55. if len(dataset.train_files) != 0:
  56. set_folder_status(dataset_path, DatasetStatus.XSPLITED)
  57. def import_dataset(dataset_id, dataset_type, dataset_path, source_path):
  58. set_folder_status(dataset_path, DatasetStatus.XCHECKING)
  59. if dataset_type == 'classification':
  60. ds = ClsDataset(dataset_id, dataset_path)
  61. elif dataset_type == 'detection':
  62. ds = DetDataset(dataset_id, dataset_path)
  63. elif dataset_type == 'segmentation':
  64. ds = SegDataset(dataset_id, dataset_path)
  65. elif dataset_type == 'instance_segmentation':
  66. ds = InsSegDataset(dataset_id, dataset_path)
  67. p = mp.Process(
  68. target=_check_and_copy, args=(ds, dataset_path, source_path))
  69. p.start()
  70. return p
  71. def _download_proc(url, target_path, dataset_type):
  72. # 下载数据集压缩包
  73. from paddlex.utils import decompress
  74. target_path = osp.join(target_path, dataset_type)
  75. fname = download(url, target_path)
  76. # 解压
  77. decompress(fname)
  78. set_folder_status(target_path, DownloadStatus.XDDECOMPRESSED)
  79. def download_demo_dataset(prj_type, target_path):
  80. url = dataset_url_list[prj_type.value]
  81. dataset_type = prj_type.name
  82. p = mp.Process(
  83. target=_download_proc, args=(url, target_path, dataset_type))
  84. p.start()
  85. return p
  86. def get_dataset_status(dataset_id, dataset_type, dataset_path):
  87. status, message = get_folder_status(dataset_path, True)
  88. if status is None:
  89. status = DatasetStatus.XEMPTY
  90. if status == DatasetStatus.XCOPYING:
  91. items = message.strip().split()
  92. pid = None
  93. if len(items) < 2:
  94. percent = 0.0
  95. else:
  96. pid = int(items[0])
  97. if int(items[1]) == 0:
  98. percent = 1.0
  99. else:
  100. copyed_files_num = len(list_files(dataset_path)) - 1
  101. percent = copyed_files_num * 1.0 / int(items[1])
  102. message = {'pid': pid, 'percent': percent}
  103. if status == DatasetStatus.XCOPYDONE or status == DatasetStatus.XSPLITED:
  104. if not osp.exists(osp.join(dataset_path, 'statis.pkl')):
  105. p = import_dataset(dataset_id, dataset_type, dataset_path,
  106. dataset_path)
  107. status = DatasetStatus.XCHECKING
  108. return status, message
  109. def split_dataset(dataset_id, dataset_type, dataset_path, val_split,
  110. test_split):
  111. status, message = get_folder_status(dataset_path, True)
  112. if status != DatasetStatus.XCOPYDONE and status != DatasetStatus.XSPLITED:
  113. raise Exception("数据集还未导入完成,请等数据集导入成功后再进行切分")
  114. if not osp.exists(osp.join(dataset_path, 'statis.pkl')):
  115. raise Exception("数据集需重新校验,请刷新数据集后再进行切分")
  116. if dataset_type == 'classification':
  117. ds = ClsDataset(dataset_id, dataset_path)
  118. elif dataset_type == 'detection':
  119. ds = DetDataset(dataset_id, dataset_path)
  120. elif dataset_type == 'segmentation':
  121. ds = SegDataset(dataset_id, dataset_path)
  122. elif dataset_type == 'instance_segmentation':
  123. ds = InsSegDataset(dataset_id, dataset_path)
  124. ds.load_statis_info()
  125. ds.split(val_split, test_split)
  126. set_folder_status(dataset_path, DatasetStatus.XSPLITED)
  127. def get_dataset_details(dataset_path):
  128. status, message = get_folder_status(dataset_path, True)
  129. if status == DatasetStatus.XCOPYDONE or status == DatasetStatus.XSPLITED:
  130. with open(osp.join(dataset_path, 'statis.pkl'), 'rb') as f:
  131. details = pickle.load(f)
  132. return details
  133. return None