瀏覽代碼

Merge pull request #472 from syyxsxx/develop

add generate excel
Jason 4 年之前
父節點
當前提交
a050a14939
共有 6 個文件被更改,包括 201 次插入10 次删除
  1. 2 0
      docs/restful/index.rst
  2. 6 4
      docs/restful/introduction.md
  3. 20 0
      docs/restful/restful_api.md
  4. 25 5
      paddlex/restful/app.py
  5. 78 1
      paddlex/restful/project/task.py
  6. 70 0
      paddlex/restful/utils.py

+ 2 - 0
docs/restful/index.rst

@@ -10,6 +10,8 @@ PaddleX RESTful是基于PaddleX开发的RESTful API。
 
 **paddlex --start_restful --port [端口号] --workspace_dir [工作空间地址]**
 
+**注意:请确保启动RESTful的端口未被防火墙限制**
+
 支持RESTful版本的GUI
 ---------------------------------------
 

+ 6 - 4
docs/restful/introduction.md

@@ -2,7 +2,12 @@
 PaddleX RESTful是基于PaddleX开发的RESTful API。  
 
 对于开发者来说可以通过如下指令启动PaddleX RESTful服务  
-**paddlex --start_restful --port [端口号] --workspace_dir [工作空间地址]**
+**paddlex --start_restful --port [端口号] --workspace_dir [工作空间地址]**  
+
+对于设置workspace在HOME目录的wk文件夹下,RESTful服务端口为8080的命令参考如下:
+![](./img/start_restful.png)  
+
+**注意:请确保启动RESTful的端口未被防火墙限制**
 
 开启RESTful服务后可以实现如下功能:
 
@@ -11,9 +16,6 @@ PaddleX RESTful是基于PaddleX开发的RESTful API。
 - 根据RESTful API来开发您自己个性化的可视化界面。  
 
 
-同样您还可以根据RESTful API来开发自己的可视化界面。  
-
-**paddlex --start_restful --port [端口号] --workspace_dir [工作空间地址]**
 
 ## PaddleX Remote GUI
 PaddleX Remote GUI是针对PaddleX RESTful开发的可视化客户端。开发者可以通过客户端连接开启RESTful服务的服务端,通过GUI实现深度学习全流程:**数据处理** 、 **超参配置** 、 **模型训练及优化** 、 **模型发布**,无需开发一行代码,即可得到高性深度学习推理模型。  

+ 20 - 0
docs/restful/restful_api.md

