visualizer_3d.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import os
  2. import numpy as np
  3. import argparse
  4. import importlib.util
  5. import sys
  6. class LazyLoader:
  7. def __init__(self, module_name):
  8. self.module_name = module_name
  9. self._module = None
  10. def __getattr__(self, item):
  11. if self._module is None:
  12. self._module = importlib.import_module(self.module_name)
  13. return getattr(self._module, item)
  14. open3d = LazyLoader('open3d')
  15. class Visualizer3D:
  16. def __init__(self):
  17. self.vis = open3d.visualization.Visualizer() # initialize visualizer
  18. def boxes_to_lines(self, box: np.ndarray):
  19. """
  20. 4-------- 6
  21. /| /|
  22. 5 -------- 3 .
  23. | | | |
  24. . 7 -------- 1
  25. |/ |/
  26. 2 -------- 0
  27. """
  28. center = box[0:3]
  29. lwh = box[3:6]
  30. angles = np.array([0, 0, box[6] + 1e-10])
  31. rot = open3d.geometry.get_rotation_matrix_from_axis_angle(angles)
  32. box3d = open3d.geometry.OrientedBoundingBox(center, rot, lwh)
  33. return open3d.geometry.LineSet.create_from_oriented_bounding_box(box3d)
  34. def draw_results(self, points: np.ndarray, result: dict, score_threshold: float) -> None:
  35. scores = result["scores"]
  36. bbox3d = result["bbox3d"]
  37. label_preds = result["labels"]
  38. num_bbox3d, bbox3d_dims = bbox3d.shape
  39. result_boxes = []
  40. for box_idx in range(num_bbox3d):
  41. if scores[box_idx] < score_threshold:
  42. continue
  43. if bbox3d_dims == 9:
  44. print(
  45. "Score: {} Label: {} Box(x_c, y_c, z_c, w, l, h, vec_x, vec_y, -rot): {} {} {} {} {} {} {} {} {}"
  46. .format(scores[box_idx], label_preds[box_idx],
  47. bbox3d[box_idx, 0], bbox3d[box_idx, 1],
  48. bbox3d[box_idx, 2], bbox3d[box_idx, 3],
  49. bbox3d[box_idx, 4], bbox3d[box_idx, 5],
  50. bbox3d[box_idx, 6], bbox3d[box_idx, 7],
  51. bbox3d[box_idx, 8]))
  52. elif bbox3d_dims == 7:
  53. print(
  54. "Score: {} Label: {} Box(x_c, y_c, z_c, w, l, h, -rot): {} {} {} {} {} {} {}"
  55. .format(scores[box_idx], label_preds[box_idx],
  56. bbox3d[box_idx, 0], bbox3d[box_idx, 1],
  57. bbox3d[box_idx, 2], bbox3d[box_idx, 3],
  58. bbox3d[box_idx, 4], bbox3d[box_idx, 5],
  59. bbox3d[box_idx, 6]))
  60. # draw result
  61. result_boxes.append([
  62. bbox3d[box_idx, 0], bbox3d[box_idx, 1],
  63. bbox3d[box_idx, 2], bbox3d[box_idx, 3],
  64. bbox3d[box_idx, 4], bbox3d[box_idx, 5],
  65. bbox3d[box_idx, -1]
  66. ])
  67. # config
  68. self.vis.create_window()
  69. self.vis.get_render_option().point_size = 1.0
  70. self.vis.get_render_option().background_color = [0, 0, 0]
  71. pc_color = [1, 1, 1]
  72. num_points = len(points)
  73. pc_colors = np.tile(pc_color, (num_points, 1))
  74. # raw point cloud
  75. pts = open3d.geometry.PointCloud()
  76. pts.points = open3d.utility.Vector3dVector(points[:, :3])
  77. pts.colors = open3d.utility.Vector3dVector(pc_colors)
  78. self.vis.add_geometry(pts)
  79. # result_boxes
  80. obs_color = [1, 0, 0]
  81. result_boxes = np.array(result_boxes)
  82. for i in range(result_boxes.shape[0]):
  83. lines = self.boxes_to_lines(result_boxes[i])
  84. # show different colors for different classes
  85. if label_preds[i] <= 4:
  86. obs_color = [0, 1, 0] # 'car', 'truck', 'trailer', 'bus', 'construction_vehicle',
  87. elif (label_preds[i] <= 6):
  88. obs_color = [0, 0, 1] # 'bicycle', 'motorcycle'
  89. elif (label_preds[i] <= 7):
  90. obs_color = [1, 0, 0] # 'pedestrian'
  91. else:
  92. obs_color = [1, 0, 1] # 'traffic_cone','barrier'
  93. lines.paint_uniform_color(obs_color)
  94. self.vis.add_geometry(lines)
  95. self.vis.run()
  96. self.vis.poll_events()
  97. self.vis.update_renderer()
  98. # self.vis.capture_screen_image("result.png")
  99. self.vis.destroy_window()
  100. if __name__ == "__main__":
  101. parser = argparse.ArgumentParser(description='Visualizer 3d')
  102. parser.add_argument(
  103. '--save_path',
  104. type=str,
  105. default=None)
  106. args = parser.parse_args()
  107. save_path = args.save_path
  108. if save_path is None:
  109. raise ValueError("Please specify the path to the saved results.")
  110. points = np.load(os.path.join(save_path, "points.npy"), allow_pickle=True)
  111. result = np.load(os.path.join(save_path, "results.npy"), allow_pickle=True).item()
  112. score_threshold = 0.25
  113. vis = Visualizer3D()
  114. vis.draw_results(points, result, score_threshold)