demo.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import os
  2. import json
  3. from os import path as osp
  4. from .utils import DownloadStatus, DatasetStatus, ProjectType, get_folder_status
  5. from .project.train.params import PARAMS_CLASS_LIST
  6. from .utils import CustomEncoder
  7. def download_demo_dataset(data, workspace, load_demo_proc_dict):
  8. """下载样例工程
  9. Args:
  10. data为dict, key包括
  11. 'prj_type' 样例类型(ProjectType)
  12. """
  13. prj_type = ProjectType(data['prj_type'])
  14. assert prj_type.value >= 0 and prj_type.value <= 4, "不支持此样例类型的导入(type:{})".format(
  15. prj_type)
  16. target_path = osp.join(workspace.path, "demo_datasets")
  17. if not osp.exists(target_path):
  18. os.makedirs(target_path)
  19. from .dataset.operate import download_demo_dataset
  20. proc = download_demo_dataset(prj_type, target_path)
  21. if prj_type not in load_demo_proc_dict:
  22. load_demo_proc_dict[prj_type] = []
  23. load_demo_proc_dict[prj_type].append(proc)
  24. return {'status': 1}
  25. def load_demo_project(data, workspace, monitored_processes,
  26. load_demo_proj_data_dict, load_demo_proc_dict):
  27. """导入样例工程
  28. Args:
  29. data为dict, key包括
  30. 'prj_type' 样例类型(ProjectType)
  31. """
  32. prj_type = ProjectType(data['prj_type'])
  33. assert prj_type.value >= 0 and prj_type.value <= 4, "不支持此样例类型的导入(type:{})".format(
  34. prj_type)
  35. target_path = osp.join(workspace.path, "demo_datasets")
  36. assert osp.exists(target_path), "样例数据集暂未下载,无法导入样例工程"
  37. target_path = osp.join(target_path, prj_type.name)
  38. assert osp.exists(target_path), "样例{}数据集暂未下载,无法导入样例工程".format(
  39. prj_type.name)
  40. status = get_folder_status(target_path)
  41. assert status == DownloadStatus.XDDECOMPRESSED, "样例{}数据集暂未解压,无法导入样例工程".format(
  42. prj_type.name)
  43. from .dataset.operate import dataset_url_list
  44. url = dataset_url_list[prj_type.value]
  45. fname = osp.split(url)[-1]
  46. for suffix in ['tar', 'tgz', 'zip']:
  47. pos = fname.find(suffix)
  48. if pos >= 2:
  49. fname = fname[0:pos - 1]
  50. break
  51. source_dataset_path = osp.join(target_path, fname)
  52. params_path = osp.join(target_path, fname, fname + "_params.json")
  53. params = {}
  54. with open(params_path, "r", encoding="utf-8") as f:
  55. params = json.load(f)
  56. dataset_params = params['dataset_info']
  57. proj_params = params['project_info']
  58. train_params = params['train_params']
  59. # 判断数据集、项目名称是否已存在
  60. dataset_name = dataset_params['name']
  61. project_name = proj_params['name']
  62. for id in workspace.datasets:
  63. if dataset_name == workspace.datasets[id].name:
  64. return {'status': 1, 'loading_status': 'dataset already exists'}
  65. for id in workspace.projects:
  66. if project_name == workspace.projects[id].name:
  67. return {'status': 1, 'loading_status': 'project already exists'}
  68. # 创建数据集
  69. from .dataset.dataset import create_dataset
  70. results = create_dataset(dataset_params, workspace)
  71. dataset_id = results['did']
  72. # 导入数据集
  73. from .dataset.dataset import import_dataset
  74. data = {'did': dataset_id, 'path': source_dataset_path}
  75. import_dataset(data, workspace, monitored_processes, load_demo_proc_dict)
  76. # 创建项目
  77. from .project.project import create_project
  78. results = create_project(proj_params, workspace)
  79. pid = results['pid']
  80. # 绑定数据集
  81. from .workspace import set_attr
  82. attr_dict = {'did': dataset_id}
  83. params = {'struct': 'project', 'id': pid, 'attr_dict': attr_dict}
  84. set_attr(params, workspace)
  85. # 创建任务
  86. task_params = PARAMS_CLASS_LIST[prj_type.value]()
  87. for k, v in train_params.items():
  88. if hasattr(task_params, k):
  89. setattr(task_params, k, v)
  90. task_params = CustomEncoder().encode(task_params)
  91. from .project.task import create_task
  92. params = {'pid': pid, 'train': task_params}
  93. create_task(params, workspace)
  94. load_demo_proj_data_dict[prj_type] = (pid, dataset_id)
  95. return {'status': 1, 'did': dataset_id, 'pid': pid}
  96. def get_download_demo_progress(data, workspace):
  97. """查询样例工程的下载进度
  98. Args:
  99. data为dict, key包括
  100. 'prj_type' 样例类型(ProjectType)
  101. """
  102. prj_type = ProjectType(data['prj_type'])
  103. target_path = osp.join(workspace.path, "demo_datasets", prj_type.name)
  104. status, message = get_folder_status(target_path, True)
  105. if status == DownloadStatus.XDDOWNLOADING:
  106. from .dataset.operate import dataset_url_list
  107. url = dataset_url_list[prj_type.value]
  108. fname = osp.split(url)[-1] + "_tmp"
  109. fullname = osp.join(target_path, fname)
  110. total_size = int(message)
  111. download_size = osp.getsize(fullname)
  112. message = download_size * 100 / total_size
  113. if status is not None:
  114. attr = {'status': status.value, 'progress': message}
  115. else:
  116. attr = {'status': status, 'progress': message}
  117. return {'status': 1, 'attr': attr}
  118. def stop_import_demo(data, workspace, load_demo_proc_dict,
  119. load_demo_proj_data_dict):
  120. """停止样例工程的导入进度
  121. Args:
  122. request(comm.Request): 其中request.params为dict, key包括
  123. 'prj_type' 样例类型(ProjectType)
  124. """
  125. prj_type = ProjectType(data['prj_type'])
  126. for proc in load_demo_proc_dict[prj_type]:
  127. if proc.is_alive():
  128. proc.terminate()
  129. # 只删除未完成导入的样例项目
  130. if prj_type in load_demo_proj_data_dict:
  131. pid, did = load_demo_proj_data_dict[prj_type]
  132. params = {'did': did}
  133. from .dataset.dataset import get_dataset_status
  134. results = get_dataset_status(params, workspace)
  135. dataset_status = DatasetStatus(results['dataset_status'])
  136. if dataset_status not in [
  137. DatasetStatus.XCOPYDONE, DatasetStatus.XSPLITED
  138. ]:
  139. params = {'pid': pid}
  140. from .project.project import delete_project
  141. delete_project(params, workspace)
  142. from .dataset.dataset import delete_dataset
  143. params = {'did': did}
  144. delete_dataset(params, workspace)
  145. return {'status': 1}