瀏覽代碼

feat(3d_bev_detection):Add visualization for 3d bev detection pipeline. (#3291)

* modify 3d bev detection pipeline docs : add pipeline output description.

* bugfix(3d_bev_detection):Avoid compiling ops when paddlex is initialized.

* bugfix(3d_bev_detection):Avoid compiling ops when paddlex is initialized.

* feat(3d_bev_detection):Add visualization for 3d bev detection pipeline.

* feat(3d_bev_detection):Add visualization for 3d bev detection pipeline.

* feat(3d_bev_detection):Add open3d package to requirements.txt.

* feat(3d_bev_detection):rm open3d package to requirements.txt.

* feat(3d_bev_detection):Add visualization for 3d bev detection pipeline.

* feat(3d_bev_detection):Add visualization for 3d bev detection pipeline.

* feat(3d_bev_detection):Add visualization for 3d bev detection pipeline.

---------

Co-authored-by: cuicheng01 <45199522+cuicheng01@users.noreply.github.com>
Jonathans575 9 月之前
父節點
當前提交
7c75d82d85

+ 9 - 0
docs/module_usage/tutorials/cv_modules/3d_bev_detection.md

@@ -48,8 +48,17 @@ output = pipeline.predict("nuscenes_demo_infer.tar")
 for res in output:
     res.print()  ## 打印预测的结构化输出
     res.save_to_json("./output/")  ## 保存结果到json文件
+    res.visualize(save_path="./output/", show=True) ## 3d结果可视化,如果运行环境有图形界面设置show=True,否则设置为False
 ```
 
+<b>注:</b> 如果运行环境没有图形界面,则无法可视化,但不影响结果的保存,可以在支持图形界面的环境下运行脚本,对保存的结果进行可视化:
+```bash
+python paddlex/inference/models/3d_bev_detection/visualizer_3d.py --save_path="./output/"
+```
+
+<img src="https://raw.githubusercontent.com/cuicheng01/PaddleX_doc_images/refs/heads/main/images/images/pipelines/3d_bev_detection/02.png">
+
+
 运行后,得到的结果为:
 ```bash
 {"res":

+ 9 - 0
docs/pipeline_usage/tutorials/cv_pipelines/3d_bev_detection.en.md

@@ -139,8 +139,16 @@ output = pipeline.predict("nuscenes_demo_infer.tar")
 for res in output:
     res.print()  ## Print the structured output of the prediction
     res.save_to_json("./output/")  ## Save the results to a json file
+    res.visualize(save_path="./output/", show=True) ## 3D result visualization. If the runtime environment has a graphical interface, set `show=True`; otherwise, set it to `False`.
 ```
 
+<b>Note: </b> If the runtime environment does not have a graphical interface, visualization will not be possible, but the results will still be saved. You can run the script in an environment that supports a graphical interface to visualize the saved results:
+```bash
+python paddlex/inference/models/3d_bev_detection/visualizer_3d.py --save_path="./output/"
+```
+
+<img src="https://raw.githubusercontent.com/cuicheng01/PaddleX_doc_images/refs/heads/main/images/images/pipelines/3d_bev_detection/02.png">
+
 In the above Python script, the following steps are executed:
 
 (1)Call `create_pipeline` to instantiate the 3D multi-modal fusion detection pipeline object. Specific parameter descriptions are as follows:
@@ -243,6 +251,7 @@ output = pipeline.predict("nuscenes_demo_infer.tar")
 for res in output:
     res.print()  ## Print the structured output of the prediction
     res.save_to_json("./output/")  ## Save the results to a json file
+    res.visualize(save_path="./output/", show=True) ## 3D result visualization. If the runtime environment has a graphical interface, set `show=True`; otherwise, set it to `False`.
 ```
 
 <b>Note: </b>The parameters in the configuration file are pipeline initialization parameters. If you want to change the 3D multi-modal fusion detection pipeline initialization parameters, you can directly modify the parameters in the configuration file and load the configuration file for prediction. At the same time, CLI prediction also supports passing in a configuration file by specifying the path with `--pipeline`.

+ 10 - 0
docs/pipeline_usage/tutorials/cv_pipelines/3d_bev_detection.md

@@ -131,8 +131,17 @@ output = pipeline.predict("nuscenes_demo_infer.tar")
 for res in output:
     res.print()  ## 打印预测的结构化输出
     res.save_to_json("./output/")  ## 保存结果到json文件
+    res.visualize(save_path="./output/", show=True) ## 3d结果可视化,如果运行环境有图形界面设置show=True,否则设置为False
 ```
 
+<b>注:</b> 如果运行环境没有图形界面,则无法可视化,但不影响结果的保存,可以在支持图形界面的环境下运行脚本,对保存的结果进行可视化:
+```bash
+python paddlex/inference/models/3d_bev_detection/visualizer_3d.py --save_path="./output/"
+```
+
+<img src="https://raw.githubusercontent.com/cuicheng01/PaddleX_doc_images/refs/heads/main/images/images/pipelines/3d_bev_detection/02.png">
+
+
 在上述 Python 脚本中,执行了如下几个步骤:
 
 (1)调用 `create_pipeline` 实例化 3D多模态融合检测 产线对象:具体参数说明如下:
@@ -235,6 +244,7 @@ output = pipeline.predict("nuscenes_demo_infer.tar")
 for res in output:
     res.print()  ## 打印预测的结构化输出
     res.save_to_json("./output/")  ## 保存结果到json文件
+    res.visualize(save_path="./output/", show=True) ## 3d结果可视化,如果运行环境有图形界面设置show=True,否则设置为False
 ```
 
 <b>注:</b> 配置文件中的参数为产线初始化参数,如果希望更改3D多模态融合检测产线初始化参数,可以直接修改配置文件中的参数,并加载配置文件进行预测。同时,CLI 预测也支持传入配置文件,`--pipeline` 指定配置文件的路径即可。

+ 34 - 1
paddlex/inference/models/3d_bev_detection/result.py

@@ -13,7 +13,9 @@
 # limitations under the License.
 
 from ...common.result import BaseResult, StrMixin, JsonMixin
-
+import numpy as np
+import os
+from .visualizer_3d import Visualizer3D
 
 class BEV3DDetResult(BaseResult):
     """Base class for computer vision results."""
@@ -30,3 +32,34 @@ class BEV3DDetResult(BaseResult):
         """
 
         super().__init__(data)
+
+    def visualize(self, save_path: str, show: bool) -> None:
+        # input point cloud
+        assert 'input_path' in self.keys(), 'input_path is not found in the data'
+        points = np.fromfile(self['input_path'], dtype=np.float32)
+        points = points.reshape(-1, 5)
+        points = points[:, :4]
+
+        # detection result
+        result = dict()
+        assert 'boxes_3d' in self.keys(), 'boxes_3d is not found in the data'
+        result["bbox3d"] = self["boxes_3d"]
+        assert 'scores_3d' in self.keys(), 'scores_3d is not found in the data'
+        result["scores"] = self["scores_3d"]
+        assert 'labels_3d' in self.keys(), 'labels_3d is not found in the data'
+        result["labels"] = self["labels_3d"]
+
+        if save_path is not None:
+            # save result for local visualization
+            if not os.path.exists(save_path):
+                os.makedirs(save_path)
+            np.save(os.path.join(save_path, 'results.npy'), result)
+            np.save(os.path.join(save_path, 'points.npy'), points)
+        
+        if show:
+            # visualize
+            score_threshold = 0.25
+            vis = Visualizer3D()
+            vis.draw_results(points, result, score_threshold)
+
+        return

+ 118 - 0
paddlex/inference/models/3d_bev_detection/visualizer_3d.py

@@ -0,0 +1,118 @@
+import os
+import open3d
+import numpy as np
+import argparse
+
+class Visualizer3D:
+    def __init__(self):
+        self.vis = open3d.visualization.Visualizer() # initialize visualizer
+
+    def boxes_to_lines(self, box: np.ndarray) -> open3d.geometry.LineSet:
+        """
+           4-------- 6
+         /|         /|
+        5 -------- 3 .
+        | |        | |
+        . 7 -------- 1
+        |/         |/
+        2 -------- 0
+        """
+        center = box[0:3]
+        lwh = box[3:6]
+        angles = np.array([0, 0, box[6] + 1e-10])
+        rot = open3d.geometry.get_rotation_matrix_from_axis_angle(angles)
+        box3d = open3d.geometry.OrientedBoundingBox(center, rot, lwh)
+        return open3d.geometry.LineSet.create_from_oriented_bounding_box(box3d)
+
+    def draw_results(self, points: np.ndarray, result: dict, score_threshold: float) -> None:
+        scores = result["scores"]
+        bbox3d = result["bbox3d"]
+        label_preds = result["labels"]
+
+        num_bbox3d, bbox3d_dims = bbox3d.shape
+        result_boxes = []
+        for box_idx in range(num_bbox3d):
+            if scores[box_idx] < score_threshold:
+                continue
+            if bbox3d_dims == 9:
+                print(
+                    "Score: {} Label: {} Box(x_c, y_c, z_c, w, l, h, vec_x, vec_y, -rot): {} {} {} {} {} {} {} {} {}"
+                    .format(scores[box_idx], label_preds[box_idx],
+                            bbox3d[box_idx, 0], bbox3d[box_idx, 1],
+                            bbox3d[box_idx, 2], bbox3d[box_idx, 3],
+                            bbox3d[box_idx, 4], bbox3d[box_idx, 5],
+                            bbox3d[box_idx, 6], bbox3d[box_idx, 7],
+                            bbox3d[box_idx, 8]))
+            elif bbox3d_dims == 7:
+                print(
+                    "Score: {} Label: {} Box(x_c, y_c, z_c, w, l, h, -rot): {} {} {} {} {} {} {}"
+                    .format(scores[box_idx], label_preds[box_idx],
+                            bbox3d[box_idx, 0], bbox3d[box_idx, 1],
+                            bbox3d[box_idx, 2], bbox3d[box_idx, 3],
+                            bbox3d[box_idx, 4], bbox3d[box_idx, 5],
+                            bbox3d[box_idx, 6]))
+            # draw result
+            result_boxes.append([
+                bbox3d[box_idx, 0], bbox3d[box_idx, 1],
+                bbox3d[box_idx, 2], bbox3d[box_idx, 3],
+                bbox3d[box_idx, 4], bbox3d[box_idx, 5],
+                bbox3d[box_idx, -1]
+            ])
+
+        # config
+        self.vis.create_window()
+        self.vis.get_render_option().point_size = 1.0
+        self.vis.get_render_option().background_color = [0, 0, 0]
+        pc_color = [1, 1, 1]
+        num_points = len(points)
+        pc_colors = np.tile(pc_color, (num_points, 1))
+
+        # raw point cloud
+        pts = open3d.geometry.PointCloud()
+        pts.points = open3d.utility.Vector3dVector(points[:, :3])
+        pts.colors = open3d.utility.Vector3dVector(pc_colors)
+        self.vis.add_geometry(pts)
+
+        # result_boxes
+        obs_color = [1, 0, 0]
+        result_boxes = np.array(result_boxes)
+        for i in range(result_boxes.shape[0]):
+            lines = self.boxes_to_lines(result_boxes[i])
+            # show different colors for different classes
+            if label_preds[i] <= 4:
+                obs_color = [0, 1, 0] # 'car', 'truck', 'trailer', 'bus', 'construction_vehicle',
+            elif (label_preds[i] <= 6):
+                obs_color = [0, 0, 1] # 'bicycle', 'motorcycle'
+            elif (label_preds[i] <= 7):
+                obs_color = [1, 0, 0] # 'pedestrian'
+            else:
+                obs_color = [1, 0, 1] # 'traffic_cone','barrier'
+            lines.paint_uniform_color(obs_color)
+            self.vis.add_geometry(lines)
+
+        self.vis.run()
+        self.vis.poll_events()
+        self.vis.update_renderer()
+        # self.vis.capture_screen_image("result.png")
+        self.vis.destroy_window()
+
+
+if __name__ == "__main__":
+
+    parser = argparse.ArgumentParser(description='Visualizer 3d')
+    parser.add_argument(
+        '--save_path',
+        type=str,
+        default=None)
+    
+    args = parser.parse_args()
+    save_path = args.save_path
+    if save_path is None:
+        raise ValueError("Please specify the path to the saved results.")
+    
+    points = np.load(os.path.join(save_path, "points.npy"), allow_pickle=True)
+    result = np.load(os.path.join(save_path, "results.npy"), allow_pickle=True).item()
+    
+    score_threshold = 0.25
+    vis = Visualizer3D()
+    vis.draw_results(points, result, score_threshold)

+ 2 - 1
requirements.txt

@@ -2,7 +2,7 @@ prettytable # only for benchmark
 py-cpuinfo # LaTeX_OCR_rec only support MKLDNN on CPU
 imagesize
 colorlog
-PyYAML
+PyYAML==6.0.2
 filelock
 ftfy
 ruamel.yaml
@@ -14,6 +14,7 @@ albumentations==1.4.10
 opencv-python==4.5.5.64
 opencv-python-headless==4.10.0.84
 opencv-contrib-python==4.10.0.84
+open3d
 chinese_calendar
 scikit-learn
 pycocotools