|
|
@@ -23,7 +23,11 @@ from .detection_eval import fixed_linspace, backup_linspace, loadRes
|
|
|
from paddlex.cv.datasets.dataset import is_pic
|
|
|
|
|
|
|
|
|
-def visualize_detection(image, result, threshold=0.5, save_dir='./'):
|
|
|
+def visualize_detection(image,
|
|
|
+ result,
|
|
|
+ threshold=0.5,
|
|
|
+ save_dir='./',
|
|
|
+ color=None):
|
|
|
"""
|
|
|
Visualize bbox and mask results
|
|
|
"""
|
|
|
@@ -34,7 +38,7 @@ def visualize_detection(image, result, threshold=0.5, save_dir='./'):
|
|
|
image_name = os.path.split(image)[-1]
|
|
|
image = cv2.imread(image)
|
|
|
|
|
|
- image = draw_bbox_mask(image, result, threshold=threshold)
|
|
|
+ image = draw_bbox_mask(image, result, threshold=threshold, color_map=color)
|
|
|
if save_dir is not None:
|
|
|
if not os.path.exists(save_dir):
|
|
|
os.makedirs(save_dir)
|
|
|
@@ -159,7 +163,7 @@ def clip_bbox(bbox):
|
|
|
return xmin, ymin, xmax, ymax
|
|
|
|
|
|
|
|
|
-def draw_bbox_mask(image, results, threshold=0.5):
|
|
|
+def draw_bbox_mask(image, results, threshold=0.5, color_map=None):
|
|
|
import matplotlib
|
|
|
matplotlib.use('Agg')
|
|
|
import matplotlib as mpl
|
|
|
@@ -201,7 +205,18 @@ def draw_bbox_mask(image, results, threshold=0.5):
|
|
|
for dt in np.array(results):
|
|
|
if dt['category'] not in labels:
|
|
|
labels.append(dt['category'])
|
|
|
- color_map = get_color_map_list(256)
|
|
|
+
|
|
|
+ if color_map is None:
|
|
|
+ color_map = get_color_map_list(len(labels) + 2)[2:]
|
|
|
+ else:
|
|
|
+ color_map = np.asarray(color_map)
|
|
|
+ if color_map.shape[0] != len(labels) or color_map.shape[1] != 3:
|
|
|
+ raise Exception(
|
|
|
+ "The shape for color_map is required to be {}x3, but recieved shape is {}x{}.".
|
|
|
+ format(len(labels), color_map.shape))
|
|
|
+ if np.max(color_map) > 255 or np.min(color_map) < 0:
|
|
|
+ raise ValueError(
|
|
|
+ " The values in color_map should be within 0-255 range.")
|
|
|
|
|
|
keep_results = []
|
|
|
areas = []
|
|
|
@@ -222,7 +237,7 @@ def draw_bbox_mask(image, results, threshold=0.5):
|
|
|
xmax = xmin + w
|
|
|
ymax = ymin + h
|
|
|
|
|
|
- color = tuple(color_map[labels.index(cname) + 2])
|
|
|
+ color = tuple(color_map[labels.index(cname)])
|
|
|
color = [c / 255. for c in color]
|
|
|
# draw bbox
|
|
|
ax.add_patch(
|