浏览代码

modify visualize to support np.ndarray

jiangjiajun 5 年之前
父节点
当前提交
c861b55194
共有 1 个文件被更改,包括 14 次插入4 次删除
  1. 14 4
      paddlex/cv/models/utils/visualize.py

+ 14 - 4
paddlex/cv/models/utils/visualize.py

@@ -16,6 +16,7 @@ import os
 import cv2
 import colorsys
 import numpy as np
+import time
 import paddlex.utils.logging as logging
 from .detection_eval import fixed_linspace, backup_linspace, loadRes
 
@@ -25,8 +26,12 @@ def visualize_detection(image, result, threshold=0.5, save_dir='./'):
         Visualize bbox and mask results
     """
 
-    image_name = os.path.split(image)[-1]
-    image = cv2.imread(image)
+    if isinstance(image, np.ndarray):
+        image_name = str(int(time.time())) + '.jpg'
+    else:
+        image = cv2.imread(image)
+        image_name = os.path.split(image)[-1]
+
     image = draw_bbox_mask(image, result, threshold=threshold)
     if save_dir is not None:
         if not os.path.exists(save_dir):
@@ -56,13 +61,18 @@ def visualize_segmentation(image, result, weight=0.6, save_dir='./'):
     c3 = cv2.LUT(label_map, color_map[:, 2])
     pseudo_img = np.dstack((c1, c2, c3))
 
-    im = cv2.imread(image)
+    if isinstance(image, np.ndarray):
+        im = image
+        image_name = str(int(time.time())) + '.jpg'
+    else:
+        image = cv2.imread(image)
+        image_name = os.path.split(image)[-1]
+
     vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
 
     if save_dir is not None:
         if not os.path.exists(save_dir):
             os.makedirs(save_dir)
-        image_name = os.path.split(image)[-1]
         out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
         cv2.imwrite(out_path, vis_result)
         logging.info('The visualized result is saved as {}'.format(out_path))