|
|
@@ -5,6 +5,10 @@ from .utils import DownloadStatus, DatasetStatus, ProjectType, get_folder_status
|
|
|
from .project.train.params import PARAMS_CLASS_LIST
|
|
|
from .utils import CustomEncoder
|
|
|
|
|
|
+prj_type_list = [
|
|
|
+ 'classification', 'detection', 'segmentation', 'instance_segmentation'
|
|
|
+]
|
|
|
+
|
|
|
|
|
|
def download_demo_dataset(data, workspace, load_demo_proc_dict):
|
|
|
"""下载样例工程
|
|
|
@@ -13,7 +17,10 @@ def download_demo_dataset(data, workspace, load_demo_proc_dict):
|
|
|
data为dict, key包括
|
|
|
'prj_type' 样例类型(ProjectType)
|
|
|
"""
|
|
|
- prj_type = ProjectType(data['prj_type'])
|
|
|
+ if isinstance(data['prj_type'], str):
|
|
|
+ prj_type = ProjectType(prj_type_list.index(data['prj_type']))
|
|
|
+ else:
|
|
|
+ prj_type = ProjectType(data['prj_type'])
|
|
|
assert prj_type.value >= 0 and prj_type.value <= 4, "不支持此样例类型的导入(type:{})".format(
|
|
|
prj_type)
|
|
|
target_path = osp.join(workspace.path, "demo_datasets")
|
|
|
@@ -118,12 +125,20 @@ def get_download_demo_progress(data, workspace):
|
|
|
data为dict, key包括
|
|
|
'prj_type' 样例类型(ProjectType)
|
|
|
"""
|
|
|
- prj_type = ProjectType(data['prj_type'])
|
|
|
- target_path = osp.join(workspace.path, "demo_datasets", prj_type.name)
|
|
|
+ if isinstance(data['prj_type'], str):
|
|
|
+ target_path = osp.join(workspace.path, "demo_datasets",
|
|
|
+ data['prj_type'])
|
|
|
+ else:
|
|
|
+ prj_type = ProjectType(data['prj_type'])
|
|
|
+ target_path = osp.join(workspace.path, "demo_datasets", prj_type.name)
|
|
|
status, message = get_folder_status(target_path, True)
|
|
|
if status == DownloadStatus.XDDOWNLOADING:
|
|
|
- from .dataset.operate import dataset_url_list
|
|
|
- url = dataset_url_list[prj_type.value]
|
|
|
+ if isinstance(data['prj_type'], str):
|
|
|
+ from .dataset.operate import dataset_url_dict
|
|
|
+ url = dataset_url_dict[data['prj_type']]
|
|
|
+ else:
|
|
|
+ from .dataset.operate import dataset_url_list
|
|
|
+ url = dataset_url_list[prj_type.value]
|
|
|
fname = osp.split(url)[-1] + "_tmp"
|
|
|
fullname = osp.join(target_path, fname)
|
|
|
total_size = int(message)
|