demo.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import json
  16. from os import path as osp
  17. from .utils import DownloadStatus, DatasetStatus, ProjectType, get_folder_status
  18. from .project.train.params import PARAMS_CLASS_LIST
  19. from .utils import CustomEncoder
  20. prj_type_list = [
  21. 'classification', 'detection', 'segmentation', 'instance_segmentation'
  22. ]
  23. def download_demo_dataset(data, workspace, load_demo_proc_dict):
  24. """下载样例工程
  25. Args:
  26. data为dict, key包括
  27. 'prj_type' 样例类型(ProjectType)
  28. """
  29. if isinstance(data['prj_type'], str):
  30. prj_type = ProjectType(prj_type_list.index(data['prj_type']))
  31. else:
  32. prj_type = ProjectType(data['prj_type'])
  33. assert prj_type.value >= 0 and prj_type.value <= 4, "不支持此样例类型的导入(type:{})".format(
  34. prj_type)
  35. target_path = osp.join(workspace.path, "demo_datasets")
  36. if not osp.exists(target_path):
  37. os.makedirs(target_path)
  38. from .dataset.operate import download_demo_dataset
  39. proc = download_demo_dataset(prj_type, target_path)
  40. if prj_type not in load_demo_proc_dict:
  41. load_demo_proc_dict[prj_type] = []
  42. load_demo_proc_dict[prj_type].append(proc)
  43. return {'status': 1}
  44. def load_demo_project(data, workspace, monitored_processes,
  45. load_demo_proj_data_dict, load_demo_proc_dict):
  46. """导入样例工程
  47. Args:
  48. data为dict, key包括
  49. 'prj_type' 样例类型(ProjectType)
  50. """
  51. if isinstance(data['prj_type'], str):
  52. prj_type = ProjectType(prj_type_list.index(data['prj_type']))
  53. else:
  54. prj_type = ProjectType(data['prj_type'])
  55. assert prj_type.value >= 0 and prj_type.value <= 4, "不支持此样例类型的导入(type:{})".format(
  56. prj_type)
  57. target_path = osp.join(workspace.path, "demo_datasets")
  58. assert osp.exists(target_path), "样例数据集暂未下载,无法导入样例工程"
  59. target_path = osp.join(target_path, prj_type.name)
  60. assert osp.exists(target_path), "样例{}数据集暂未下载,无法导入样例工程".format(
  61. prj_type.name)
  62. status = get_folder_status(target_path)
  63. assert status == DownloadStatus.XDDECOMPRESSED, "样例{}数据集暂未解压,无法导入样例工程".format(
  64. prj_type.name)
  65. from .dataset.operate import dataset_url_list
  66. url = dataset_url_list[prj_type.value]
  67. fname = osp.split(url)[-1]
  68. for suffix in ['tar', 'tgz', 'zip']:
  69. pos = fname.find(suffix)
  70. if pos >= 2:
  71. fname = fname[0:pos - 1]
  72. break
  73. source_dataset_path = osp.join(target_path, fname)
  74. params_path = osp.join(target_path, fname, fname + "_params.json")
  75. params = {}
  76. with open(params_path, "r", encoding="utf-8") as f:
  77. params = json.load(f)
  78. dataset_params = params['dataset_info']
  79. proj_params = params['project_info']
  80. train_params = params['train_params']
  81. # 判断数据集、项目名称是否已存在
  82. dataset_name = dataset_params['name']
  83. project_name = proj_params['name']
  84. for id in workspace.datasets:
  85. if dataset_name == workspace.datasets[id].name:
  86. return {'status': 1, 'loading_status': 'dataset already exists'}
  87. for id in workspace.projects:
  88. if project_name == workspace.projects[id].name:
  89. return {'status': 1, 'loading_status': 'project already exists'}
  90. # 创建数据集
  91. from .dataset.dataset import create_dataset
  92. results = create_dataset(dataset_params, workspace)
  93. dataset_id = results['did']
  94. # 导入数据集
  95. from .dataset.dataset import import_dataset
  96. data = {'did': dataset_id, 'path': source_dataset_path}
  97. import_dataset(data, workspace, monitored_processes, load_demo_proc_dict)
  98. # 创建项目
  99. from .project.project import create_project
  100. results = create_project(proj_params, workspace)
  101. pid = results['pid']
  102. # 绑定数据集
  103. from .workspace import set_attr
  104. attr_dict = {'did': dataset_id}
  105. params = {'struct': 'project', 'id': pid, 'attr_dict': attr_dict}
  106. set_attr(params, workspace)
  107. # 创建任务
  108. task_params = PARAMS_CLASS_LIST[prj_type.value]()
  109. for k, v in train_params.items():
  110. if hasattr(task_params, k):
  111. setattr(task_params, k, v)
  112. task_params = CustomEncoder().encode(task_params)
  113. from .project.task import create_task
  114. params = {'pid': pid, 'train': task_params}
  115. create_task(params, workspace)
  116. load_demo_proj_data_dict[prj_type] = (pid, dataset_id)
  117. return {'status': 1, 'did': dataset_id, 'pid': pid}
  118. def get_download_demo_progress(data, workspace):
  119. """查询样例工程的下载进度
  120. Args:
  121. data为dict, key包括
  122. 'prj_type' 样例类型(ProjectType)
  123. """
  124. if isinstance(data['prj_type'], str):
  125. target_path = osp.join(workspace.path, "demo_datasets",
  126. data['prj_type'])
  127. else:
  128. prj_type = ProjectType(data['prj_type'])
  129. target_path = osp.join(workspace.path, "demo_datasets", prj_type.name)
  130. status, message = get_folder_status(target_path, True)
  131. if status == DownloadStatus.XDDOWNLOADING:
  132. if isinstance(data['prj_type'], str):
  133. from .dataset.operate import dataset_url_dict
  134. url = dataset_url_dict[data['prj_type']]
  135. else:
  136. from .dataset.operate import dataset_url_list
  137. url = dataset_url_list[prj_type.value]
  138. fname = osp.split(url)[-1] + "_tmp"
  139. fullname = osp.join(target_path, fname)
  140. total_size = int(message)
  141. download_size = osp.getsize(fullname)
  142. message = download_size * 100 / total_size
  143. if status is not None:
  144. attr = {'status': status.value, 'progress': message}
  145. else:
  146. attr = {'status': status, 'progress': message}
  147. return {'status': 1, 'attr': attr}
  148. def stop_import_demo(data, workspace, load_demo_proc_dict,
  149. load_demo_proj_data_dict):
  150. """停止样例工程的导入进度
  151. Args:
  152. request(comm.Request): 其中request.params为dict, key包括
  153. 'prj_type' 样例类型(ProjectType)
  154. """
  155. if isinstance(data['prj_type'], str):
  156. prj_type = ProjectType(prj_type_list.index(data['prj_type']))
  157. else:
  158. prj_type = ProjectType(data['prj_type'])
  159. for proc in load_demo_proc_dict[prj_type]:
  160. if proc.is_alive():
  161. proc.terminate()
  162. # 只删除未完成导入的样例项目
  163. if prj_type in load_demo_proj_data_dict:
  164. pid, did = load_demo_proj_data_dict[prj_type]
  165. params = {'did': did}
  166. from .dataset.dataset import get_dataset_status
  167. results = get_dataset_status(params, workspace)
  168. dataset_status = DatasetStatus(results['dataset_status'])
  169. if dataset_status not in [
  170. DatasetStatus.XCOPYDONE, DatasetStatus.XSPLITED
  171. ]:
  172. params = {'pid': pid}
  173. from .project.project import delete_project
  174. delete_project(params, workspace)
  175. from .dataset.dataset import delete_dataset
  176. params = {'did': did}
  177. delete_dataset(params, workspace)
  178. return {'status': 1}