demo.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import os
  2. import json
  3. from os import path as osp
  4. from .utils import DownloadStatus, DatasetStatus, ProjectType, get_folder_status
  5. from .project.train.params import PARAMS_CLASS_LIST
  6. from .utils import CustomEncoder
  7. prj_type_list = [
  8. 'classification', 'detection', 'segmentation', 'instance_segmentation'
  9. ]
  10. def download_demo_dataset(data, workspace, load_demo_proc_dict):
  11. """下载样例工程
  12. Args:
  13. data为dict, key包括
  14. 'prj_type' 样例类型(ProjectType)
  15. """
  16. if isinstance(data['prj_type'], str):
  17. prj_type = ProjectType(prj_type_list.index(data['prj_type']))
  18. else:
  19. prj_type = ProjectType(data['prj_type'])
  20. assert prj_type.value >= 0 and prj_type.value <= 4, "不支持此样例类型的导入(type:{})".format(
  21. prj_type)
  22. target_path = osp.join(workspace.path, "demo_datasets")
  23. if not osp.exists(target_path):
  24. os.makedirs(target_path)
  25. from .dataset.operate import download_demo_dataset
  26. proc = download_demo_dataset(prj_type, target_path)
  27. if prj_type not in load_demo_proc_dict:
  28. load_demo_proc_dict[prj_type] = []
  29. load_demo_proc_dict[prj_type].append(proc)
  30. return {'status': 1}
  31. def load_demo_project(data, workspace, monitored_processes,
  32. load_demo_proj_data_dict, load_demo_proc_dict):
  33. """导入样例工程
  34. Args:
  35. data为dict, key包括
  36. 'prj_type' 样例类型(ProjectType)
  37. """
  38. if isinstance(data['prj_type'], str):
  39. prj_type = ProjectType(prj_type_list.index(data['prj_type']))
  40. else:
  41. prj_type = ProjectType(data['prj_type'])
  42. assert prj_type.value >= 0 and prj_type.value <= 4, "不支持此样例类型的导入(type:{})".format(
  43. prj_type)
  44. target_path = osp.join(workspace.path, "demo_datasets")
  45. assert osp.exists(target_path), "样例数据集暂未下载,无法导入样例工程"
  46. target_path = osp.join(target_path, prj_type.name)
  47. assert osp.exists(target_path), "样例{}数据集暂未下载,无法导入样例工程".format(
  48. prj_type.name)
  49. status = get_folder_status(target_path)
  50. assert status == DownloadStatus.XDDECOMPRESSED, "样例{}数据集暂未解压,无法导入样例工程".format(
  51. prj_type.name)
  52. from .dataset.operate import dataset_url_list
  53. url = dataset_url_list[prj_type.value]
  54. fname = osp.split(url)[-1]
  55. for suffix in ['tar', 'tgz', 'zip']:
  56. pos = fname.find(suffix)
  57. if pos >= 2:
  58. fname = fname[0:pos - 1]
  59. break
  60. source_dataset_path = osp.join(target_path, fname)
  61. params_path = osp.join(target_path, fname, fname + "_params.json")
  62. params = {}
  63. with open(params_path, "r", encoding="utf-8") as f:
  64. params = json.load(f)
  65. dataset_params = params['dataset_info']
  66. proj_params = params['project_info']
  67. train_params = params['train_params']
  68. # 判断数据集、项目名称是否已存在
  69. dataset_name = dataset_params['name']
  70. project_name = proj_params['name']
  71. for id in workspace.datasets:
  72. if dataset_name == workspace.datasets[id].name:
  73. return {'status': 1, 'loading_status': 'dataset already exists'}
  74. for id in workspace.projects:
  75. if project_name == workspace.projects[id].name:
  76. return {'status': 1, 'loading_status': 'project already exists'}
  77. # 创建数据集
  78. from .dataset.dataset import create_dataset
  79. results = create_dataset(dataset_params, workspace)
  80. dataset_id = results['did']
  81. # 导入数据集
  82. from .dataset.dataset import import_dataset
  83. data = {'did': dataset_id, 'path': source_dataset_path}
  84. import_dataset(data, workspace, monitored_processes, load_demo_proc_dict)
  85. # 创建项目
  86. from .project.project import create_project
  87. results = create_project(proj_params, workspace)
  88. pid = results['pid']
  89. # 绑定数据集
  90. from .workspace import set_attr
  91. attr_dict = {'did': dataset_id}
  92. params = {'struct': 'project', 'id': pid, 'attr_dict': attr_dict}
  93. set_attr(params, workspace)
  94. # 创建任务
  95. task_params = PARAMS_CLASS_LIST[prj_type.value]()
  96. for k, v in train_params.items():
  97. if hasattr(task_params, k):
  98. setattr(task_params, k, v)
  99. task_params = CustomEncoder().encode(task_params)
  100. from .project.task import create_task
  101. params = {'pid': pid, 'train': task_params}
  102. create_task(params, workspace)
  103. load_demo_proj_data_dict[prj_type] = (pid, dataset_id)
  104. return {'status': 1, 'did': dataset_id, 'pid': pid}
  105. def get_download_demo_progress(data, workspace):
  106. """查询样例工程的下载进度
  107. Args:
  108. data为dict, key包括
  109. 'prj_type' 样例类型(ProjectType)
  110. """
  111. if isinstance(data['prj_type'], str):
  112. target_path = osp.join(workspace.path, "demo_datasets",
  113. data['prj_type'])
  114. else:
  115. prj_type = ProjectType(data['prj_type'])
  116. target_path = osp.join(workspace.path, "demo_datasets", prj_type.name)
  117. status, message = get_folder_status(target_path, True)
  118. if status == DownloadStatus.XDDOWNLOADING:
  119. if isinstance(data['prj_type'], str):
  120. from .dataset.operate import dataset_url_dict
  121. url = dataset_url_dict[data['prj_type']]
  122. else:
  123. from .dataset.operate import dataset_url_list
  124. url = dataset_url_list[prj_type.value]
  125. fname = osp.split(url)[-1] + "_tmp"
  126. fullname = osp.join(target_path, fname)
  127. total_size = int(message)
  128. download_size = osp.getsize(fullname)
  129. message = download_size * 100 / total_size
  130. if status is not None:
  131. attr = {'status': status.value, 'progress': message}
  132. else:
  133. attr = {'status': status, 'progress': message}
  134. return {'status': 1, 'attr': attr}
  135. def stop_import_demo(data, workspace, load_demo_proc_dict,
  136. load_demo_proj_data_dict):
  137. """停止样例工程的导入进度
  138. Args:
  139. request(comm.Request): 其中request.params为dict, key包括
  140. 'prj_type' 样例类型(ProjectType)
  141. """
  142. if isinstance(data['prj_type'], str):
  143. prj_type = ProjectType(prj_type_list.index(data['prj_type']))
  144. else:
  145. prj_type = ProjectType(data['prj_type'])
  146. for proc in load_demo_proc_dict[prj_type]:
  147. if proc.is_alive():
  148. proc.terminate()
  149. # 只删除未完成导入的样例项目
  150. if prj_type in load_demo_proj_data_dict:
  151. pid, did = load_demo_proj_data_dict[prj_type]
  152. params = {'did': did}
  153. from .dataset.dataset import get_dataset_status
  154. results = get_dataset_status(params, workspace)
  155. dataset_status = DatasetStatus(results['dataset_status'])
  156. if dataset_status not in [
  157. DatasetStatus.XCOPYDONE, DatasetStatus.XSPLITED
  158. ]:
  159. params = {'pid': pid}
  160. from .project.project import delete_project
  161. delete_project(params, workspace)
  162. from .dataset.dataset import delete_dataset
  163. params = {'did': did}
  164. delete_dataset(params, workspace)
  165. return {'status': 1}