demo.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import os
  2. import json
  3. from os import path as osp
  4. from .utils import DownloadStatus, DatasetStatus, ProjectType, get_folder_status
  5. from .project.train.params import PARAMS_CLASS_LIST
  6. from .utils import CustomEncoder
  7. prj_type_list = [
  8. 'classification', 'detection', 'segmentation', 'instance_segmentation'
  9. ]
  10. def download_demo_dataset(data, workspace, load_demo_proc_dict):
  11. """下载样例工程
  12. Args:
  13. data为dict, key包括
  14. 'prj_type' 样例类型(ProjectType)
  15. """
  16. if isinstance(data['prj_type'], str):
  17. prj_type = ProjectType(prj_type_list.index(data['prj_type']))
  18. else:
  19. prj_type = ProjectType(data['prj_type'])
  20. assert prj_type.value >= 0 and prj_type.value <= 4, "不支持此样例类型的导入(type:{})".format(
  21. prj_type)
  22. target_path = osp.join(workspace.path, "demo_datasets")
  23. if not osp.exists(target_path):
  24. os.makedirs(target_path)
  25. from .dataset.operate import download_demo_dataset
  26. proc = download_demo_dataset(prj_type, target_path)
  27. if prj_type not in load_demo_proc_dict:
  28. load_demo_proc_dict[prj_type] = []
  29. load_demo_proc_dict[prj_type].append(proc)
  30. return {'status': 1}
  31. def load_demo_project(data, workspace, monitored_processes,
  32. load_demo_proj_data_dict, load_demo_proc_dict):
  33. """导入样例工程
  34. Args:
  35. data为dict, key包括
  36. 'prj_type' 样例类型(ProjectType)
  37. """
  38. prj_type = ProjectType(data['prj_type'])
  39. assert prj_type.value >= 0 and prj_type.value <= 4, "不支持此样例类型的导入(type:{})".format(
  40. prj_type)
  41. target_path = osp.join(workspace.path, "demo_datasets")
  42. assert osp.exists(target_path), "样例数据集暂未下载,无法导入样例工程"
  43. target_path = osp.join(target_path, prj_type.name)
  44. assert osp.exists(target_path), "样例{}数据集暂未下载,无法导入样例工程".format(
  45. prj_type.name)
  46. status = get_folder_status(target_path)
  47. assert status == DownloadStatus.XDDECOMPRESSED, "样例{}数据集暂未解压,无法导入样例工程".format(
  48. prj_type.name)
  49. from .dataset.operate import dataset_url_list
  50. url = dataset_url_list[prj_type.value]
  51. fname = osp.split(url)[-1]
  52. for suffix in ['tar', 'tgz', 'zip']:
  53. pos = fname.find(suffix)
  54. if pos >= 2:
  55. fname = fname[0:pos - 1]
  56. break
  57. source_dataset_path = osp.join(target_path, fname)
  58. params_path = osp.join(target_path, fname, fname + "_params.json")
  59. params = {}
  60. with open(params_path, "r", encoding="utf-8") as f:
  61. params = json.load(f)
  62. dataset_params = params['dataset_info']
  63. proj_params = params['project_info']
  64. train_params = params['train_params']
  65. # 判断数据集、项目名称是否已存在
  66. dataset_name = dataset_params['name']
  67. project_name = proj_params['name']
  68. for id in workspace.datasets:
  69. if dataset_name == workspace.datasets[id].name:
  70. return {'status': 1, 'loading_status': 'dataset already exists'}
  71. for id in workspace.projects:
  72. if project_name == workspace.projects[id].name:
  73. return {'status': 1, 'loading_status': 'project already exists'}
  74. # 创建数据集
  75. from .dataset.dataset import create_dataset
  76. results = create_dataset(dataset_params, workspace)
  77. dataset_id = results['did']
  78. # 导入数据集
  79. from .dataset.dataset import import_dataset
  80. data = {'did': dataset_id, 'path': source_dataset_path}
  81. import_dataset(data, workspace, monitored_processes, load_demo_proc_dict)
  82. # 创建项目
  83. from .project.project import create_project
  84. results = create_project(proj_params, workspace)
  85. pid = results['pid']
  86. # 绑定数据集
  87. from .workspace import set_attr
  88. attr_dict = {'did': dataset_id}
  89. params = {'struct': 'project', 'id': pid, 'attr_dict': attr_dict}
  90. set_attr(params, workspace)
  91. # 创建任务
  92. task_params = PARAMS_CLASS_LIST[prj_type.value]()
  93. for k, v in train_params.items():
  94. if hasattr(task_params, k):
  95. setattr(task_params, k, v)
  96. task_params = CustomEncoder().encode(task_params)
  97. from .project.task import create_task
  98. params = {'pid': pid, 'train': task_params}
  99. create_task(params, workspace)
  100. load_demo_proj_data_dict[prj_type] = (pid, dataset_id)
  101. return {'status': 1, 'did': dataset_id, 'pid': pid}
  102. def get_download_demo_progress(data, workspace):
  103. """查询样例工程的下载进度
  104. Args:
  105. data为dict, key包括
  106. 'prj_type' 样例类型(ProjectType)
  107. """
  108. if isinstance(data['prj_type'], str):
  109. target_path = osp.join(workspace.path, "demo_datasets",
  110. data['prj_type'])
  111. else:
  112. prj_type = ProjectType(data['prj_type'])
  113. target_path = osp.join(workspace.path, "demo_datasets", prj_type.name)
  114. status, message = get_folder_status(target_path, True)
  115. if status == DownloadStatus.XDDOWNLOADING:
  116. if isinstance(data['prj_type'], str):
  117. from .dataset.operate import dataset_url_dict
  118. url = dataset_url_dict[data['prj_type']]
  119. else:
  120. from .dataset.operate import dataset_url_list
  121. url = dataset_url_list[prj_type.value]
  122. fname = osp.split(url)[-1] + "_tmp"
  123. fullname = osp.join(target_path, fname)
  124. total_size = int(message)
  125. download_size = osp.getsize(fullname)
  126. message = download_size * 100 / total_size
  127. if status is not None:
  128. attr = {'status': status.value, 'progress': message}
  129. else:
  130. attr = {'status': status, 'progress': message}
  131. return {'status': 1, 'attr': attr}
  132. def stop_import_demo(data, workspace, load_demo_proc_dict,
  133. load_demo_proj_data_dict):
  134. """停止样例工程的导入进度
  135. Args:
  136. request(comm.Request): 其中request.params为dict, key包括
  137. 'prj_type' 样例类型(ProjectType)
  138. """
  139. prj_type = ProjectType(data['prj_type'])
  140. for proc in load_demo_proc_dict[prj_type]:
  141. if proc.is_alive():
  142. proc.terminate()
  143. # 只删除未完成导入的样例项目
  144. if prj_type in load_demo_proj_data_dict:
  145. pid, did = load_demo_proj_data_dict[prj_type]
  146. params = {'did': did}
  147. from .dataset.dataset import get_dataset_status
  148. results = get_dataset_status(params, workspace)
  149. dataset_status = DatasetStatus(results['dataset_status'])
  150. if dataset_status not in [
  151. DatasetStatus.XCOPYDONE, DatasetStatus.XSPLITED
  152. ]:
  153. params = {'pid': pid}
  154. from .project.project import delete_project
  155. delete_project(params, workspace)
  156. from .dataset.dataset import delete_dataset
  157. params = {'did': did}
  158. delete_dataset(params, workspace)
  159. return {'status': 1}