visualizer.py 49 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. import colorsys
  3. import logging
  4. import math
  5. import numpy as np
  6. from enum import Enum, unique
  7. import cv2
  8. import matplotlib as mpl
  9. import matplotlib.colors as mplc
  10. import matplotlib.figure as mplfigure
  11. import pycocotools.mask as mask_util
  12. import torch
  13. from matplotlib.backends.backend_agg import FigureCanvasAgg
  14. from PIL import Image
  15. from detectron2.data import MetadataCatalog
  16. from detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes
  17. from detectron2.utils.file_io import PathManager
  18. from detectron2.utils.colormap import random_color
  19. import pdb
  20. logger = logging.getLogger(__name__)
  21. __all__ = ["ColorMode", "VisImage", "Visualizer"]
  22. _SMALL_OBJECT_AREA_THRESH = 1000
  23. _LARGE_MASK_AREA_THRESH = 120000
  24. _OFF_WHITE = (1.0, 1.0, 240.0 / 255)
  25. _BLACK = (0, 0, 0)
  26. _RED = (1.0, 0, 0)
  27. _KEYPOINT_THRESHOLD = 0.05
  28. #CLASS_NAMES = ["footnote", "footer", "header"]
  29. @unique
  30. class ColorMode(Enum):
  31. """
  32. Enum of different color modes to use for instance visualizations.
  33. """
  34. IMAGE = 0
  35. """
  36. Picks a random color for every instance and overlay segmentations with low opacity.
  37. """
  38. SEGMENTATION = 1
  39. """
  40. Let instances of the same category have similar colors
  41. (from metadata.thing_colors), and overlay them with
  42. high opacity. This provides more attention on the quality of segmentation.
  43. """
  44. IMAGE_BW = 2
  45. """
  46. Same as IMAGE, but convert all areas without masks to gray-scale.
  47. Only available for drawing per-instance mask predictions.
  48. """
  49. class GenericMask:
  50. """
  51. Attribute:
  52. polygons (list[ndarray]): list[ndarray]: polygons for this mask.
  53. Each ndarray has format [x, y, x, y, ...]
  54. mask (ndarray): a binary mask
  55. """
  56. def __init__(self, mask_or_polygons, height, width):
  57. self._mask = self._polygons = self._has_holes = None
  58. self.height = height
  59. self.width = width
  60. m = mask_or_polygons
  61. if isinstance(m, dict):
  62. # RLEs
  63. assert "counts" in m and "size" in m
  64. if isinstance(m["counts"], list): # uncompressed RLEs
  65. h, w = m["size"]
  66. assert h == height and w == width
  67. m = mask_util.frPyObjects(m, h, w)
  68. self._mask = mask_util.decode(m)[:, :]
  69. return
  70. if isinstance(m, list): # list[ndarray]
  71. self._polygons = [np.asarray(x).reshape(-1) for x in m]
  72. return
  73. if isinstance(m, np.ndarray): # assumed to be a binary mask
  74. assert m.shape[1] != 2, m.shape
  75. assert m.shape == (
  76. height,
  77. width,
  78. ), f"mask shape: {m.shape}, target dims: {height}, {width}"
  79. self._mask = m.astype("uint8")
  80. return
  81. raise ValueError("GenericMask cannot handle object {} of type '{}'".format(m, type(m)))
  82. @property
  83. def mask(self):
  84. if self._mask is None:
  85. self._mask = self.polygons_to_mask(self._polygons)
  86. return self._mask
  87. @property
  88. def polygons(self):
  89. if self._polygons is None:
  90. self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
  91. return self._polygons
  92. @property
  93. def has_holes(self):
  94. if self._has_holes is None:
  95. if self._mask is not None:
  96. self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
  97. else:
  98. self._has_holes = False # if original format is polygon, does not have holes
  99. return self._has_holes
  100. def mask_to_polygons(self, mask):
  101. # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
  102. # hierarchy. External contours (boundary) of the object are placed in hierarchy-1.
  103. # Internal contours (holes) are placed in hierarchy-2.
  104. # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.
  105. mask = np.ascontiguousarray(mask) # some versions of cv2 does not support incontiguous arr
  106. res = cv2.findContours(mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
  107. hierarchy = res[-1]
  108. if hierarchy is None: # empty mask
  109. return [], False
  110. has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
  111. res = res[-2]
  112. res = [x.flatten() for x in res]
  113. # These coordinates from OpenCV are integers in range [0, W-1 or H-1].
  114. # We add 0.5 to turn them into real-value coordinate space. A better solution
  115. # would be to first +0.5 and then dilate the returned polygon by 0.5.
  116. res = [x + 0.5 for x in res if len(x) >= 6]
  117. return res, has_holes
  118. def polygons_to_mask(self, polygons):
  119. rle = mask_util.frPyObjects(polygons, self.height, self.width)
  120. rle = mask_util.merge(rle)
  121. return mask_util.decode(rle)[:, :]
  122. def area(self):
  123. return self.mask.sum()
  124. def bbox(self):
  125. p = mask_util.frPyObjects(self.polygons, self.height, self.width)
  126. p = mask_util.merge(p)
  127. bbox = mask_util.toBbox(p)
  128. bbox[2] += bbox[0]
  129. bbox[3] += bbox[1]
  130. return bbox
  131. class _PanopticPrediction:
  132. """
  133. Unify different panoptic annotation/prediction formats
  134. """
  135. def __init__(self, panoptic_seg, segments_info, metadata=None):
  136. if segments_info is None:
  137. assert metadata is not None
  138. # If "segments_info" is None, we assume "panoptic_img" is a
  139. # H*W int32 image storing the panoptic_id in the format of
  140. # category_id * label_divisor + instance_id. We reserve -1 for
  141. # VOID label.
  142. label_divisor = metadata.label_divisor
  143. segments_info = []
  144. for panoptic_label in np.unique(panoptic_seg.numpy()):
  145. if panoptic_label == -1:
  146. # VOID region.
  147. continue
  148. pred_class = panoptic_label // label_divisor
  149. isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
  150. segments_info.append(
  151. {
  152. "id": int(panoptic_label),
  153. "category_id": int(pred_class),
  154. "isthing": bool(isthing),
  155. }
  156. )
  157. del metadata
  158. self._seg = panoptic_seg
  159. self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info
  160. segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True)
  161. areas = areas.numpy()
  162. sorted_idxs = np.argsort(-areas)
  163. self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs]
  164. self._seg_ids = self._seg_ids.tolist()
  165. for sid, area in zip(self._seg_ids, self._seg_areas):
  166. if sid in self._sinfo:
  167. self._sinfo[sid]["area"] = float(area)
  168. def non_empty_mask(self):
  169. """
  170. Returns:
  171. (H, W) array, a mask for all pixels that have a prediction
  172. """
  173. empty_ids = []
  174. for id in self._seg_ids:
  175. if id not in self._sinfo:
  176. empty_ids.append(id)
  177. if len(empty_ids) == 0:
  178. return np.zeros(self._seg.shape, dtype=np.uint8)
  179. assert (
  180. len(empty_ids) == 1
  181. ), ">1 ids corresponds to no labels. This is currently not supported"
  182. return (self._seg != empty_ids[0]).numpy().astype(np.bool)
  183. def semantic_masks(self):
  184. for sid in self._seg_ids:
  185. sinfo = self._sinfo.get(sid)
  186. if sinfo is None or sinfo["isthing"]:
  187. # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions.
  188. continue
  189. yield (self._seg == sid).numpy().astype(np.bool), sinfo
  190. def instance_masks(self):
  191. for sid in self._seg_ids:
  192. sinfo = self._sinfo.get(sid)
  193. if sinfo is None or not sinfo["isthing"]:
  194. continue
  195. mask = (self._seg == sid).numpy().astype(np.bool)
  196. if mask.sum() > 0:
  197. yield mask, sinfo
  198. def _create_text_labels(classes, scores, class_names, is_crowd=None):
  199. """
  200. Args:
  201. classes (list[int] or None):
  202. scores (list[float] or None):
  203. class_names (list[str] or None):
  204. is_crowd (list[bool] or None):
  205. Returns:
  206. list[str] or None
  207. """
  208. #class_names = CLASS_NAMES
  209. labels = None
  210. if classes is not None:
  211. if class_names is not None and len(class_names) > 0:
  212. labels = [class_names[i] for i in classes]
  213. else:
  214. labels = [str(i) for i in classes]
  215. if scores is not None:
  216. if labels is None:
  217. labels = ["{:.0f}%".format(s * 100) for s in scores]
  218. else:
  219. labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)]
  220. if labels is not None and is_crowd is not None:
  221. labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)]
  222. return labels
  223. class VisImage:
  224. def __init__(self, img, scale=1.0):
  225. """
  226. Args:
  227. img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
  228. scale (float): scale the input image
  229. """
  230. self.img = img
  231. self.scale = scale
  232. self.width, self.height = img.shape[1], img.shape[0]
  233. self._setup_figure(img)
  234. def _setup_figure(self, img):
  235. """
  236. Args:
  237. Same as in :meth:`__init__()`.
  238. Returns:
  239. fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
  240. ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
  241. """
  242. fig = mplfigure.Figure(frameon=False)
  243. self.dpi = fig.get_dpi()
  244. # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
  245. # (https://github.com/matplotlib/matplotlib/issues/15363)
  246. fig.set_size_inches(
  247. (self.width * self.scale + 1e-2) / self.dpi,
  248. (self.height * self.scale + 1e-2) / self.dpi,
  249. )
  250. self.canvas = FigureCanvasAgg(fig)
  251. # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
  252. ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
  253. ax.axis("off")
  254. self.fig = fig
  255. self.ax = ax
  256. self.reset_image(img)
  257. def reset_image(self, img):
  258. """
  259. Args:
  260. img: same as in __init__
  261. """
  262. img = img.astype("uint8")
  263. self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
  264. def save(self, filepath):
  265. """
  266. Args:
  267. filepath (str): a string that contains the absolute path, including the file name, where
  268. the visualized image will be saved.
  269. """
  270. self.fig.savefig(filepath)
  271. def get_image(self):
  272. """
  273. Returns:
  274. ndarray:
  275. the visualized image of shape (H, W, 3) (RGB) in uint8 type.
  276. The shape is scaled w.r.t the input image using the given `scale` argument.
  277. """
  278. canvas = self.canvas
  279. s, (width, height) = canvas.print_to_buffer()
  280. # buf = io.BytesIO() # works for cairo backend
  281. # canvas.print_rgba(buf)
  282. # width, height = self.width, self.height
  283. # s = buf.getvalue()
  284. buffer = np.frombuffer(s, dtype="uint8")
  285. img_rgba = buffer.reshape(height, width, 4)
  286. rgb, alpha = np.split(img_rgba, [3], axis=2)
  287. return rgb.astype("uint8")
  288. class Visualizer:
  289. """
  290. Visualizer that draws data about detection/segmentation on images.
  291. It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
  292. that draw primitive objects to images, as well as high-level wrappers like
  293. `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
  294. that draw composite data in some pre-defined style.
  295. Note that the exact visualization style for the high-level wrappers are subject to change.
  296. Style such as color, opacity, label contents, visibility of labels, or even the visibility
  297. of objects themselves (e.g. when the object is too small) may change according
  298. to different heuristics, as long as the results still look visually reasonable.
  299. To obtain a consistent style, you can implement custom drawing functions with the
  300. abovementioned primitive methods instead. If you need more customized visualization
  301. styles, you can process the data yourself following their format documented in
  302. tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
  303. intend to satisfy everyone's preference on drawing styles.
  304. This visualizer focuses on high rendering quality rather than performance. It is not
  305. designed to be used for real-time applications.
  306. """
  307. # TODO implement a fast, rasterized version using OpenCV
  308. def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE):
  309. """
  310. Args:
  311. img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
  312. the height and width of the image respectively. C is the number of
  313. color channels. The image is required to be in RGB format since that
  314. is a requirement of the Matplotlib library. The image is also expected
  315. to be in the range [0, 255].
  316. metadata (Metadata): dataset metadata (e.g. class names and colors)
  317. instance_mode (ColorMode): defines one of the pre-defined style for drawing
  318. instances on an image.
  319. """
  320. self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
  321. if metadata is None:
  322. metadata = MetadataCatalog.get("__nonexist__")
  323. self.metadata = metadata
  324. self.output = VisImage(self.img, scale=scale)
  325. self.cpu_device = torch.device("cpu")
  326. # too small texts are useless, therefore clamp to 9
  327. self._default_font_size = max(
  328. np.sqrt(self.output.height * self.output.width) // 90, 10 // scale
  329. )
  330. self._instance_mode = instance_mode
  331. self.keypoint_threshold = _KEYPOINT_THRESHOLD
  332. def draw_instance_predictions(self, predictions):
  333. """
  334. Draw instance-level prediction results on an image.
  335. Args:
  336. predictions (Instances): the output of an instance detection/segmentation
  337. model. Following fields will be used to draw:
  338. "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
  339. Returns:
  340. output (VisImage): image object with visualizations.
  341. """
  342. boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
  343. scores = predictions.scores if predictions.has("scores") else None
  344. classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None
  345. labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None))
  346. keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None
  347. if predictions.has("pred_masks"):
  348. masks = np.asarray(predictions.pred_masks)
  349. masks = [GenericMask(x, self.output.height, self.output.width) for x in masks]
  350. else:
  351. masks = None
  352. if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
  353. colors = [
  354. self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes
  355. ]
  356. alpha = 0.8
  357. else:
  358. colors = None
  359. alpha = 0.5
  360. if self._instance_mode == ColorMode.IMAGE_BW:
  361. self.output.reset_image(
  362. self._create_grayscale_image(
  363. (predictions.pred_masks.any(dim=0) > 0).numpy()
  364. if predictions.has("pred_masks")
  365. else None
  366. )
  367. )
  368. alpha = 0.3
  369. self.overlay_instances(
  370. masks=masks,
  371. boxes=boxes,
  372. labels=labels,
  373. keypoints=keypoints,
  374. assigned_colors=colors,
  375. alpha=alpha,
  376. )
  377. return self.output
  378. def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8):
  379. """
  380. Draw semantic segmentation predictions/labels.
  381. Args:
  382. sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
  383. Each value is the integer label of the pixel.
  384. area_threshold (int): segments with less than `area_threshold` are not drawn.
  385. alpha (float): the larger it is, the more opaque the segmentations are.
  386. Returns:
  387. output (VisImage): image object with visualizations.
  388. """
  389. if isinstance(sem_seg, torch.Tensor):
  390. sem_seg = sem_seg.numpy()
  391. labels, areas = np.unique(sem_seg, return_counts=True)
  392. sorted_idxs = np.argsort(-areas).tolist()
  393. labels = labels[sorted_idxs]
  394. for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):
  395. try:
  396. mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
  397. except (AttributeError, IndexError):
  398. mask_color = None
  399. binary_mask = (sem_seg == label).astype(np.uint8)
  400. text = self.metadata.stuff_classes[label]
  401. self.draw_binary_mask(
  402. binary_mask,
  403. color=mask_color,
  404. edge_color=_OFF_WHITE,
  405. text=text,
  406. alpha=alpha,
  407. area_threshold=area_threshold,
  408. )
  409. return self.output
  410. def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7):
  411. """
  412. Draw panoptic prediction annotations or results.
  413. Args:
  414. panoptic_seg (Tensor): of shape (height, width) where the values are ids for each
  415. segment.
  416. segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.
  417. If it is a ``list[dict]``, each dict contains keys "id", "category_id".
  418. If None, category id of each pixel is computed by
  419. ``pixel // metadata.label_divisor``.
  420. area_threshold (int): stuff segments with less than `area_threshold` are not drawn.
  421. Returns:
  422. output (VisImage): image object with visualizations.
  423. """
  424. pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
  425. if self._instance_mode == ColorMode.IMAGE_BW:
  426. self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))
  427. # draw mask for all semantic segments first i.e. "stuff"
  428. for mask, sinfo in pred.semantic_masks():
  429. category_idx = sinfo["category_id"]
  430. try:
  431. mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
  432. except AttributeError:
  433. mask_color = None
  434. text = self.metadata.stuff_classes[category_idx]
  435. self.draw_binary_mask(
  436. mask,
  437. color=mask_color,
  438. edge_color=_OFF_WHITE,
  439. text=text,
  440. alpha=alpha,
  441. area_threshold=area_threshold,
  442. )
  443. # draw mask for all instances second
  444. all_instances = list(pred.instance_masks())
  445. if len(all_instances) == 0:
  446. return self.output
  447. masks, sinfo = list(zip(*all_instances))
  448. category_ids = [x["category_id"] for x in sinfo]
  449. try:
  450. scores = [x["score"] for x in sinfo]
  451. except KeyError:
  452. scores = None
  453. labels = _create_text_labels(
  454. category_ids, scores, self.metadata.thing_classes, [x.get("iscrowd", 0) for x in sinfo]
  455. )
  456. try:
  457. colors = [
  458. self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids
  459. ]
  460. except AttributeError:
  461. colors = None
  462. self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha)
  463. return self.output
  464. draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility
  465. def draw_dataset_dict(self, dic):
  466. """
  467. Draw annotations/segmentaions in Detectron2 Dataset format.
  468. Args:
  469. dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format.
  470. Returns:
  471. output (VisImage): image object with visualizations.
  472. """
  473. annos = dic.get("annotations", None)
  474. if annos:
  475. if "segmentation" in annos[0]:
  476. masks = [x["segmentation"] for x in annos]
  477. else:
  478. masks = None
  479. if "keypoints" in annos[0]:
  480. keypts = [x["keypoints"] for x in annos]
  481. keypts = np.array(keypts).reshape(len(annos), -1, 3)
  482. else:
  483. keypts = None
  484. boxes = [
  485. BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS)
  486. if len(x["bbox"]) == 4
  487. else x["bbox"]
  488. for x in annos
  489. ]
  490. colors = None
  491. category_ids = [x["category_id"] for x in annos]
  492. if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
  493. colors = [
  494. self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
  495. for c in category_ids
  496. ]
  497. names = self.metadata.get("thing_classes", None)
  498. labels = _create_text_labels(
  499. category_ids,
  500. scores=None,
  501. class_names=names,
  502. is_crowd=[x.get("iscrowd", 0) for x in annos],
  503. )
  504. self.overlay_instances(
  505. labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors
  506. )
  507. sem_seg = dic.get("sem_seg", None)
  508. if sem_seg is None and "sem_seg_file_name" in dic:
  509. with PathManager.open(dic["sem_seg_file_name"], "rb") as f:
  510. sem_seg = Image.open(f)
  511. sem_seg = np.asarray(sem_seg, dtype="uint8")
  512. if sem_seg is not None:
  513. self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.5)
  514. pan_seg = dic.get("pan_seg", None)
  515. if pan_seg is None and "pan_seg_file_name" in dic:
  516. with PathManager.open(dic["pan_seg_file_name"], "rb") as f:
  517. pan_seg = Image.open(f)
  518. pan_seg = np.asarray(pan_seg)
  519. from panopticapi.utils import rgb2id
  520. pan_seg = rgb2id(pan_seg)
  521. if pan_seg is not None:
  522. segments_info = dic["segments_info"]
  523. pan_seg = torch.tensor(pan_seg)
  524. self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.5)
  525. return self.output
  526. def overlay_instances(
  527. self,
  528. *,
  529. boxes=None,
  530. labels=None,
  531. masks=None,
  532. keypoints=None,
  533. assigned_colors=None,
  534. alpha=0.5,
  535. ):
  536. """
  537. Args:
  538. boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
  539. or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
  540. or a :class:`RotatedBoxes`,
  541. or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
  542. for the N objects in a single image,
  543. labels (list[str]): the text to be displayed for each instance.
  544. masks (masks-like object): Supported types are:
  545. * :class:`detectron2.structures.PolygonMasks`,
  546. :class:`detectron2.structures.BitMasks`.
  547. * list[list[ndarray]]: contains the segmentation masks for all objects in one image.
  548. The first level of the list corresponds to individual instances. The second
  549. level to all the polygon that compose the instance, and the third level
  550. to the polygon coordinates. The third level should have the format of
  551. [x0, y0, x1, y1, ..., xn, yn] (n >= 3).
  552. * list[ndarray]: each ndarray is a binary mask of shape (H, W).
  553. * list[dict]: each dict is a COCO-style RLE.
  554. keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),
  555. where the N is the number of instances and K is the number of keypoints.
  556. The last dimension corresponds to (x, y, visibility or score).
  557. assigned_colors (list[matplotlib.colors]): a list of colors, where each color
  558. corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
  559. for full list of formats that the colors are accepted in.
  560. Returns:
  561. output (VisImage): image object with visualizations.
  562. """
  563. num_instances = 0
  564. if boxes is not None:
  565. boxes = self._convert_boxes(boxes)
  566. num_instances = len(boxes)
  567. if masks is not None:
  568. masks = self._convert_masks(masks)
  569. if num_instances:
  570. assert len(masks) == num_instances
  571. else:
  572. num_instances = len(masks)
  573. if keypoints is not None:
  574. if num_instances:
  575. assert len(keypoints) == num_instances
  576. else:
  577. num_instances = len(keypoints)
  578. keypoints = self._convert_keypoints(keypoints)
  579. if labels is not None:
  580. assert len(labels) == num_instances
  581. if assigned_colors is None:
  582. assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
  583. if num_instances == 0:
  584. return self.output
  585. if boxes is not None and boxes.shape[1] == 5:
  586. return self.overlay_rotated_instances(
  587. boxes=boxes, labels=labels, assigned_colors=assigned_colors
  588. )
  589. # Display in largest to smallest order to reduce occlusion.
  590. areas = None
  591. if boxes is not None:
  592. areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
  593. elif masks is not None:
  594. areas = np.asarray([x.area() for x in masks])
  595. if areas is not None:
  596. sorted_idxs = np.argsort(-areas).tolist()
  597. # Re-order overlapped instances in descending order.
  598. boxes = boxes[sorted_idxs] if boxes is not None else None
  599. labels = [labels[k] for k in sorted_idxs] if labels is not None else None
  600. masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None
  601. assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
  602. keypoints = keypoints[sorted_idxs] if keypoints is not None else None
  603. for i in range(num_instances):
  604. color = assigned_colors[i]
  605. if boxes is not None:
  606. self.draw_box(boxes[i], edge_color=color)
  607. if masks is not None:
  608. for segment in masks[i].polygons:
  609. self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha)
  610. if labels is not None:
  611. # first get a box
  612. if boxes is not None:
  613. x0, y0, x1, y1 = boxes[i]
  614. text_pos = (x0, y0) # if drawing boxes, put text on the box corner.
  615. horiz_align = "left"
  616. elif masks is not None:
  617. # skip small mask without polygon
  618. if len(masks[i].polygons) == 0:
  619. continue
  620. x0, y0, x1, y1 = masks[i].bbox()
  621. # draw text in the center (defined by median) when box is not drawn
  622. # median is less sensitive to outliers.
  623. text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1]
  624. horiz_align = "center"
  625. else:
  626. continue # drawing the box confidence for keypoints isn't very useful.
  627. # for small objects, draw text at the side to avoid occlusion
  628. instance_area = (y1 - y0) * (x1 - x0)
  629. if (
  630. instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale
  631. or y1 - y0 < 40 * self.output.scale
  632. ):
  633. if y1 >= self.output.height - 5:
  634. text_pos = (x1, y0)
  635. else:
  636. text_pos = (x0, y1)
  637. height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)
  638. lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
  639. font_size = (
  640. np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
  641. * 0.5
  642. * self._default_font_size
  643. )
  644. self.draw_text(
  645. labels[i],
  646. text_pos,
  647. color=lighter_color,
  648. horizontal_alignment=horiz_align,
  649. font_size=font_size,
  650. )
  651. # draw keypoints
  652. if keypoints is not None:
  653. for keypoints_per_instance in keypoints:
  654. self.draw_and_connect_keypoints(keypoints_per_instance)
  655. return self.output
  656. def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None):
  657. """
  658. Args:
  659. boxes (ndarray): an Nx5 numpy array of
  660. (x_center, y_center, width, height, angle_degrees) format
  661. for the N objects in a single image.
  662. labels (list[str]): the text to be displayed for each instance.
  663. assigned_colors (list[matplotlib.colors]): a list of colors, where each color
  664. corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
  665. for full list of formats that the colors are accepted in.
  666. Returns:
  667. output (VisImage): image object with visualizations.
  668. """
  669. num_instances = len(boxes)
  670. if assigned_colors is None:
  671. assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
  672. if num_instances == 0:
  673. return self.output
  674. # Display in largest to smallest order to reduce occlusion.
  675. if boxes is not None:
  676. areas = boxes[:, 2] * boxes[:, 3]
  677. sorted_idxs = np.argsort(-areas).tolist()
  678. # Re-order overlapped instances in descending order.
  679. boxes = boxes[sorted_idxs]
  680. labels = [labels[k] for k in sorted_idxs] if labels is not None else None
  681. colors = [assigned_colors[idx] for idx in sorted_idxs]
  682. for i in range(num_instances):
  683. self.draw_rotated_box_with_label(
  684. boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None
  685. )
  686. return self.output
  687. def draw_and_connect_keypoints(self, keypoints):
  688. """
  689. Draws keypoints of an instance and follows the rules for keypoint connections
  690. to draw lines between appropriate keypoints. This follows color heuristics for
  691. line color.
  692. Args:
  693. keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints
  694. and the last dimension corresponds to (x, y, probability).
  695. Returns:
  696. output (VisImage): image object with visualizations.
  697. """
  698. visible = {}
  699. keypoint_names = self.metadata.get("keypoint_names")
  700. for idx, keypoint in enumerate(keypoints):
  701. # draw keypoint
  702. x, y, prob = keypoint
  703. if prob > self.keypoint_threshold:
  704. self.draw_circle((x, y), color=_RED)
  705. if keypoint_names:
  706. keypoint_name = keypoint_names[idx]
  707. visible[keypoint_name] = (x, y)
  708. if self.metadata.get("keypoint_connection_rules"):
  709. for kp0, kp1, color in self.metadata.keypoint_connection_rules:
  710. if kp0 in visible and kp1 in visible:
  711. x0, y0 = visible[kp0]
  712. x1, y1 = visible[kp1]
  713. color = tuple(x / 255.0 for x in color)
  714. self.draw_line([x0, x1], [y0, y1], color=color)
  715. # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip
  716. # Note that this strategy is specific to person keypoints.
  717. # For other keypoints, it should just do nothing
  718. try:
  719. ls_x, ls_y = visible["left_shoulder"]
  720. rs_x, rs_y = visible["right_shoulder"]
  721. mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2
  722. except KeyError:
  723. pass
  724. else:
  725. # draw line from nose to mid-shoulder
  726. nose_x, nose_y = visible.get("nose", (None, None))
  727. if nose_x is not None:
  728. self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED)
  729. try:
  730. # draw line from mid-shoulder to mid-hip
  731. lh_x, lh_y = visible["left_hip"]
  732. rh_x, rh_y = visible["right_hip"]
  733. except KeyError:
  734. pass
  735. else:
  736. mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2
  737. self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED)
  738. return self.output
  739. """
  740. Primitive drawing functions:
  741. """
  742. def draw_text(
  743. self,
  744. text,
  745. position,
  746. *,
  747. font_size=None,
  748. color="g",
  749. horizontal_alignment="center",
  750. rotation=0,
  751. ):
  752. """
  753. Args:
  754. text (str): class label
  755. position (tuple): a tuple of the x and y coordinates to place text on image.
  756. font_size (int, optional): font of the text. If not provided, a font size
  757. proportional to the image width is calculated and used.
  758. color: color of the text. Refer to `matplotlib.colors` for full list
  759. of formats that are accepted.
  760. horizontal_alignment (str): see `matplotlib.text.Text`
  761. rotation: rotation angle in degrees CCW
  762. Returns:
  763. output (VisImage): image object with text drawn.
  764. """
  765. if not font_size:
  766. font_size = self._default_font_size
  767. # since the text background is dark, we don't want the text to be dark
  768. color = np.maximum(list(mplc.to_rgb(color)), 0.2)
  769. color[np.argmax(color)] = max(0.8, np.max(color))
  770. x, y = position
  771. self.output.ax.text(
  772. x,
  773. y,
  774. text,
  775. size=font_size * self.output.scale,
  776. family="sans-serif",
  777. bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
  778. verticalalignment="top",
  779. horizontalalignment=horizontal_alignment,
  780. color=color,
  781. zorder=10,
  782. rotation=rotation,
  783. )
  784. return self.output
  785. def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
  786. """
  787. Args:
  788. box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0
  789. are the coordinates of the image's top left corner. x1 and y1 are the
  790. coordinates of the image's bottom right corner.
  791. alpha (float): blending efficient. Smaller values lead to more transparent masks.
  792. edge_color: color of the outline of the box. Refer to `matplotlib.colors`
  793. for full list of formats that are accepted.
  794. line_style (string): the string to use to create the outline of the boxes.
  795. Returns:
  796. output (VisImage): image object with box drawn.
  797. """
  798. x0, y0, x1, y1 = box_coord
  799. width = x1 - x0
  800. height = y1 - y0
  801. linewidth = max(self._default_font_size / 4, 1)
  802. self.output.ax.add_patch(
  803. mpl.patches.Rectangle(
  804. (x0, y0),
  805. width,
  806. height,
  807. fill=False,
  808. edgecolor=edge_color,
  809. linewidth=linewidth * self.output.scale,
  810. alpha=alpha,
  811. linestyle=line_style,
  812. )
  813. )
  814. return self.output
  815. def draw_rotated_box_with_label(
  816. self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None
  817. ):
  818. """
  819. Draw a rotated box with label on its top-left corner.
  820. Args:
  821. rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle),
  822. where cnt_x and cnt_y are the center coordinates of the box.
  823. w and h are the width and height of the box. angle represents how
  824. many degrees the box is rotated CCW with regard to the 0-degree box.
  825. alpha (float): blending efficient. Smaller values lead to more transparent masks.
  826. edge_color: color of the outline of the box. Refer to `matplotlib.colors`
  827. for full list of formats that are accepted.
  828. line_style (string): the string to use to create the outline of the boxes.
  829. label (string): label for rotated box. It will not be rendered when set to None.
  830. Returns:
  831. output (VisImage): image object with box drawn.
  832. """
  833. cnt_x, cnt_y, w, h, angle = rotated_box
  834. area = w * h
  835. # use thinner lines when the box is small
  836. linewidth = self._default_font_size / (
  837. 6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3
  838. )
  839. theta = angle * math.pi / 180.0
  840. c = math.cos(theta)
  841. s = math.sin(theta)
  842. rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)]
  843. # x: left->right ; y: top->down
  844. rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect]
  845. for k in range(4):
  846. j = (k + 1) % 4
  847. self.draw_line(
  848. [rotated_rect[k][0], rotated_rect[j][0]],
  849. [rotated_rect[k][1], rotated_rect[j][1]],
  850. color=edge_color,
  851. linestyle="--" if k == 1 else line_style,
  852. linewidth=linewidth,
  853. )
  854. if label is not None:
  855. text_pos = rotated_rect[1] # topleft corner
  856. height_ratio = h / np.sqrt(self.output.height * self.output.width)
  857. label_color = self._change_color_brightness(edge_color, brightness_factor=0.7)
  858. font_size = (
  859. np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size
  860. )
  861. self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle)
  862. return self.output
  863. def draw_circle(self, circle_coord, color, radius=3):
  864. """
  865. Args:
  866. circle_coord (list(int) or tuple(int)): contains the x and y coordinates
  867. of the center of the circle.
  868. color: color of the polygon. Refer to `matplotlib.colors` for a full list of
  869. formats that are accepted.
  870. radius (int): radius of the circle.
  871. Returns:
  872. output (VisImage): image object with box drawn.
  873. """
  874. x, y = circle_coord
  875. self.output.ax.add_patch(
  876. mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color)
  877. )
  878. return self.output
  879. def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
  880. """
  881. Args:
  882. x_data (list[int]): a list containing x values of all the points being drawn.
  883. Length of list should match the length of y_data.
  884. y_data (list[int]): a list containing y values of all the points being drawn.
  885. Length of list should match the length of x_data.
  886. color: color of the line. Refer to `matplotlib.colors` for a full list of
  887. formats that are accepted.
  888. linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
  889. for a full list of formats that are accepted.
  890. linewidth (float or None): width of the line. When it's None,
  891. a default value will be computed and used.
  892. Returns:
  893. output (VisImage): image object with line drawn.
  894. """
  895. if linewidth is None:
  896. linewidth = self._default_font_size / 3
  897. linewidth = max(linewidth, 1)
  898. self.output.ax.add_line(
  899. mpl.lines.Line2D(
  900. x_data,
  901. y_data,
  902. linewidth=linewidth * self.output.scale,
  903. color=color,
  904. linestyle=linestyle,
  905. )
  906. )
  907. return self.output
  908. def draw_binary_mask(
  909. self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.5, area_threshold=0
  910. ):
  911. """
  912. Args:
  913. binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
  914. W is the image width. Each value in the array is either a 0 or 1 value of uint8
  915. type.
  916. color: color of the mask. Refer to `matplotlib.colors` for a full list of
  917. formats that are accepted. If None, will pick a random color.
  918. edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
  919. full list of formats that are accepted.
  920. text (str): if None, will be drawn in the object's center of mass.
  921. alpha (float): blending efficient. Smaller values lead to more transparent masks.
  922. area_threshold (float): a connected component small than this will not be shown.
  923. Returns:
  924. output (VisImage): image object with mask drawn.
  925. """
  926. if color is None:
  927. color = random_color(rgb=True, maximum=1)
  928. color = mplc.to_rgb(color)
  929. has_valid_segment = False
  930. binary_mask = binary_mask.astype("uint8") # opencv needs uint8
  931. mask = GenericMask(binary_mask, self.output.height, self.output.width)
  932. shape2d = (binary_mask.shape[0], binary_mask.shape[1])
  933. if not mask.has_holes:
  934. # draw polygons for regular masks
  935. for segment in mask.polygons:
  936. area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1]))
  937. if area < (area_threshold or 0):
  938. continue
  939. has_valid_segment = True
  940. segment = segment.reshape(-1, 2)
  941. self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha)
  942. else:
  943. # TODO: Use Path/PathPatch to draw vector graphics:
  944. # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
  945. rgba = np.zeros(shape2d + (4,), dtype="float32")
  946. rgba[:, :, :3] = color
  947. rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
  948. has_valid_segment = True
  949. self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
  950. if text is not None and has_valid_segment:
  951. # TODO sometimes drawn on wrong objects. the heuristics here can improve.
  952. lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
  953. _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)
  954. largest_component_id = np.argmax(stats[1:, -1]) + 1
  955. # draw text on the largest component, as well as other very large components.
  956. for cid in range(1, _num_cc):
  957. if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
  958. # median is more stable than centroid
  959. # center = centroids[largest_component_id]
  960. center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
  961. self.draw_text(text, center, color=lighter_color)
  962. return self.output
  963. def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
  964. """
  965. Args:
  966. segment: numpy array of shape Nx2, containing all the points in the polygon.
  967. color: color of the polygon. Refer to `matplotlib.colors` for a full list of
  968. formats that are accepted.
  969. edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
  970. full list of formats that are accepted. If not provided, a darker shade
  971. of the polygon color will be used instead.
  972. alpha (float): blending efficient. Smaller values lead to more transparent masks.
  973. Returns:
  974. output (VisImage): image object with polygon drawn.
  975. """
  976. if edge_color is None:
  977. # make edge color darker than the polygon color
  978. if alpha > 0.8:
  979. edge_color = self._change_color_brightness(color, brightness_factor=-0.7)
  980. else:
  981. edge_color = color
  982. edge_color = mplc.to_rgb(edge_color) + (1,)
  983. polygon = mpl.patches.Polygon(
  984. segment,
  985. fill=True,
  986. facecolor=mplc.to_rgb(color) + (alpha,),
  987. edgecolor=edge_color,
  988. linewidth=max(self._default_font_size // 15 * self.output.scale, 1),
  989. )
  990. self.output.ax.add_patch(polygon)
  991. return self.output
  992. """
  993. Internal methods:
  994. """
  995. def _jitter(self, color):
  996. """
  997. Randomly modifies given color to produce a slightly different color than the color given.
  998. Args:
  999. color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color
  1000. picked. The values in the list are in the [0.0, 1.0] range.
  1001. Returns:
  1002. jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the
  1003. color after being jittered. The values in the list are in the [0.0, 1.0] range.
  1004. """
  1005. color = mplc.to_rgb(color)
  1006. vec = np.random.rand(3)
  1007. # better to do it in another color space
  1008. vec = vec / np.linalg.norm(vec) * 0.5
  1009. res = np.clip(vec + color, 0, 1)
  1010. return tuple(res)
  1011. def _create_grayscale_image(self, mask=None):
  1012. """
  1013. Create a grayscale version of the original image.
  1014. The colors in masked area, if given, will be kept.
  1015. """
  1016. img_bw = self.img.astype("f4").mean(axis=2)
  1017. img_bw = np.stack([img_bw] * 3, axis=2)
  1018. if mask is not None:
  1019. img_bw[mask] = self.img[mask]
  1020. return img_bw
  1021. def _change_color_brightness(self, color, brightness_factor):
  1022. """
  1023. Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
  1024. less or more saturation than the original color.
  1025. Args:
  1026. color: color of the polygon. Refer to `matplotlib.colors` for a full list of
  1027. formats that are accepted.
  1028. brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
  1029. 0 will correspond to no change, a factor in [-1.0, 0) range will result in
  1030. a darker color and a factor in (0, 1.0] range will result in a lighter color.
  1031. Returns:
  1032. modified_color (tuple[double]): a tuple containing the RGB values of the
  1033. modified color. Each value in the tuple is in the [0.0, 1.0] range.
  1034. """
  1035. assert brightness_factor >= -1.0 and brightness_factor <= 1.0
  1036. color = mplc.to_rgb(color)
  1037. polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
  1038. modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
  1039. modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
  1040. modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
  1041. modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
  1042. return modified_color
  1043. def _convert_boxes(self, boxes):
  1044. """
  1045. Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension.
  1046. """
  1047. if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes):
  1048. return boxes.tensor.detach().numpy()
  1049. else:
  1050. return np.asarray(boxes)
  1051. def _convert_masks(self, masks_or_polygons):
  1052. """
  1053. Convert different format of masks or polygons to a tuple of masks and polygons.
  1054. Returns:
  1055. list[GenericMask]:
  1056. """
  1057. m = masks_or_polygons
  1058. if isinstance(m, PolygonMasks):
  1059. m = m.polygons
  1060. if isinstance(m, BitMasks):
  1061. m = m.tensor.numpy()
  1062. if isinstance(m, torch.Tensor):
  1063. m = m.numpy()
  1064. ret = []
  1065. for x in m:
  1066. if isinstance(x, GenericMask):
  1067. ret.append(x)
  1068. else:
  1069. ret.append(GenericMask(x, self.output.height, self.output.width))
  1070. return ret
  1071. def _convert_keypoints(self, keypoints):
  1072. if isinstance(keypoints, Keypoints):
  1073. keypoints = keypoints.tensor
  1074. keypoints = np.asarray(keypoints)
  1075. return keypoints
  1076. def get_output(self):
  1077. """
  1078. Returns:
  1079. output (VisImage): the image output containing the visualizations added
  1080. to the image.
  1081. """
  1082. return self.output