dataset.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # copytrue (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ..utils import (set_folder_status, get_folder_status, DatasetStatus,
  15. TaskStatus, is_available, DownloadStatus,
  16. PretrainedModelStatus, ProjectType)
  17. from threading import Thread
  18. import random
  19. from .utils import copy_directory, get_label_count
  20. import traceback
  21. import shutil
  22. import psutil
  23. import pickle
  24. import os
  25. import os.path as osp
  26. import time
  27. import json
  28. import base64
  29. import cv2
  30. from .. import workspace_pb2 as w
  31. def create_dataset(data, workspace):
  32. """
  33. 创建dataset
  34. """
  35. create_time = time.time()
  36. time_array = time.localtime(create_time)
  37. create_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
  38. id = workspace.max_dataset_id + 1
  39. if id < 10000:
  40. did = 'D%04d' % id
  41. else:
  42. did = 'D{}'.format(id)
  43. assert not did in workspace.datasets, "【数据集创建】ID'{}'已经被占用.".format(did)
  44. path = osp.join(workspace.path, 'datasets', did)
  45. if osp.exists(path):
  46. if not osp.isdir(path):
  47. os.remove(path)
  48. else:
  49. shutil.rmtree(path)
  50. os.makedirs(path)
  51. set_folder_status(path, DatasetStatus.XEMPTY)
  52. workspace.max_dataset_id = id
  53. ds = w.Dataset(
  54. id=did,
  55. name=data['name'],
  56. desc=data['desc'],
  57. type=data['dataset_type'],
  58. create_time=create_time,
  59. path=path)
  60. workspace.datasets[did].CopyFrom(ds)
  61. return {'status': 1, 'did': did}
  62. def import_dataset(data, workspace, monitored_processes, load_demo_proc_dict):
  63. """导入数据集到工作目录,包括数据检查和拷贝
  64. Args:
  65. data为dict, key包括
  66. 'did':数据集id,'path': 原数据集目录路径,
  67. 'demo'(可选): 该数据集为demo数据集
  68. """
  69. dataset_id = data['did']
  70. source_path = data['path']
  71. assert dataset_id in workspace.datasets, "数据集ID'{}'不存在.".format(dataset_id)
  72. dataset_type = workspace.datasets[dataset_id].type
  73. dataset_path = workspace.datasets[dataset_id].path
  74. valid_dataset_type = [
  75. 'classification', 'detection', 'segmentation', 'instance_segmentation',
  76. 'remote_segmentation'
  77. ]
  78. assert dataset_type in valid_dataset_type, "无法识别的数据类型{}".format(
  79. dataset_type)
  80. from .operate import import_dataset
  81. process = import_dataset(dataset_id, dataset_type, dataset_path,
  82. source_path)
  83. monitored_processes.put(process.pid)
  84. if 'demo' in data:
  85. prj_type = getattr(ProjectType, dataset_type)
  86. if prj_type not in load_demo_proc_dict:
  87. load_demo_proc_dict[prj_type] = []
  88. load_demo_proc_dict[prj_type].append(process)
  89. return {'status': 1}
  90. def delete_dataset(data, workspace):
  91. """删除dataset。
  92. Args:
  93. data为dict,key包括
  94. 'did'数据集id
  95. """
  96. dataset_id = data['did']
  97. assert dataset_id in workspace.datasets, "数据集ID'{}'不存在.".format(dataset_id)
  98. counter = 0
  99. for key in workspace.projects:
  100. if workspace.projects[key].did == dataset_id:
  101. counter += 1
  102. assert counter == 0, "无法删除数据集,当前仍被{}个项目中使用中,请先删除相关项目".format(counter)
  103. path = workspace.datasets[dataset_id].path
  104. if osp.exists(path):
  105. shutil.rmtree(path)
  106. del workspace.datasets[dataset_id]
  107. return {'status': 1}
  108. def get_dataset_status(data, workspace):
  109. """获取数据集当前状态
  110. Args:
  111. data为dict, key包括
  112. 'did':数据集id
  113. """
  114. from .operate import get_dataset_status
  115. dataset_id = data['did']
  116. assert dataset_id in workspace.datasets, "数据集ID'{}'不存在.".format(dataset_id)
  117. dataset_type = workspace.datasets[dataset_id].type
  118. dataset_path = workspace.datasets[dataset_id].path
  119. dataset_name = workspace.datasets[dataset_id].name
  120. dataset_desc = workspace.datasets[dataset_id].desc
  121. dataset_create_time = workspace.datasets[dataset_id].create_time
  122. status, message = get_dataset_status(dataset_id, dataset_type,
  123. dataset_path)
  124. dataset_pids = list()
  125. for key in workspace.projects:
  126. if dataset_id == workspace.projects[key].did:
  127. dataset_pids.append(workspace.projects[key].id)
  128. attr = {
  129. "type": dataset_type,
  130. "id": dataset_id,
  131. "name": dataset_name,
  132. "path": dataset_path,
  133. "desc": dataset_desc,
  134. "create_time": dataset_create_time,
  135. "pids": dataset_pids
  136. }
  137. return {
  138. 'status': 1,
  139. 'id': dataset_id,
  140. 'dataset_status': status.value,
  141. 'message': message,
  142. 'attr': attr
  143. }
  144. def list_datasets(workspace):
  145. """
  146. 列出数据集列表,可根据request中的参数进行筛选
  147. """
  148. from .operate import get_dataset_status
  149. dataset_list = list()
  150. for key in workspace.datasets:
  151. dataset_type = workspace.datasets[key].type
  152. dataset_id = workspace.datasets[key].id
  153. dataset_name = workspace.datasets[key].name
  154. dataset_path = workspace.datasets[key].path
  155. dataset_desc = workspace.datasets[key].desc
  156. dataset_create_time = workspace.datasets[key].create_time
  157. status, message = get_dataset_status(dataset_id, dataset_type,
  158. dataset_path)
  159. attr = {
  160. "type": dataset_type,
  161. "id": dataset_id,
  162. "name": dataset_name,
  163. "path": dataset_path,
  164. "desc": dataset_desc,
  165. "create_time": dataset_create_time,
  166. 'dataset_status': status.value,
  167. 'message': message
  168. }
  169. dataset_list.append({"id": dataset_id, "attr": attr})
  170. return {'status': 1, "datasets": dataset_list}
  171. def get_dataset_details(data, workspace):
  172. """获取数据集详情
  173. Args:
  174. data为dict, key包括
  175. 'did':数据集id
  176. Return:
  177. details(dict): 'file_info': 全量数据集文件与标签映射表,'label_info': 标签与全量数据集文件映射表,
  178. 'labels': 标签列表,'train_files': 训练集文件列表, 'val_files': 验证集文件列表,
  179. 'test_files': 测试集文件列表
  180. """
  181. from .operate import get_dataset_details
  182. dataset_id = data['did']
  183. assert dataset_id in workspace.datasets, "数据集ID'{}'不存在.".format(dataset_id)
  184. dataset_path = workspace.datasets[dataset_id].path
  185. details = get_dataset_details(dataset_path)
  186. return {'status': 1, 'details': details}
  187. def split_dataset(data, workspace):
  188. """将数据集切分为训练集、验证集和测试集
  189. Args:
  190. data为dict, key包括
  191. 'did':数据集id, 'val_split': 验证集比例, 'test_split': 测试集比例
  192. """
  193. from .operate import split_dataset
  194. from .operate import get_dataset_details
  195. dataset_id = data['did']
  196. assert dataset_id in workspace.datasets, "数据集ID'{}'不存在.".format(dataset_id)
  197. dataset_type = workspace.datasets[dataset_id].type
  198. dataset_path = workspace.datasets[dataset_id].path
  199. val_split = data['val_split']
  200. test_split = data['test_split']
  201. split_dataset(dataset_id, dataset_type, dataset_path, val_split,
  202. test_split)
  203. return {'status': 1}
  204. def img_base64(data, workspace=None):
  205. """将数据集切分为训练集、验证集和测试集
  206. Args:
  207. data为dict, key包括
  208. 'path':图片绝对路径
  209. """
  210. path = data['path']
  211. path = '/'.join(path.split('\\'))
  212. if 'did' in data:
  213. did = data['did']
  214. lable_type = workspace.datasets[did].type
  215. ds_path = workspace.datasets[did].path
  216. ret = get_dataset_details(data, workspace)
  217. dataset_details = ret['details']
  218. ds_label_count = get_label_count(dataset_details['label_info'])
  219. image_path = 'JPEGImages/' + path.split('/')[-1]
  220. anno = osp.join(ds_path, dataset_details["file_info"][image_path])
  221. if lable_type == 'detection':
  222. from ..project.visualize import plot_det_label
  223. labels = list(ds_label_count.keys())
  224. img = plot_det_label(path, anno, labels)
  225. base64_str = base64.b64encode(cv2.imencode('.png', img)[1]).decode(
  226. )
  227. return {'status': 1, 'img_data': base64_str}
  228. elif lable_type == 'segmentation' or lable_type == 'remote_segmentation':
  229. from ..project.visualize import plot_seg_label
  230. im = plot_seg_label(anno)
  231. img = cv2.imread(path)
  232. im = cv2.addWeighted(img, 0.5, im, 0.5, 0).astype('uint8')
  233. base64_str = base64.b64encode(cv2.imencode('.png', im)[1]).decode()
  234. return {'status': 1, 'img_data': base64_str}
  235. elif lable_type == 'instance_segmentation':
  236. labels = list(ds_label_count.keys())
  237. from ..project.visualize import plot_insseg_label
  238. img = plot_insseg_label(path, anno, labels)
  239. base64_str = base64.b64encode(cv2.imencode('.png', img)[1]).decode(
  240. )
  241. return {'status': 1, 'img_data': base64_str}
  242. else:
  243. raise Exception("数据集类型{}目前暂不支持".format(lable_type))
  244. with open(path, 'rb') as f:
  245. base64_data = base64.b64encode(f.read())
  246. base64_str = str(base64_data, 'utf-8')
  247. return {'status': 1, 'img_data': base64_str}