dataset.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # copytrue (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ..utils import set_folder_status, DatasetStatus, ProjectType
  15. from .utils import get_label_count
  16. import shutil
  17. import os
  18. import os.path as osp
  19. import time
  20. import base64
  21. import cv2
  22. from .. import workspace_pb2 as w
  23. def create_dataset(data, workspace):
  24. """
  25. 创建dataset
  26. """
  27. create_time = time.time()
  28. time_array = time.localtime(create_time)
  29. create_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
  30. id = workspace.max_dataset_id + 1
  31. if id < 10000:
  32. did = 'D%04d' % id
  33. else:
  34. did = 'D{}'.format(id)
  35. assert not did in workspace.datasets, "【数据集创建】ID'{}'已经被占用.".format(did)
  36. path = osp.join(workspace.path, 'datasets', did)
  37. if osp.exists(path):
  38. if not osp.isdir(path):
  39. os.remove(path)
  40. else:
  41. shutil.rmtree(path)
  42. os.makedirs(path)
  43. set_folder_status(path, DatasetStatus.XEMPTY)
  44. workspace.max_dataset_id = id
  45. ds = w.Dataset(
  46. id=did,
  47. name=data['name'],
  48. desc=data['desc'],
  49. type=data['dataset_type'],
  50. create_time=create_time,
  51. path=path)
  52. workspace.datasets[did].CopyFrom(ds)
  53. return {'status': 1, 'did': did}
  54. def import_dataset(data, workspace, monitored_processes, load_demo_proc_dict):
  55. """导入数据集到工作目录,包括数据检查和拷贝
  56. Args:
  57. data为dict, key包括
  58. 'did':数据集id,'path': 原数据集目录路径,
  59. 'demo'(可选): 该数据集为demo数据集
  60. """
  61. dataset_id = data['did']
  62. source_path = data['path']
  63. assert dataset_id in workspace.datasets, "数据集ID'{}'不存在.".format(dataset_id)
  64. dataset_type = workspace.datasets[dataset_id].type
  65. dataset_path = workspace.datasets[dataset_id].path
  66. valid_dataset_type = [
  67. 'classification', 'detection', 'segmentation', 'instance_segmentation',
  68. 'remote_segmentation'
  69. ]
  70. assert dataset_type in valid_dataset_type, "无法识别的数据类型{}".format(
  71. dataset_type)
  72. from .operate import import_dataset
  73. process = import_dataset(dataset_id, dataset_type, dataset_path,
  74. source_path)
  75. monitored_processes.put(process.pid)
  76. if 'demo' in data:
  77. prj_type = getattr(ProjectType, dataset_type)
  78. if prj_type not in load_demo_proc_dict:
  79. load_demo_proc_dict[prj_type] = []
  80. load_demo_proc_dict[prj_type].append(process)
  81. return {'status': 1}
  82. def delete_dataset(data, workspace):
  83. """删除dataset。
  84. Args:
  85. data为dict,key包括
  86. 'did'数据集id
  87. """
  88. dataset_id = data['did']
  89. assert dataset_id in workspace.datasets, "数据集ID'{}'不存在.".format(dataset_id)
  90. counter = 0
  91. for key in workspace.projects:
  92. if workspace.projects[key].did == dataset_id:
  93. counter += 1
  94. assert counter == 0, "无法删除数据集,当前仍被{}个项目中使用中,请先删除相关项目".format(counter)
  95. path = workspace.datasets[dataset_id].path
  96. if osp.exists(path):
  97. shutil.rmtree(path)
  98. del workspace.datasets[dataset_id]
  99. return {'status': 1}
  100. def get_dataset_status(data, workspace):
  101. """获取数据集当前状态
  102. Args:
  103. data为dict, key包括
  104. 'did':数据集id
  105. """
  106. from .operate import get_dataset_status
  107. dataset_id = data['did']
  108. assert dataset_id in workspace.datasets, "数据集ID'{}'不存在.".format(dataset_id)
  109. dataset_type = workspace.datasets[dataset_id].type
  110. dataset_path = workspace.datasets[dataset_id].path
  111. dataset_name = workspace.datasets[dataset_id].name
  112. dataset_desc = workspace.datasets[dataset_id].desc
  113. dataset_create_time = workspace.datasets[dataset_id].create_time
  114. status, message = get_dataset_status(dataset_id, dataset_type,
  115. dataset_path)
  116. dataset_pids = list()
  117. for key in workspace.projects:
  118. if dataset_id == workspace.projects[key].did:
  119. dataset_pids.append(workspace.projects[key].id)
  120. attr = {
  121. "type": dataset_type,
  122. "id": dataset_id,
  123. "name": dataset_name,
  124. "path": dataset_path,
  125. "desc": dataset_desc,
  126. "create_time": dataset_create_time,
  127. "pids": dataset_pids
  128. }
  129. return {
  130. 'status': 1,
  131. 'id': dataset_id,
  132. 'dataset_status': status.value,
  133. 'message': message,
  134. 'attr': attr
  135. }
  136. def list_datasets(workspace):
  137. """
  138. 列出数据集列表,可根据request中的参数进行筛选
  139. """
  140. from .operate import get_dataset_status
  141. dataset_list = list()
  142. for key in workspace.datasets:
  143. dataset_type = workspace.datasets[key].type
  144. dataset_id = workspace.datasets[key].id
  145. dataset_name = workspace.datasets[key].name
  146. dataset_path = workspace.datasets[key].path
  147. dataset_desc = workspace.datasets[key].desc
  148. dataset_create_time = workspace.datasets[key].create_time
  149. status, message = get_dataset_status(dataset_id, dataset_type,
  150. dataset_path)
  151. attr = {
  152. "type": dataset_type,
  153. "id": dataset_id,
  154. "name": dataset_name,
  155. "path": dataset_path,
  156. "desc": dataset_desc,
  157. "create_time": dataset_create_time,
  158. 'dataset_status': status.value,
  159. 'message': message
  160. }
  161. dataset_list.append({"id": dataset_id, "attr": attr})
  162. return {'status': 1, "datasets": dataset_list}
  163. def get_dataset_details(data, workspace):
  164. """获取数据集详情
  165. Args:
  166. data为dict, key包括
  167. 'did':数据集id
  168. Return:
  169. details(dict): 'file_info': 全量数据集文件与标签映射表,'label_info': 标签与全量数据集文件映射表,
  170. 'labels': 标签列表,'train_files': 训练集文件列表, 'val_files': 验证集文件列表,
  171. 'test_files': 测试集文件列表
  172. """
  173. from .operate import get_dataset_details
  174. dataset_id = data['did']
  175. assert dataset_id in workspace.datasets, "数据集ID'{}'不存在.".format(dataset_id)
  176. dataset_path = workspace.datasets[dataset_id].path
  177. details = get_dataset_details(dataset_path)
  178. return {'status': 1, 'details': details}
  179. def split_dataset(data, workspace):
  180. """将数据集切分为训练集、验证集和测试集
  181. Args:
  182. data为dict, key包括
  183. 'did':数据集id, 'val_split': 验证集比例, 'test_split': 测试集比例
  184. """
  185. from .operate import split_dataset
  186. from .operate import get_dataset_details
  187. dataset_id = data['did']
  188. assert dataset_id in workspace.datasets, "数据集ID'{}'不存在.".format(dataset_id)
  189. dataset_type = workspace.datasets[dataset_id].type
  190. dataset_path = workspace.datasets[dataset_id].path
  191. val_split = data['val_split']
  192. test_split = data['test_split']
  193. split_dataset(dataset_id, dataset_type, dataset_path, val_split,
  194. test_split)
  195. return {'status': 1}
  196. def img_base64(data, workspace=None):
  197. """将数据集切分为训练集、验证集和测试集
  198. Args:
  199. data为dict, key包括
  200. 'path':图片绝对路径
  201. """
  202. path = data['path']
  203. path = '/'.join(path.split('\\'))
  204. if 'did' in data:
  205. did = data['did']
  206. lable_type = workspace.datasets[did].type
  207. ds_path = workspace.datasets[did].path
  208. ret = get_dataset_details(data, workspace)
  209. dataset_details = ret['details']
  210. ds_label_count = get_label_count(dataset_details['label_info'])
  211. image_path = 'JPEGImages/' + path.split('/')[-1]
  212. anno = osp.join(ds_path, dataset_details["file_info"][image_path])
  213. if lable_type == 'detection':
  214. from ..project.visualize import plot_det_label
  215. labels = list(ds_label_count.keys())
  216. img = plot_det_label(path, anno, labels)
  217. base64_str = base64.b64encode(cv2.imencode('.png', img)[1]).decode(
  218. )
  219. return {'status': 1, 'img_data': base64_str}
  220. elif lable_type == 'segmentation' or lable_type == 'remote_segmentation':
  221. from ..project.visualize import plot_seg_label
  222. im = plot_seg_label(anno)
  223. img = cv2.imread(path)
  224. im = cv2.addWeighted(img, 0.5, im, 0.5, 0).astype('uint8')
  225. base64_str = base64.b64encode(cv2.imencode('.png', im)[1]).decode()
  226. return {'status': 1, 'img_data': base64_str}
  227. elif lable_type == 'instance_segmentation':
  228. labels = list(ds_label_count.keys())
  229. from ..project.visualize import plot_insseg_label
  230. img = plot_insseg_label(path, anno, labels)
  231. base64_str = base64.b64encode(cv2.imencode('.png', img)[1]).decode(
  232. )
  233. return {'status': 1, 'img_data': base64_str}
  234. else:
  235. raise Exception("数据集类型{}目前暂不支持".format(lable_type))
  236. with open(path, 'rb') as f:
  237. base64_data = base64.b64encode(f.read())
  238. base64_str = str(base64_data, 'utf-8')
  239. return {'status': 1, 'img_data': base64_str}