Просмотр исходного кода

Merge pull request #449 from syyxsxx/restful

add export inference_model with default params epoch and quant
Jason 4 лет назад
Родитель
Сommit
55c35b2975

+ 16 - 1
docs/gui/restful_api.md

@@ -203,6 +203,7 @@ methods=='GET':获取某个数据集的详细信息
 			'test_files'(list): 测试集文件列表,相对于据集地址的相对路径
 			'class_train_file_list(dict)':类别与训练集映射表,key为类别、value为训练图片相对于据集地址的相对路径
 			'class_val_file_list(dict)':类别与评估集映射表,key为类别、value为评估图片相对于据集地址的相对路径
+			'class_test_file_list(dict)':类别与测试集映射表,key为类别、value为测试图片相对于据集地址的相对路径
 		}
 ```
 
@@ -328,6 +329,7 @@ methods=='GET':#获取某个任务的信息或者所有任务的信息
 		if 'tid' in Args:
 			task_status(int):任务状态(TaskStatus)枚举变量的值
 			message(str):任务状态信息
+			type(str):任务类型包括{'classification', 'detection', 'segmentation', 'instance_segmentation'}
 			resumable(bool):仅Args中存在resume时返回,任务训练是否可以恢复
 			max_saved_epochs(int):仅Args中存在resume时返回,当前训练模型保存的最大epoch
 		else:
@@ -339,6 +341,18 @@ methods=='GET':#获取某个任务的信息或者所有任务的信息
 	Example2:
 		#获取所有任务的信息
 		ret = requests.get(url + '/project/task')
+	Ruturn中的自定数据结构:
+		所有任务属性(tasks),任务属性attr(dict)的list
+		attr{
+			'id'(str): 任务id
+			'name'(str): 任务名字
+			'desc'(str): 任务详细描述
+			'pid'(str): 任务所属的项目id
+			'path'(str): 任务在工作空间的路径
+			'create_time'(str): 任务创建时间
+			'status(int)':任务状态(TaskStatus)枚举变量的值
+			'type(str)':任务类型包括{'classification', 'detection', 'segmentation', 'instance_segmentation'}
+		}
 
 methods=='POST':#创建任务(训练或者裁剪)
 	Args:
@@ -637,7 +651,8 @@ methods=='POST':#导出inference模型或者导出lite模型
 		tid(str):任务id
 		type(str):保存模型的类别[infer,lite],支持inference模型导出和lite的模型导出
 		save_dir(str):保存模型的路径
-		quant(bool,optional)可选,type为infer有效,是否导出量化后的模型
+		epoch(str,optional)可选,指定导出的epoch数默认为评估效果最好的epoch
+		quant(bool,optional)可选,type为infer有效,是否导出量化后的模型,默认为False
 		model_path(str,optional)可选,type为lite时有效,inference模型的地址
 	Return:
 		status

+ 3 - 1
paddlex/restful/app.py

@@ -302,6 +302,7 @@ def task():
             if 'tid' in Args:
                 task_status(int):任务状态(TaskStatus)枚举变量的值
                 message(str):任务状态信息
+                type:任务类型包括{'classification', 'detection', 'segmentation', 'instance_segmentation'}
                 resumable(bool):仅Args中存在resume时返回,任务训练是否可以恢复
                 max_saved_epochs(int):仅Args中存在resume时返回,当前训练模型保存的最大epoch
             else:
@@ -630,7 +631,8 @@ def task_export():
             tid(str):任务id
             type(str):保存模型的类别[infer,lite],支持inference模型导出和lite的模型导出
             save_dir(str):保存模型的路径
-            quant(bool,optional)可选,type为infer有效,是否导出量化后的模型
+            epoch(str,optional)可选,指定导出的epoch数默认为评估效果最好的epoch
+            quant(bool,optional)可选,type为infer有效,是否导出量化后的模型,默认为False
             model_path(str,optional)可选,type为lite时有效,inference模型的地址
         Return:
             status

+ 12 - 6
paddlex/restful/project/operate.py

@@ -335,8 +335,11 @@ def _call_paddlex_export_infer(task_path, save_dir, export_status_path, epoch):
     import os
     os.environ['CUDA_VISIBLE_DEVICES'] = ''
     import paddlex as pdx
-    model_dir = "epoch_{}".format(epoch)
-    model_path = osp.join(task_path, 'output', model_dir)
+    if epoch is not None:
+        model_dir = "epoch_{}".format(epoch)
+        model_path = osp.join(task_path, 'output', model_dir)
+    else:
+        model_path = osp.join(task_path, 'output', 'best_model')
     model = pdx.load_model(model_path)
     model.export_inference_model(save_dir)
     set_folder_status(export_status_path, TaskStatus.XEXPORTED)
@@ -353,8 +356,11 @@ def _call_paddlex_export_quant(task_path, params, save_dir, export_status_path,
     task_type = params['task_type']
     os.environ['CUDA_VISIBLE_DEVICES'] = params['train'].cuda_visible_devices
     import paddlex as pdx
-    model_dir = "epoch_{}".format(epoch)
-    model_path = osp.join(task_path, 'output', model_dir)
+    if epoch is not None:
+        model_dir = "epoch_{}".format(epoch)
+        model_path = osp.join(task_path, 'output', model_dir)
+    else:
+        model_path = osp.join(task_path, 'output', 'best_model')
     model = pdx.load_model(model_path)
     if task_type == "classification":
         train_file_list = osp.join(dataset_path, 'train_list.txt')
@@ -823,7 +829,7 @@ def get_export_status(task_path):
     return status, message
 
 
-def export_quant_model(task_path, save_dir, epoch):
+def export_quant_model(task_path, save_dir, epoch=None):
     """导出量化模型
 
     Args:
@@ -850,7 +856,7 @@ def export_quant_model(task_path, save_dir, epoch):
     return p
 
 
-def export_noquant_model(task_path, save_dir, epoch):
+def export_noquant_model(task_path, save_dir, epoch=None):
     """导出inference模型
 
     Args:

+ 15 - 5
paddlex/restful/project/task.py

@@ -141,6 +141,7 @@ def list_tasks(data, workspace):
         task_pid = workspace.tasks[key].pid
         task_path = workspace.tasks[key].path
         task_create_time = workspace.tasks[key].create_time
+        task_type = workspace.projects[task_pid].type
         from .operate import get_task_status
         path = workspace.tasks[task_id].path
         status, message = get_task_status(path)
@@ -155,7 +156,8 @@ def list_tasks(data, workspace):
             "pid": task_pid,
             "path": task_path,
             "create_time": task_create_time,
-            "status": status.value
+            "status": status.value,
+            'type': task_type
         }
         task_list.append(attr)
     return {'status': 1, 'tasks': task_list}
@@ -250,6 +252,8 @@ def get_task_status(data, workspace):
     assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
     path = workspace.tasks[tid].path
     status, message = get_task_status(path)
+    task_pid = workspace.tasks[tid].pid
+    task_type = workspace.projects[task_pid].type
     if 'resume' in data:
         max_saved_epochs = get_task_max_saved_epochs(path)
         params = {'tid': tid}
@@ -261,10 +265,16 @@ def get_task_status(data, workspace):
             'task_status': status.value,
             'message': message,
             'resumable': resumable,
-            'max_saved_epochs': max_saved_epochs
+            'max_saved_epochs': max_saved_epochs,
+            'type': task_type
         }
 
-    return {'status': 1, 'task_status': status.value, 'message': message}
+    return {
+        'status': 1,
+        'task_status': status.value,
+        'message': message,
+        'type': task_type
+    }
 
 
 def get_train_metrics(data, workspace):
@@ -703,8 +713,8 @@ def export_infer_model(data, workspace, monitored_processes):
     from .operate import export_noquant_model, export_quant_model
     tid = data['tid']
     save_dir = data['save_dir']
-    epoch = data['epoch']
-    quant = data['quant']
+    epoch = data['epoch'] if 'epoch' in data else None
+    quant = data['quant'] if 'quant' in data else False
     assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
     path = workspace.tasks[tid].path
     if quant: