Эх сурвалжийг харах

Merge pull request #506 from FlyingQianMM/develop_qh

add color setting for pdx.det.visualize
Jason 4 жил өмнө
parent
commit
6467426753

+ 2 - 1
docs/apis/visualize.md

@@ -5,7 +5,7 @@ PaddleX提供了一系列模型预测和结果分析的可视化函数。
 ## paddlex.det.visualize
 > **目标检测/实例分割预测结果可视化**  
 ```
-paddlex.det.visualize(image, result, threshold=0.5, save_dir='./')
+paddlex.det.visualize(image, result, threshold=0.5, save_dir='./', color=None)
 ```
 将目标检测/实例分割模型预测得到的Box框和Mask在原图上进行可视化。
 
@@ -14,6 +14,7 @@ paddlex.det.visualize(image, result, threshold=0.5, save_dir='./')
 > * **result** (str): 模型预测结果。
 > * **threshold**(float): score阈值,将Box置信度低于该阈值的框过滤不进行可视化。默认0.5
 > * **save_dir**(str): 可视化结果保存路径。若为None,则表示不保存,该函数将可视化的结果以np.ndarray的形式返回;若设为目录路径,则将可视化结果保存至该目录下。默认值为'./'。
+> * **color**(list|tuple|np.array): 各类别的BGR颜色值组成的数组,形状为Nx3(N为类别数量),数值范围为[0, 255]。例如针对2个类别的[[255, 0, 0], [0, 255, 0]]。若为None,则自动生成各类别的颜色。默认值为None。
 
 ### 使用示例
 > 点击下载如下示例中的[模型](https://bj.bcebos.com/paddlex/models/xiaoduxiong_epoch_12.tar.gz)

+ 20 - 5
paddlex/cv/models/utils/visualize.py

@@ -23,7 +23,11 @@ from .detection_eval import fixed_linspace, backup_linspace, loadRes
 from paddlex.cv.datasets.dataset import is_pic
 
 
-def visualize_detection(image, result, threshold=0.5, save_dir='./'):
+def visualize_detection(image,
+                        result,
+                        threshold=0.5,
+                        save_dir='./',
+                        color=None):
     """
         Visualize bbox and mask results
     """
@@ -34,7 +38,7 @@ def visualize_detection(image, result, threshold=0.5, save_dir='./'):
         image_name = os.path.split(image)[-1]
         image = cv2.imread(image)
 
-    image = draw_bbox_mask(image, result, threshold=threshold)
+    image = draw_bbox_mask(image, result, threshold=threshold, color_map=color)
     if save_dir is not None:
         if not os.path.exists(save_dir):
             os.makedirs(save_dir)
@@ -159,7 +163,7 @@ def clip_bbox(bbox):
     return xmin, ymin, xmax, ymax
 
 
-def draw_bbox_mask(image, results, threshold=0.5):
+def draw_bbox_mask(image, results, threshold=0.5, color_map=None):
     import matplotlib
     matplotlib.use('Agg')
     import matplotlib as mpl
@@ -201,7 +205,18 @@ def draw_bbox_mask(image, results, threshold=0.5):
     for dt in np.array(results):
         if dt['category'] not in labels:
             labels.append(dt['category'])
-    color_map = get_color_map_list(256)
+
+    if color_map is None:
+        color_map = get_color_map_list(len(labels) + 2)[2:]
+    else:
+        color_map = np.asarray(color_map)
+        if color_map.shape[0] != len(labels) or color_map.shape[1] != 3:
+            raise Exception(
+                "The shape for color_map is required to be {}x3, but recieved shape is {}x{}.".
+                format(len(labels), color_map.shape))
+        if np.max(color_map) > 255 or np.min(color_map) < 0:
+            raise ValueError(
+                " The values in color_map should be within 0-255 range.")
 
     keep_results = []
     areas = []
@@ -222,7 +237,7 @@ def draw_bbox_mask(image, results, threshold=0.5):
         xmax = xmin + w
         ymax = ymin + h
 
-        color = tuple(color_map[labels.index(cname) + 2])
+        color = tuple(color_map[labels.index(cname)])
         color = [c / 255. for c in color]
         # draw bbox
         ax.add_patch(