@@ -525,6 +525,26 @@ methods=='POST':#异步,创建一个评估任务
 		ret = requests.post(url + '/project/task/evaluate',json=params)
 ```
 
+### /project/task/evaluate/file [GET]
+评估结果生成excel表格  
+- GET请求:评估完成的情况下,在服务器端生成评估结果的excel表格
+
+```
+methods=='GET':#评估结果生成excel表格
+	Args:
+		tid(str):任务id
+	Return:
+		path(str):评估结果excel表格在服务器端的路径
+		message(str):提示信息
+		status
+	Example:
+		#任务id为T0001的任务在服务器端生成评估excel表格
+		params = {'tid': 'T0001'}
+		ret = requests.get(url + '/project/task/evaluate/file',json=params)
+		#显示保存路径
+		print(ret.json()['path'])
+```
+
 ### /project/task/metrics [GET]
 获取训练、评估、剪裁的日志和敏感度与模型裁剪率关系图
 - GET请求:通过type来确定需要获取的内容  

+ 25 - 5
paddlex/restful/app.py

@@ -29,11 +29,10 @@ def init(dirname, logger):
     get_system_info(SD.machine_info)
 
 
-'''@app.errorhandler(Exception)
+@app.errorhandler(Exception)
 def handle_exception(e):
     ret = {"status": -1, 'message': repr(e)}
     return ret
-'''
 
 
 @app.route('/workspace', methods=['GET', 'PUT'])
@@ -555,6 +554,7 @@ def task_evaluate():
                 ret['result']['Confusion_Matrix'] = ret['result'][
                     'Confusion_Matrix'].tolist()
             ret['result'] = CustomEncoder().encode(ret['result'])
+            ret['result'] = json.loads(ret['result'])
         ret['evaluate_status'] = ret['evaluate_status'].value
         return ret
     if request.method == 'POST':
@@ -567,8 +567,25 @@ def task_evaluate():
 def task_evaluate_file():
     data = request.get_json()
     if request.method == 'GET':
-        ret = data['path']
-        return send_file(ret)
+        if 'path' in data:
+            ret = data['path']
+            return send_file(ret)
+        else:
+            from .project.task import get_evaluate_result
+            from .project.task import import_evaluate_excel
+            ret = get_evaluate_result(data, SD.workspace)
+            if ret['evaluate_status'] == TaskStatus.XEVALUATED and ret[
+                    'result'] is not None:
+                result = ret['result']
+                excel_ret = dict()
+                excel_ret = import_evaluate_excel(data, result, SD.workspace)
+                return excel_ret
+            else:
+                excel_ret = dict()
+                excel_ret['path'] = None
+                excel_ret['status'] = -1
+                excel_ret['message'] = "评估尚未完成或评估失败"
+                return excel_ret
 
 
 @app.route('/project/task/predict', methods=['GET', 'POST', 'PUT'])
@@ -905,4 +922,7 @@ def run(port, workspace_dir):
             os.makedirs(dirname)
     logger = get_logger(osp.join(dirname, "mcessages.log"))
     init(dirname, logger)
-    app.run(host='0.0.0.0', port=port, threaded=True, debug=True)
+    try:
+        app.run(host='0.0.0.0', port=port, threaded=True)
+    except:
+        print("请确保端口号:{}未被防火墙限制".format(port))

+ 78 - 1
paddlex/restful/project/task.py

@@ -20,7 +20,9 @@ import time
 import pickle
 import json
 import multiprocessing as mp
-from ..utils import set_folder_status, TaskStatus, get_folder_status, is_available, get_ip
+import xlwt
+import numpy as np
+from ..utils import set_folder_status, TaskStatus, get_folder_status, is_available, get_ip, trans_name
 from .train.params import ClsParams, DetParams, SegParams
 
 
@@ -583,6 +585,81 @@ def get_evaluate_result(data, workspace):
     }
 
 
+def import_evaluate_excel(data, result, workspace):
+    excel_ret = dict()
+    workbook = xlwt.Workbook()
+    labels = None
+    START_ROW = 0
+    sheet = workbook.add_sheet("评估报告")
+    if 'label_list' not in result:
+        pass
+    else:
+        labels = result['label_list']
+    for k, v in result.items():
+        if k == 'label_list':
+            continue
+        if type(v) == np.ndarray:
+            sheet.write(START_ROW + 0, 0, trans_name(k))
+            sheet.write(START_ROW + 1, 1, trans_name("Class"))
+            if labels is None:
+                labels = ["{}".format(x) for x in range(len(v))]
+            for i in range(len(labels)):
+                sheet.write(START_ROW + 1, 2 + i, labels[i])
+                sheet.write(START_ROW + 2 + i, 1, labels[i])
+            for i in range(len(labels)):
+                for j in range(len(labels)):
+                    sheet.write(START_ROW + 2 + i, 2 + j, str(v[i, j]))
+            START_ROW = (START_ROW + 4 + len(labels))
+
+        if type(v) == dict:
+            sheet.write(START_ROW + 0, 0, trans_name(k))
+            multi_row = False
+            Cols = ["Class"]
+            for k1, v1 in v.items():
+                if type(v1) == dict:
+                    multi_row = True
+                    for sub_k, sub_v in v1.items():
+                        Cols.append(sub_k)
+                else:
+                    Cols.append(k)
+                break
+            for i in range(len(Cols)):
+                sheet.write(START_ROW + 1, 1 + i, trans_name(Cols[i]))
+
+            index = 2
+            for k1, v1 in v.items():
+                sheet.write(START_ROW + index, 1, k1)
+                if multi_row:
+                    for sub_k, sub_v in v1.items():
+                        sheet.write(START_ROW + index,
+                                    Cols.index(sub_k) + 1, "nan"
+                                    if (sub_v is None) or sub_v == -1 else
+                                    "{:.4f}".format(sub_v))
+                else:
+                    sheet.write(START_ROW + index, 2, "{}".format(v1))
+                index += 1
+            START_ROW = (START_ROW + index + 2)
+        if type(v) in [float, np.float, np.float32, np.float64, type(None)]:
+            front_str = "{}".format(trans_name(k))
+            if k == "Acck":
+                if "topk" in data:
+                    front_str = front_str.format(data["topk"])
+                else:
+                    front_str = front_str.format(5)
+            sheet.write(START_ROW + 0, 0, front_str)
+            sheet.write(START_ROW + 1, 1, "{:.4f}".format(v)
+                        if v is not None else "nan")
+            START_ROW = (START_ROW + 2 + 2)
+    tid = data['tid']
+    path = workspace.tasks[tid].path
+    final_save = os.path.join(path, 'report-task{}.xls'.format(tid))
+    workbook.save(final_save)
+    excel_ret['status'] = 1
+    excel_ret['path'] = final_save
+    excel_ret['message'] = "成功导出结果到excel"
+    return excel_ret
+
+
 def get_predict_status(data, workspace):
     from .operate import get_predict_status
     tid = data['tid']

+ 70 - 0
paddlex/restful/utils.py

@@ -76,6 +76,60 @@ ExportedModelType = Enum(
                           'XQUANTSERVER', 'XPRUNESERVER', 'XTRAINSERVER'),
     start=0)
 
+translate_chinese_table = {
+    "Confusion_matrix": "各个类别之间的混淆矩阵",
+    "Precision": "精准率",
+    "Accuracy": "准确率",
+    "Recall": "召回率",
+    "Class": "类别",
+    "Topk": "K取值",
+    "Auc": "AUC",
+    "Per_ap": "类别平均精准率",
+    "Map": "类别平均精准率(AP)的均值(mAP)",
+    "Mean_iou": "平均交并比",
+    "Mean_acc": "平均准确率",
+    "Category_iou": "各类别交并比",
+    "Category_acc": "各类别准确率",
+    "Ap": "平均精准率",
+    "F1": "F1-score",
+    "Iou": "交并比"
+}
+
+translate_chinese = {
+    "Confusion_matrix": "混淆矩阵",
+    "Mask_confusion_matrix": "Mask混淆矩阵",
+    "Bbox_confusion_matrix": "Bbox混淆矩阵",
+    "Precision": "精准率(Precision)",
+    "Accuracy": "准确率(Accuracy)",
+    "Recall": "召回率(Recall)",
+    "Class": "类别(Class)",
+    "PRF1": "整体分类评估结果",
+    "PRF1_TOPk": "TopK评估结果",
+    "Topk": "K取值",
+    "AUC": "Area Under Curve",
+    "Auc": "Area Under Curve",
+    "F1": "F1-score",
+    "Iou": "交并比(IoU)",
+    "Per_ap": "各类别的平均精准率(AP)",
+    "mAP": "平均精准率的均值(mAP)",
+    "Mask_mAP": "Mask的平均精准率的均值(mAP)",
+    "BBox_mAP": "Bbox的平均精准率的均值(mAP)",
+    "Mean_iou": "平均交并比(mIoU)",
+    "Mean_acc": "平均准确率(mAcc)",
+    "Ap": "平均精准率(Average Precision)",
+    "Category_iou": "各类别的交并比(IoU)",
+    "Category_acc": "各类别的准确率(Accuracy)",
+    "PRAP": "整体检测评估结果",
+    "BBox_PRAP": "Bbox评估结果",
+    "Mask_PRAP": "Mask评估结果",
+    "Overall": "整体平均指标",
+    "PRF1_average": "整体平均指标",
+    "overall_det": "整体平均指标",
+    "PRIoU": "整体平均指标",
+    "Acc1": "预测Top1的准确率",
+    "Acck": "预测Top{}的准确率"
+}
+
 process_pool = Queue(1000)
 
 
@@ -283,6 +337,22 @@ def download(url, target_path):
     return fullname
 
 
+def trans_name(key, in_table=False):
+    if in_table:
+        if key in translate_chinese_table:
+            key = "{}".format(translate_chinese_table[key])
+        if key.capitalize() in translate_chinese_table:
+            key = "{}".format(translate_chinese_table[key.capitalize()])
+        return key
+    else:
+        if key in translate_chinese:
+            key = "{}".format(translate_chinese[key])
+        if key.capitalize() in translate_chinese:
+            key = "{}".format(translate_chinese[key.capitalize()])
+        return key
+    return key
+
+
 def is_pic(filename):
     suffixes = {'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png'}
     suffix = filename.strip().split('.')[-1]