visualizer_3d.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import os
  2. import open3d
  3. import numpy as np
  4. import argparse
  5. class Visualizer3D:
  6. def __init__(self):
  7. self.vis = open3d.visualization.Visualizer() # initialize visualizer
  8. def boxes_to_lines(self, box: np.ndarray) -> open3d.geometry.LineSet:
  9. """
  10. 4-------- 6
  11. /| /|
  12. 5 -------- 3 .
  13. | | | |
  14. . 7 -------- 1
  15. |/ |/
  16. 2 -------- 0
  17. """
  18. center = box[0:3]
  19. lwh = box[3:6]
  20. angles = np.array([0, 0, box[6] + 1e-10])
  21. rot = open3d.geometry.get_rotation_matrix_from_axis_angle(angles)
  22. box3d = open3d.geometry.OrientedBoundingBox(center, rot, lwh)
  23. return open3d.geometry.LineSet.create_from_oriented_bounding_box(box3d)
  24. def draw_results(self, points: np.ndarray, result: dict, score_threshold: float) -> None:
  25. scores = result["scores"]
  26. bbox3d = result["bbox3d"]
  27. label_preds = result["labels"]
  28. num_bbox3d, bbox3d_dims = bbox3d.shape
  29. result_boxes = []
  30. for box_idx in range(num_bbox3d):
  31. if scores[box_idx] < score_threshold:
  32. continue
  33. if bbox3d_dims == 9:
  34. print(
  35. "Score: {} Label: {} Box(x_c, y_c, z_c, w, l, h, vec_x, vec_y, -rot): {} {} {} {} {} {} {} {} {}"
  36. .format(scores[box_idx], label_preds[box_idx],
  37. bbox3d[box_idx, 0], bbox3d[box_idx, 1],
  38. bbox3d[box_idx, 2], bbox3d[box_idx, 3],
  39. bbox3d[box_idx, 4], bbox3d[box_idx, 5],
  40. bbox3d[box_idx, 6], bbox3d[box_idx, 7],
  41. bbox3d[box_idx, 8]))
  42. elif bbox3d_dims == 7:
  43. print(
  44. "Score: {} Label: {} Box(x_c, y_c, z_c, w, l, h, -rot): {} {} {} {} {} {} {}"
  45. .format(scores[box_idx], label_preds[box_idx],
  46. bbox3d[box_idx, 0], bbox3d[box_idx, 1],
  47. bbox3d[box_idx, 2], bbox3d[box_idx, 3],
  48. bbox3d[box_idx, 4], bbox3d[box_idx, 5],
  49. bbox3d[box_idx, 6]))
  50. # draw result
  51. result_boxes.append([
  52. bbox3d[box_idx, 0], bbox3d[box_idx, 1],
  53. bbox3d[box_idx, 2], bbox3d[box_idx, 3],
  54. bbox3d[box_idx, 4], bbox3d[box_idx, 5],
  55. bbox3d[box_idx, -1]
  56. ])
  57. # config
  58. self.vis.create_window()
  59. self.vis.get_render_option().point_size = 1.0
  60. self.vis.get_render_option().background_color = [0, 0, 0]
  61. pc_color = [1, 1, 1]
  62. num_points = len(points)
  63. pc_colors = np.tile(pc_color, (num_points, 1))
  64. # raw point cloud
  65. pts = open3d.geometry.PointCloud()
  66. pts.points = open3d.utility.Vector3dVector(points[:, :3])
  67. pts.colors = open3d.utility.Vector3dVector(pc_colors)
  68. self.vis.add_geometry(pts)
  69. # result_boxes
  70. obs_color = [1, 0, 0]
  71. result_boxes = np.array(result_boxes)
  72. for i in range(result_boxes.shape[0]):
  73. lines = self.boxes_to_lines(result_boxes[i])
  74. # show different colors for different classes
  75. if label_preds[i] <= 4:
  76. obs_color = [0, 1, 0] # 'car', 'truck', 'trailer', 'bus', 'construction_vehicle',
  77. elif (label_preds[i] <= 6):
  78. obs_color = [0, 0, 1] # 'bicycle', 'motorcycle'
  79. elif (label_preds[i] <= 7):
  80. obs_color = [1, 0, 0] # 'pedestrian'
  81. else:
  82. obs_color = [1, 0, 1] # 'traffic_cone','barrier'
  83. lines.paint_uniform_color(obs_color)
  84. self.vis.add_geometry(lines)
  85. self.vis.run()
  86. self.vis.poll_events()
  87. self.vis.update_renderer()
  88. # self.vis.capture_screen_image("result.png")
  89. self.vis.destroy_window()
  90. if __name__ == "__main__":
  91. parser = argparse.ArgumentParser(description='Visualizer 3d')
  92. parser.add_argument(
  93. '--save_path',
  94. type=str,
  95. default=None)
  96. args = parser.parse_args()
  97. save_path = args.save_path
  98. if save_path is None:
  99. raise ValueError("Please specify the path to the saved results.")
  100. points = np.load(os.path.join(save_path, "points.npy"), allow_pickle=True)
  101. result = np.load(os.path.join(save_path, "results.npy"), allow_pickle=True).item()
  102. score_threshold = 0.25
  103. vis = Visualizer3D()
  104. vis.draw_results(points, result, score_threshold)