|
|
@@ -18,7 +18,7 @@ from ..utils import (set_folder_status, get_folder_status, DatasetStatus,
|
|
|
|
|
|
from threading import Thread
|
|
|
import random
|
|
|
-from .utils import copy_directory
|
|
|
+from .utils import copy_directory, get_label_count
|
|
|
import traceback
|
|
|
import shutil
|
|
|
import psutil
|
|
|
@@ -28,6 +28,7 @@ import os.path as osp
|
|
|
import time
|
|
|
import json
|
|
|
import base64
|
|
|
+import cv2
|
|
|
from .. import workspace_pb2 as w
|
|
|
|
|
|
|
|
|
@@ -224,7 +225,7 @@ def split_dataset(data, workspace):
|
|
|
return {'status': 1}
|
|
|
|
|
|
|
|
|
-def img_base64(data):
|
|
|
+def img_base64(data, workspace=None):
|
|
|
"""将数据集切分为训练集、验证集和测试集
|
|
|
|
|
|
Args:
|
|
|
@@ -232,7 +233,42 @@ def img_base64(data):
|
|
|
'path':图片绝对路径
|
|
|
"""
|
|
|
path = data['path']
|
|
|
- print(path)
|
|
|
+ path = '/'.join(path.split('\\'))
|
|
|
+ if 'did' in data:
|
|
|
+ did = data['did']
|
|
|
+ lable_type = workspace.datasets[did].type
|
|
|
+ ds_path = workspace.datasets[did].path
|
|
|
+
|
|
|
+ ret = get_dataset_details(data, workspace)
|
|
|
+ dataset_details = ret['details']
|
|
|
+ ds_label_count = get_label_count(dataset_details['label_info'])
|
|
|
+ image_path = 'JPEGImages/' + path.split('/')[-1]
|
|
|
+ anno = osp.join(ds_path, dataset_details["file_info"][image_path])
|
|
|
+
|
|
|
+ if lable_type == 'detection':
|
|
|
+ from ..project.visualize import plot_det_label
|
|
|
+ labels = list(ds_label_count.keys())
|
|
|
+ img = plot_det_label(path, anno, labels)
|
|
|
+ base64_str = base64.b64encode(cv2.imencode('.png', img)[1]).decode(
|
|
|
+ )
|
|
|
+ return {'status': 1, 'img_data': base64_str}
|
|
|
+ elif lable_type == 'segmentation' or lable_type == 'remote_segmentation':
|
|
|
+ from ..project.visualize import plot_seg_label
|
|
|
+ im = plot_seg_label(anno)
|
|
|
+ img = cv2.imread(path)
|
|
|
+ im = cv2.addWeighted(img, 0.5, im, 0.5, 0).astype('uint8')
|
|
|
+ base64_str = base64.b64encode(cv2.imencode('.png', im)[1]).decode()
|
|
|
+ return {'status': 1, 'img_data': base64_str}
|
|
|
+ elif lable_type == 'instance_segmentation':
|
|
|
+ labels = list(ds_label_count.keys())
|
|
|
+ from ..project.visualize import plot_insseg_label
|
|
|
+ img = plot_insseg_label(path, anno, labels)
|
|
|
+ base64_str = base64.b64encode(cv2.imencode('.png', img)[1]).decode(
|
|
|
+ )
|
|
|
+ return {'status': 1, 'img_data': base64_str}
|
|
|
+ else:
|
|
|
+ raise Exception("数据集类型{}目前暂不支持".format(lable_type))
|
|
|
+
|
|
|
with open(path, 'rb') as f:
|
|
|
base64_data = base64.b64encode(f.read())
|
|
|
base64_str = str(base64_data, 'utf-8')
|