Przeglądaj źródła

add task type for [GET]/project/task

wangsiyuan06 5 lat temu
rodzic
commit
1ffa87c4ce
3 zmienionych plików z 27 dodań i 3 usunięć
  1. 13 0
      docs/gui/restful_api.md
  2. 1 0
      paddlex/restful/app.py
  3. 13 3
      paddlex/restful/project/task.py

+ 13 - 0
docs/gui/restful_api.md

@@ -329,6 +329,7 @@ methods=='GET':#获取某个任务的信息或者所有任务的信息
 		if 'tid' in Args:
 			task_status(int):任务状态(TaskStatus)枚举变量的值
 			message(str):任务状态信息
+			type(str):任务类型包括{'classification', 'detection', 'segmentation', 'instance_segmentation'}
 			resumable(bool):仅Args中存在resume时返回,任务训练是否可以恢复
 			max_saved_epochs(int):仅Args中存在resume时返回,当前训练模型保存的最大epoch
 		else:
@@ -340,6 +341,18 @@ methods=='GET':#获取某个任务的信息或者所有任务的信息
 	Example2:
 		#获取所有任务的信息
 		ret = requests.get(url + '/project/task')
+	Ruturn中的自定数据结构:
+		所有任务属性(tasks),任务属性attr(dict)的list
+		attr{
+			'id'(str): 任务id
+			'name'(str): 任务名字
+			'desc'(str): 任务详细描述
+			'pid'(str): 任务所属的项目id
+			'path'(str): 任务在工作空间的路径
+			'create_time'(str): 任务创建时间
+			'status(int)':任务状态(TaskStatus)枚举变量的值
+			'type(str)':任务类型包括{'classification', 'detection', 'segmentation', 'instance_segmentation'}
+		}
 
 methods=='POST':#创建任务(训练或者裁剪)
 	Args:

+ 1 - 0
paddlex/restful/app.py

@@ -302,6 +302,7 @@ def task():
             if 'tid' in Args:
                 task_status(int):任务状态(TaskStatus)枚举变量的值
                 message(str):任务状态信息
+                type:任务类型包括{'classification', 'detection', 'segmentation', 'instance_segmentation'}
                 resumable(bool):仅Args中存在resume时返回,任务训练是否可以恢复
                 max_saved_epochs(int):仅Args中存在resume时返回,当前训练模型保存的最大epoch
             else:

+ 13 - 3
paddlex/restful/project/task.py

@@ -141,6 +141,7 @@ def list_tasks(data, workspace):
         task_pid = workspace.tasks[key].pid
         task_path = workspace.tasks[key].path
         task_create_time = workspace.tasks[key].create_time
+        task_type = workspace.projects[task_pid].type
         from .operate import get_task_status
         path = workspace.tasks[task_id].path
         status, message = get_task_status(path)
@@ -155,7 +156,8 @@ def list_tasks(data, workspace):
             "pid": task_pid,
             "path": task_path,
             "create_time": task_create_time,
-            "status": status.value
+            "status": status.value,
+            'type': task_type
         }
         task_list.append(attr)
     return {'status': 1, 'tasks': task_list}
@@ -250,6 +252,8 @@ def get_task_status(data, workspace):
     assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
     path = workspace.tasks[tid].path
     status, message = get_task_status(path)
+    task_pid = workspace.tasks[tid].pid
+    task_type = workspace.projects[task_pid].type
     if 'resume' in data:
         max_saved_epochs = get_task_max_saved_epochs(path)
         params = {'tid': tid}
@@ -261,10 +265,16 @@ def get_task_status(data, workspace):
             'task_status': status.value,
             'message': message,
             'resumable': resumable,
-            'max_saved_epochs': max_saved_epochs
+            'max_saved_epochs': max_saved_epochs,
+            'type': task_type
         }
 
-    return {'status': 1, 'task_status': status.value, 'message': message}
+    return {
+        'status': 1,
+        'task_status': status.value,
+        'message': message,
+        'type': task_type
+    }
 
 
 def get_train_metrics(data, workspace):