Browse Source

add label visualize

wangsiyuan06 5 years ago
parent
commit
0e19184785

+ 24 - 0
docs/restful/data_struct.md

@@ -123,6 +123,22 @@ TaskStatus = Enum('TaskStatus',
 ),start=0)
 
 ```
+### ProjectType(项目类型)
+ProjectType = Enum('ProjectType',
+('classification',#分类
+'detection',#检测
+'segmentation',#分割
+'instance_segmentation',#实例分割
+'remote_segmentation'#摇杆分割
+),start=0)
+
+### DownloadStatus(下载状态变量)
+DownloadStatus = Enum('DownloadStatus',
+('XDDOWNLOADING',#下载中
+'XDDOWNLOADFAIL',#下载失败
+'XDDOWNLOADDONE',下载完成
+'XDDECOMPRESSED'解压完成
+),start=0)
 
 ### PruneStatus(裁剪状态变量)
 ```
@@ -143,6 +159,14 @@ PredictStatus = Enum('PredictStatus',
 'XPREFAIL'#预测失败
 ), start=0)
 ```
+### PretrainedModelStatus(预训练模型状态变量)
+PretrainedModelStatus = Enum('PretrainedModelStatus',
+('XPINIT', #初始化
+'XPSAVING', #正在保存
+'XPSAVEFAIL',#保存失败
+'XPSAVEDONE' #保存完成
+),start=0)
+
 ### ExportedModelType(模型导出状态变量)
 ```
 ExportedModelType = Enum('ExportedModelType',

BIN
docs/restful/img/restful_api.png


+ 4 - 4
docs/restful/quick_start.md

@@ -1,9 +1,9 @@
 # 快速开始
 
 ## 环境依赖  
-paddlepaddle-gpu/paddlepaddle
-paddlex
-pycocotools
+- paddlepaddle-gpu/paddlepaddle  
+- paddlex  
+- pycocotools  
 
 
 ## 服务端启动PaddleX Restful服务
@@ -39,7 +39,7 @@ url = "https://127.0.0.1:5000"
 params = {"name": "我的第一个数据集", "desc": "这里是数据集的描述文字", "dataset_type": "detection"}  
 ret = requests.post(url+"/dataset", json=params)  
 #获取数据集id
-did = ret.json()['did']
+did = ret.json()['id']
 ```
 
 #### 导入数据集

+ 11 - 0
docs/restful/restful_api.md

@@ -193,6 +193,7 @@ methods=='PUT':切分某个数据集
 methods=='GET':获取服务端的文件,目前支持图片、xml格式文件
 	Args:
 		'path'(str):文件在服务端的路径
+		'did'(str, optional):可选,数据集id仅在文件为图片时有效。若存在返回图片带label可视化。注意当前不支持分类数据集数据的标注可视化
 	Return:
 		#数据为图片
 		img_data(str): base64图片数据
@@ -708,6 +709,16 @@ methods=='POST':#创建一个模型
 			pmid(str):预训练模型id
 		if type == 'exported':
 			emid(str):inference模型id
+	Exampe:
+		#创建一个预训练模型
+		params={
+			pid : 'P0001',
+			tid : 'T0001',
+			name : 'Pretrain_model',
+			type : 'pretrained',
+			source_path : '/path/to/pretrian_model',
+		}
+		ret = requests.post(url + 'model', json=params)
 
 methods=='DELETE':删除一个模型
 	Args:

+ 1 - 1
paddlex/restful/app.py

@@ -231,7 +231,7 @@ def get_file():
             return {'status': -1}
         if is_pic(path):
             from .dataset.dataset import img_base64
-            ret = img_base64(data)
+            ret = img_base64(data, SD.workspace)
             return ret
         file_type = path[(path.rfind('.') + 1):]
         if file_type in ['xml', 'npy', 'log']:

+ 39 - 3
paddlex/restful/dataset/dataset.py

@@ -18,7 +18,7 @@ from ..utils import (set_folder_status, get_folder_status, DatasetStatus,
 
 from threading import Thread
 import random
-from .utils import copy_directory
+from .utils import copy_directory, get_label_count
 import traceback
 import shutil
 import psutil
@@ -28,6 +28,7 @@ import os.path as osp
 import time
 import json
 import base64
+import cv2
 from .. import workspace_pb2 as w
 
 
@@ -224,7 +225,7 @@ def split_dataset(data, workspace):
     return {'status': 1}
 
 
-def img_base64(data):
+def img_base64(data, workspace=None):
     """将数据集切分为训练集、验证集和测试集
 
     Args:
@@ -232,7 +233,42 @@ def img_base64(data):
         'path':图片绝对路径
     """
     path = data['path']
-    print(path)
+    path = '/'.join(path.split('\\'))
+    if 'did' in data:
+        did = data['did']
+        lable_type = workspace.datasets[did].type
+        ds_path = workspace.datasets[did].path
+
+        ret = get_dataset_details(data, workspace)
+        dataset_details = ret['details']
+        ds_label_count = get_label_count(dataset_details['label_info'])
+        image_path = 'JPEGImages/' + path.split('/')[-1]
+        anno = osp.join(ds_path, dataset_details["file_info"][image_path])
+
+        if lable_type == 'detection':
+            from ..project.visualize import plot_det_label
+            labels = list(ds_label_count.keys())
+            img = plot_det_label(path, anno, labels)
+            base64_str = base64.b64encode(cv2.imencode('.png', img)[1]).decode(
+            )
+            return {'status': 1, 'img_data': base64_str}
+        elif lable_type == 'segmentation' or lable_type == 'remote_segmentation':
+            from ..project.visualize import plot_seg_label
+            im = plot_seg_label(anno)
+            img = cv2.imread(path)
+            im = cv2.addWeighted(img, 0.5, im, 0.5, 0).astype('uint8')
+            base64_str = base64.b64encode(cv2.imencode('.png', im)[1]).decode()
+            return {'status': 1, 'img_data': base64_str}
+        elif lable_type == 'instance_segmentation':
+            labels = list(ds_label_count.keys())
+            from ..project.visualize import plot_insseg_label
+            img = plot_insseg_label(path, anno, labels)
+            base64_str = base64.b64encode(cv2.imencode('.png', img)[1]).decode(
+            )
+            return {'status': 1, 'img_data': base64_str}
+        else:
+            raise Exception("数据集类型{}目前暂不支持".format(lable_type))
+
     with open(path, 'rb') as f:
         base64_data = base64.b64encode(f.read())
         base64_str = str(base64_data, 'utf-8')

+ 12 - 0
paddlex/restful/dataset/utils.py

@@ -210,6 +210,18 @@ def get_npy_from_coco_json(coco, npy_path, files):
         np.save(osp.join(npy_path, npy_name), anno_dict)
 
 
+def get_label_count(label_info):
+    """ 根据存储的label_info字段,计算label_count字段
+
+    Args:
+        label_info: 存储的label_info
+    """
+    label_count = dict()
+    for key in sorted(label_info):
+        label_count[key] = len(label_info[key])
+    return label_count
+
+
 class MyEncoder(json.JSONEncoder):
     # 调整json文件存储形式
     def default(self, obj):