operate.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import pickle
  16. import traceback
  17. import os.path as osp
  18. import multiprocessing as mp
  19. from .cls_dataset import ClsDataset
  20. from .det_dataset import DetDataset
  21. from .seg_dataset import SegDataset
  22. from .ins_seg_dataset import InsSegDataset
  23. from ..utils import set_folder_status, get_folder_status, DatasetStatus, DownloadStatus, download, list_files
  24. dataset_url_list = [
  25. 'https://bj.bcebos.com/paddlex/demos/vegetables_cls.tar.gz',
  26. 'https://bj.bcebos.com/paddlex/demos/insect_det.tar.gz',
  27. 'https://bj.bcebos.com/paddlex/demos/optic_disc_seg.tar.gz',
  28. 'https://bj.bcebos.com/paddlex/demos/xiaoduxiong_ins_det.tar.gz',
  29. 'https://bj.bcebos.com/paddlex/demos/remote_sensing_seg.tar.gz'
  30. ]
  31. dataset_url_dict = {
  32. 'classification':
  33. 'https://bj.bcebos.com/paddlex/demos/vegetables_cls.tar.gz',
  34. 'detection': 'https://bj.bcebos.com/paddlex/demos/insect_det.tar.gz',
  35. 'segmentation':
  36. 'https://bj.bcebos.com/paddlex/demos/optic_disc_seg.tar.gz',
  37. 'instance_segmentation':
  38. 'https://bj.bcebos.com/paddlex/demos/xiaoduxiong_ins_det.tar.gz'
  39. }
  40. def _check_and_copy(dataset, dataset_path, source_path):
  41. try:
  42. dataset.check_dataset(source_path)
  43. except Exception as e:
  44. error_info = traceback.format_exc()
  45. set_folder_status(dataset_path, DatasetStatus.XCHECKFAIL, error_info)
  46. return
  47. set_folder_status(dataset_path, DatasetStatus.XCOPYING, os.getpid())
  48. try:
  49. dataset.copy_dataset(source_path, dataset.all_files)
  50. except Exception as e:
  51. error_info = traceback.format_exc()
  52. set_folder_status(dataset_path, DatasetStatus.XCOPYFAIL, error_info)
  53. return
  54. # 若上传已切分好的数据集
  55. if len(dataset.train_files) != 0:
  56. set_folder_status(dataset_path, DatasetStatus.XSPLITED)
  57. def import_dataset(dataset_id, dataset_type, dataset_path, source_path):
  58. if not osp.isdir(dataset_path):
  59. if osp.exists(dataset_path):
  60. os.remove(dataset_path)
  61. os.makedirs(dataset_path)
  62. set_folder_status(dataset_path, DatasetStatus.XCHECKING)
  63. if dataset_type == 'classification':
  64. ds = ClsDataset(dataset_id, dataset_path)
  65. elif dataset_type == 'detection':
  66. ds = DetDataset(dataset_id, dataset_path)
  67. elif dataset_type == 'segmentation':
  68. ds = SegDataset(dataset_id, dataset_path)
  69. elif dataset_type == 'instance_segmentation':
  70. ds = InsSegDataset(dataset_id, dataset_path)
  71. p = mp.Process(
  72. target=_check_and_copy, args=(ds, dataset_path, source_path))
  73. p.start()
  74. return p
  75. def _download_proc(url, target_path, dataset_type):
  76. # 下载数据集压缩包
  77. from paddlex.utils import decompress
  78. target_path = osp.join(target_path, dataset_type)
  79. fname = download(url, target_path)
  80. # 解压
  81. decompress(fname)
  82. set_folder_status(target_path, DownloadStatus.XDDECOMPRESSED)
  83. def download_demo_dataset(prj_type, target_path):
  84. url = dataset_url_list[prj_type.value]
  85. dataset_type = prj_type.name
  86. p = mp.Process(
  87. target=_download_proc, args=(url, target_path, dataset_type))
  88. p.start()
  89. return p
  90. def get_dataset_status(dataset_id, dataset_type, dataset_path):
  91. status, message = get_folder_status(dataset_path, True)
  92. if status is None:
  93. status = DatasetStatus.XEMPTY
  94. if status == DatasetStatus.XCOPYING:
  95. items = message.strip().split()
  96. if len(items) < 2:
  97. percent = 0.0
  98. else:
  99. pid = int(items[0])
  100. if int(items[1]) == 0:
  101. percent = 1.0
  102. else:
  103. copyed_files_num = len(list_files(dataset_path)) - 1
  104. percent = copyed_files_num * 1.0 / int(items[1])
  105. message = {'pid': pid, 'percent': percent}
  106. if status == DatasetStatus.XCOPYDONE or status == DatasetStatus.XSPLITED:
  107. if not osp.exists(osp.join(dataset_path, 'statis.pkl')):
  108. p = import_dataset(dataset_id, dataset_type, dataset_path,
  109. dataset_path)
  110. status = DatasetStatus.XCHECKING
  111. return status, message
  112. def split_dataset(dataset_id, dataset_type, dataset_path, val_split,
  113. test_split):
  114. status, message = get_folder_status(dataset_path, True)
  115. if status != DatasetStatus.XCOPYDONE and status != DatasetStatus.XSPLITED:
  116. raise Exception("数据集还未导入完成,请等数据集导入成功后再进行切分")
  117. if not osp.exists(osp.join(dataset_path, 'statis.pkl')):
  118. raise Exception("数据集需重新校验,请刷新数据集后再进行切分")
  119. if dataset_type == 'classification':
  120. ds = ClsDataset(dataset_id, dataset_path)
  121. elif dataset_type == 'detection':
  122. ds = DetDataset(dataset_id, dataset_path)
  123. elif dataset_type == 'segmentation':
  124. ds = SegDataset(dataset_id, dataset_path)
  125. elif dataset_type == 'instance_segmentation':
  126. ds = InsSegDataset(dataset_id, dataset_path)
  127. ds.load_statis_info()
  128. ds.split(val_split, test_split)
  129. set_folder_status(dataset_path, DatasetStatus.XSPLITED)
  130. def get_dataset_details(dataset_path):
  131. status, message = get_folder_status(dataset_path, True)
  132. if status == DatasetStatus.XCOPYDONE or status == DatasetStatus.XSPLITED:
  133. with open(osp.join(dataset_path, 'statis.pkl'), 'rb') as f:
  134. details = pickle.load(f)
  135. return details
  136. return None