瀏覽代碼

fix export infermodel with params epoch

wangsiyuan06 5 年之前
父節點
當前提交
5ea6751778
共有 4 個文件被更改,包括 19 次插入10 次删除
  1. 3 1
      docs/gui/restful_api.md
  2. 2 1
      paddlex/restful/app.py
  3. 12 6
      paddlex/restful/project/operate.py
  4. 2 2
      paddlex/restful/project/task.py

+ 3 - 1
docs/gui/restful_api.md

@@ -203,6 +203,7 @@ methods=='GET':获取某个数据集的详细信息
 			'test_files'(list): 测试集文件列表,相对于据集地址的相对路径
 			'test_files'(list): 测试集文件列表,相对于据集地址的相对路径
 			'class_train_file_list(dict)':类别与训练集映射表,key为类别、value为训练图片相对于据集地址的相对路径
 			'class_train_file_list(dict)':类别与训练集映射表,key为类别、value为训练图片相对于据集地址的相对路径
 			'class_val_file_list(dict)':类别与评估集映射表,key为类别、value为评估图片相对于据集地址的相对路径
 			'class_val_file_list(dict)':类别与评估集映射表,key为类别、value为评估图片相对于据集地址的相对路径
+			'class_test_file_list(dict)':类别与测试集映射表,key为类别、value为测试图片相对于据集地址的相对路径
 		}
 		}
 ```
 ```
 
 
@@ -637,7 +638,8 @@ methods=='POST':#导出inference模型或者导出lite模型
 		tid(str):任务id
 		tid(str):任务id
 		type(str):保存模型的类别[infer,lite],支持inference模型导出和lite的模型导出
 		type(str):保存模型的类别[infer,lite],支持inference模型导出和lite的模型导出
 		save_dir(str):保存模型的路径
 		save_dir(str):保存模型的路径
-		quant(bool,optional)可选,type为infer有效,是否导出量化后的模型
+		epoch(str,optional)可选,指定导出的epoch数默认为评估效果最好的epoch
+		quant(bool,optional)可选,type为infer有效,是否导出量化后的模型,默认为False
 		model_path(str,optional)可选,type为lite时有效,inference模型的地址
 		model_path(str,optional)可选,type为lite时有效,inference模型的地址
 	Return:
 	Return:
 		status
 		status

+ 2 - 1
paddlex/restful/app.py

@@ -630,7 +630,8 @@ def task_export():
             tid(str):任务id
             tid(str):任务id
             type(str):保存模型的类别[infer,lite],支持inference模型导出和lite的模型导出
             type(str):保存模型的类别[infer,lite],支持inference模型导出和lite的模型导出
             save_dir(str):保存模型的路径
             save_dir(str):保存模型的路径
-            quant(bool,optional)可选,type为infer有效,是否导出量化后的模型
+            epoch(str,optional)可选,指定导出的epoch数默认为评估效果最好的epoch
+            quant(bool,optional)可选,type为infer有效,是否导出量化后的模型,默认为False
             model_path(str,optional)可选,type为lite时有效,inference模型的地址
             model_path(str,optional)可选,type为lite时有效,inference模型的地址
         Return:
         Return:
             status
             status

+ 12 - 6
paddlex/restful/project/operate.py

@@ -335,8 +335,11 @@ def _call_paddlex_export_infer(task_path, save_dir, export_status_path, epoch):
     import os
     import os
     os.environ['CUDA_VISIBLE_DEVICES'] = ''
     os.environ['CUDA_VISIBLE_DEVICES'] = ''
     import paddlex as pdx
     import paddlex as pdx
-    model_dir = "epoch_{}".format(epoch)
-    model_path = osp.join(task_path, 'output', model_dir)
+    if epoch is not None:
+        model_dir = "epoch_{}".format(epoch)
+        model_path = osp.join(task_path, 'output', model_dir)
+    else:
+        model_path = osp.join(task_path, 'output', 'best_model')
     model = pdx.load_model(model_path)
     model = pdx.load_model(model_path)
     model.export_inference_model(save_dir)
     model.export_inference_model(save_dir)
     set_folder_status(export_status_path, TaskStatus.XEXPORTED)
     set_folder_status(export_status_path, TaskStatus.XEXPORTED)
@@ -353,8 +356,11 @@ def _call_paddlex_export_quant(task_path, params, save_dir, export_status_path,
     task_type = params['task_type']
     task_type = params['task_type']
     os.environ['CUDA_VISIBLE_DEVICES'] = params['train'].cuda_visible_devices
     os.environ['CUDA_VISIBLE_DEVICES'] = params['train'].cuda_visible_devices
     import paddlex as pdx
     import paddlex as pdx
-    model_dir = "epoch_{}".format(epoch)
-    model_path = osp.join(task_path, 'output', model_dir)
+    if epoch is not None:
+        model_dir = "epoch_{}".format(epoch)
+        model_path = osp.join(task_path, 'output', model_dir)
+    else:
+        model_path = osp.join(task_path, 'output', 'best_model')
     model = pdx.load_model(model_path)
     model = pdx.load_model(model_path)
     if task_type == "classification":
     if task_type == "classification":
         train_file_list = osp.join(dataset_path, 'train_list.txt')
         train_file_list = osp.join(dataset_path, 'train_list.txt')
@@ -823,7 +829,7 @@ def get_export_status(task_path):
     return status, message
     return status, message
 
 
 
 
-def export_quant_model(task_path, save_dir, epoch):
+def export_quant_model(task_path, save_dir, epoch=None):
     """导出量化模型
     """导出量化模型
 
 
     Args:
     Args:
@@ -850,7 +856,7 @@ def export_quant_model(task_path, save_dir, epoch):
     return p
     return p
 
 
 
 
-def export_noquant_model(task_path, save_dir, epoch):
+def export_noquant_model(task_path, save_dir, epoch=None):
     """导出inference模型
     """导出inference模型
 
 
     Args:
     Args:

+ 2 - 2
paddlex/restful/project/task.py

@@ -703,8 +703,8 @@ def export_infer_model(data, workspace, monitored_processes):
     from .operate import export_noquant_model, export_quant_model
     from .operate import export_noquant_model, export_quant_model
     tid = data['tid']
     tid = data['tid']
     save_dir = data['save_dir']
     save_dir = data['save_dir']
-    epoch = data['epoch']
-    quant = data['quant']
+    epoch = data['epoch'] if 'epoch' in data else None
+    quant = data['quant'] if 'quant' in data else False
     assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
     assert tid in workspace.tasks, "任务ID'{}'不存在".format(tid)
     path = workspace.tasks[tid].path
     path = workspace.tasks[tid].path
     if quant:
     if quant: