Browse Source

add layout post params (#2878)

* add layout post params

* add layout post params

* add layout models

* Update predictor.py

* add layout models
Sunflower7788 10 months ago
parent
commit
310b9acd6d

+ 43 - 0
api_examples/pipelines/test_object_detection.py

@@ -0,0 +1,43 @@
+# copyright (c) 2025 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddlex import create_pipeline
+
+pipeline = create_pipeline(pipeline="object_detection")
+
+output = pipeline.predict(
+    "./test_samples/general_layout.png",
+    threshold={0: 0.45, 2: 0.48, 7: 0.4},
+    layout_nms=True,
+    layout_merge_bboxes_mode="large",
+    layout_unclip_ratio=(1.0, 1.0)
+)
+
+# output = pipeline.predict(
+#     "./test_samples/general_layout.png",
+# )
+
+# output = pipeline.predict(
+#     "./test_samples/general_layout.png",
+#     threshold={0: 0.45, 2: 0.48, 7: 0.4},
+#     layout_nms=False,
+#     layout_merge_bboxes_mode="small", 
+#     layout_unclip_ratio=1.1
+# )
+
+for res in output:
+    print(res)
+    res.print()  ## 打印预测的结构化输出
+    res.save_to_img("./output/")  ## 保存结果可视化图像
+    res.save_to_json("./output/")  ## 保存预测的结构化输出

+ 40 - 0
paddlex/configs/modules/layout_detection/PP-DocLayout-L.yaml

@@ -0,0 +1,40 @@
+Global:
+  model: PP-DocLayout-L
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "/paddle/dataset/paddlex/layout/det_layout_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert:
+    enable: False
+    src_dataset_type: null
+  split:
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  num_classes: 11
+  epochs_iters: 100
+  batch_size: 1
+  learning_rate: 0.0001
+  pretrain_weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-DocLayout-L_pretrain.pdparams
+  warmup_steps: 100
+  resume_path: null
+  log_interval: 10
+  eval_interval: 1
+
+Evaluate:
+  weight_path: "output/best_model/best_model.pdparams"
+  log_interval: 10
+
+Export:
+  weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-DocLayout-L_pretrain.pdparams
+
+Predict:
+  batch_size: 1
+  model_dir: "output/best_model/inference"
+  input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/layout.jpg"
+  kernel_option:
+    run_mode: paddle

+ 40 - 0
paddlex/configs/modules/layout_detection/PP-DocLayout-M.yaml

@@ -0,0 +1,40 @@
+Global:
+  model: PP-DocLayout-M
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "/paddle/dataset/paddlex/layout/det_layout_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert:
+    enable: False
+    src_dataset_type: null
+  split:
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  num_classes: 11
+  epochs_iters: 100
+  batch_size: 1
+  learning_rate: 0.0001
+  pretrain_weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-DocLayout-M_pretrain.pdparams
+  warmup_steps: 100
+  resume_path: null
+  log_interval: 10
+  eval_interval: 1
+
+Evaluate:
+  weight_path: "output/best_model/best_model.pdparams"
+  log_interval: 10
+
+Export:
+  weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-DocLayout-M_pretrain.pdparams
+
+Predict:
+  batch_size: 1
+  model_dir: "output/best_model/inference"
+  input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/layout.jpg"
+  kernel_option:
+    run_mode: paddle

+ 40 - 0
paddlex/configs/modules/layout_detection/PP-DocLayout-S.yaml

@@ -0,0 +1,40 @@
+Global:
+  model: PP-DocLayout-S
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "/paddle/dataset/paddlex/layout/det_layout_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert:
+    enable: False
+    src_dataset_type: null
+  split:
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  num_classes: 11
+  epochs_iters: 100
+  batch_size: 1
+  learning_rate: 0.0001
+  pretrain_weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-DocLayout-S_pretrain.pdparams
+  warmup_steps: 100
+  resume_path: null
+  log_interval: 10
+  eval_interval: 1
+
+Evaluate:
+  weight_path: "output/best_model/best_model.pdparams"
+  log_interval: 10
+
+Export:
+  weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-DocLayout-S_pretrain.pdparams
+
+Predict:
+  batch_size: 1
+  model_dir: "output/best_model/inference"
+  input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/layout.jpg"
+  kernel_option:
+    run_mode: paddle

+ 1 - 1
paddlex/inference/models/object_detection.py

@@ -43,7 +43,7 @@ class DetPredictor(BasicPredictor):
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        model_names = ["DETR", "RCNN", "YOLOv3", "CenterNet"]
+        model_names = ["DETR", "RCNN", "YOLOv3", "CenterNet", "PP-DocLayout-L"]
         if any(name in self.model_name for name in model_names):
             predictor.set_inputs(
                 {

+ 56 - 7
paddlex/inference/models_new/object_detection/predictor.py

@@ -48,7 +48,10 @@ class DetPredictor(BasicPredictor):
         self,
         *args,
         img_size: Optional[Union[int, Tuple[int, int]]] = None,
-        threshold: Optional[float] = None,
+        threshold: Optional[Union[float, dict]] = None,
+        layout_nms: bool = False,
+        layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
+        layout_merge_bboxes_mode: Optional[str] = None,
         **kwargs,
     ):
         """Initializes DetPredictor.
@@ -57,6 +60,13 @@ class DetPredictor(BasicPredictor):
             img_size (Optional[Union[int, Tuple[int, int]]], optional): The input image size (w, h). Defaults to None.
             threshold (Optional[float], optional): The threshold for filtering out low-confidence predictions.
                 Defaults to None.
+            layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to False.
+            layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
+                Defaults to None.
+                If it's a single number, then both width and height are used.
+                If it's a tuple of two numbers, then they are used separately for width and height respectively.
+                If it's None, then no unclipping will be performed.
+            layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to None.
             **kwargs: Arbitrary keyword arguments passed to the superclass.
         """
         super().__init__(*args, **kwargs)
@@ -73,8 +83,26 @@ class DetPredictor(BasicPredictor):
                 raise ValueError(
                     f"The type of `img_size` must be int or Tuple[int, int], but got {type(img_size)}."
                 )
+        
+        if layout_unclip_ratio is not None:
+            if isinstance(layout_unclip_ratio, float):
+                layout_unclip_ratio = (layout_unclip_ratio, layout_unclip_ratio)
+            elif isinstance(layout_unclip_ratio, (tuple, list)):
+                assert len(layout_unclip_ratio) == 2, f"The length of `layout_unclip_ratio` should be 2."
+            else:
+                raise ValueError(
+                    f"The type of `layout_unclip_ratio` must be float or Tuple[float, float], but got {type(layout_unclip_ratio)}."
+                )
+        
+        if layout_merge_bboxes_mode is not None:
+            assert layout_merge_bboxes_mode in ["union", "large", "small"], \
+                f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_merge_bboxes_mode}"
+
         self.img_size = img_size
         self.threshold = threshold
+        self.layout_nms = layout_nms
+        self.layout_unclip_ratio = layout_unclip_ratio
+        self.layout_merge_bboxes_mode = layout_merge_bboxes_mode
         self.pre_ops, self.infer, self.post_op = self._build()
 
     def _build_batch_sampler(self):
@@ -170,12 +198,22 @@ class DetPredictor(BasicPredictor):
         else:
             return [{"boxes": np.array(res)} for res in pred_box]
 
-    def process(self, batch_data: List[Any], threshold: Optional[float] = None):
+    def process(self, 
+            batch_data: List[Any], 
+            threshold: Optional[Union[float, dict]] = None,
+            layout_nms: bool = False,
+            layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
+            layout_merge_bboxes_mode: Optional[str] = None,
+        ):
         """
         Process a batch of data through the preprocessing, inference, and postprocessing.
 
         Args:
             batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
+            threshold (Optional[float, dict], optional): The threshold for filtering out low-confidence predictions.
+            layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to False.
+            layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
+            layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to None.
 
         Returns:
             dict: A dictionary containing the input path, raw image, class IDs, scores, and label names
@@ -194,10 +232,14 @@ class DetPredictor(BasicPredictor):
 
         # process a batch of predictions into a list of single image result
         preds_list = self._format_output(batch_preds)
-
         # postprocess
         boxes = self.post_op(
-            preds_list, datas, threshold if threshold is not None else self.threshold
+            preds_list, 
+            datas, 
+            threshold = threshold or self.threshold,
+            layout_nms=layout_nms or self.layout_nms,
+            layout_unclip_ratio=layout_unclip_ratio or self.layout_unclip_ratio,
+            layout_merge_bboxes_mode=layout_merge_bboxes_mode or self.layout_merge_bboxes_mode
         )
 
         return {
@@ -268,6 +310,7 @@ class DetPredictor(BasicPredictor):
             "CenterNet",
             "BlazeFace",
             "BlazeFace-FPN-SSH",
+            "PP-DocLayout-L",
         ]
         if any(name in self.model_name for name in models_required_imgsize):
             ordered_required_keys = (
@@ -281,8 +324,14 @@ class DetPredictor(BasicPredictor):
         return ToBatch(ordered_required_keys=ordered_required_keys)
 
     def build_postprocess(self):
+        if self.threshold is None:
+            self.threshold = self.config.get("draw_threshold", 0.5)
+        if not self.layout_nms:
+            self.layout_nms = self.config.get("layout_nms", False)
+        if self.layout_unclip_ratio is None:
+            self.layout_unclip_ratio = self.config.get("layout_unclip_ratio", None)
+        if self.layout_merge_bboxes_mode is None:
+            self.layout_merge_bboxes_mode = self.config.get("layout_merge_bboxes_mode", None)
         return DetPostProcess(
-            threshold=self.config["draw_threshold"],
-            labels=self.config["label_list"],
-            layout_postprocess=self.config.get("layout_postprocess", False),
+            labels=self.config["label_list"]
         )

+ 161 - 154
paddlex/inference/models_new/object_detection/processors.py

@@ -424,52 +424,6 @@ class WarpAffine:
 
         return datas
 
-
-def compute_iou(box1: List[Number], box2: List[Number]) -> float:
-    """Compute the Intersection over Union (IoU) of two bounding boxes.
-
-    Args:
-        box1 (List[Number]): Coordinates of the first bounding box in format [x1, y1, x2, y2].
-        box2 (List[Number]): Coordinates of the second bounding box in format [x1, y1, x2, y2].
-
-    Returns:
-        float: The IoU of the two bounding boxes.
-    """
-    x1 = max(box1[0], box2[0])
-    y1 = max(box1[1], box2[1])
-    x2 = min(box1[2], box2[2])
-    y2 = min(box1[3], box2[3])
-    inter_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
-    box1_area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1)
-    box2_area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1)
-    iou = inter_area / float(box1_area + box2_area - inter_area)
-    return iou
-
-
-def is_box_mostly_inside(
-    inner_box: List[Number], outer_box: List[Number], threshold: float = 0.9
-) -> bool:
-    """Determine if one bounding box is mostly inside another bounding box.
-
-    Args:
-        inner_box (List[Number]): Coordinates of the inner bounding box in format [x1, y1, x2, y2].
-        outer_box (List[Number]): Coordinates of the outer bounding box in format [x1, y1, x2, y2].
-        threshold (float): The threshold for determining if the inner box is mostly inside the outer box (default is 0.9).
-
-    Returns:
-        bool: True if the ratio of the intersection area to the inner box area is greater than or equal to the threshold, False otherwise.
-    """
-    x1 = max(inner_box[0], outer_box[0])
-    y1 = max(inner_box[1], outer_box[1])
-    x2 = min(inner_box[2], outer_box[2])
-    y2 = min(inner_box[3], outer_box[3])
-    inter_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
-    inner_box_area = (inner_box[2] - inner_box[0] + 1) * (
-        inner_box[3] - inner_box[1] + 1
-    )
-    return (inter_area / inner_box_area) >= threshold
-
-
 def restructured_boxes(
     boxes: ndarray, labels: List[str], img_size: Tuple[int, int]
 ) -> Boxes:
@@ -544,48 +498,125 @@ def restructured_rotated_boxes(
 
     return box_list
 
-
-def non_max_suppression(
-    boxes: ndarray, scores: ndarray, iou_threshold: float
-) -> List[int]:
+def unclip_boxes(boxes, unclip_ratio=None):
     """
-    Perform non-maximum suppression to remove redundant overlapping boxes with
-    lower scores. This function is commonly used in object detection tasks.
-
+    Expand bounding boxes from (x1, y1, x2, y2) format using an unclipping ratio.
+    
     Parameters:
-    boxes (ndarray): An array of shape (N, 4) representing the bounding boxes.
-        Each row is in the format [x1, y1, x2, y2].
-    scores (ndarray): An array of shape (N,) containing the scores for each box.
-    iou_threshold (float): The Intersection over Union (IoU) threshold to use
-        for suppressing overlapping boxes.
-
+    - boxes: np.ndarray of shape (N, 4), where each row is (x1, y1, x2, y2).
+    - unclip_ratio: tuple of (width_ratio, height_ratio), optional.
+    
     Returns:
-    List[int]: A list of indices representing the indices of the boxes to keep.
+    - expanded_boxes: np.ndarray of shape (N, 4), where each row is (x1, y1, x2, y2).
     """
-    if len(boxes) == 0:
-        return []
-    x1 = boxes[:, 0]
-    y1 = boxes[:, 1]
-    x2 = boxes[:, 2]
-    y2 = boxes[:, 3]
-    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
-    order = scores.argsort()[::-1]
-    keep = []
-    while order.size > 0:
-        i = order[0]
-        keep.append(i)
-        xx1 = np.maximum(x1[i], x1[order[1:]])
-        yy1 = np.maximum(y1[i], y1[order[1:]])
-        xx2 = np.minimum(x2[i], x2[order[1:]])
-        yy2 = np.minimum(y2[i], y2[order[1:]])
-
-        w = np.maximum(0.0, xx2 - xx1 + 1)
-        h = np.maximum(0.0, yy2 - yy1 + 1)
-        inter = w * h
-        iou = inter / (areas[i] + areas[order[1:]] - inter)
-        inds = np.where(iou <= iou_threshold)[0]
-        order = order[inds + 1]
-    return keep
+    if unclip_ratio is None:
+        return boxes
+    
+    widths = boxes[:, 4] - boxes[:, 2]
+    heights = boxes[:, 5] - boxes[:, 3]
+
+    new_w = widths * unclip_ratio[0]
+    new_h = heights * unclip_ratio[1]
+    center_x = boxes[:, 2] + widths / 2
+    center_y = boxes[:, 3] + heights / 2
+
+    new_x1 = center_x - new_w / 2
+    new_y1 = center_y - new_h / 2
+    new_x2 = center_x + new_w / 2
+    new_y2 = center_y + new_h / 2
+    expanded_boxes = np.column_stack((boxes[:, 0], boxes[:, 1], new_x1, new_y1, new_x2, new_y2))
+
+    return expanded_boxes
+
+
+def iou(box1, box2):
+    """Compute the Intersection over Union (IoU) of two bounding boxes."""
+    x1, y1, x2, y2 = box1
+    x1_p, y1_p, x2_p, y2_p = box2
+
+    # Compute the intersection coordinates
+    x1_i = max(x1, x1_p)
+    y1_i = max(y1, y1_p)
+    x2_i = min(x2, x2_p)
+    y2_i = min(y2, y2_p)
+
+    # Compute the area of intersection
+    inter_area = max(0, x2_i - x1_i + 1) * max(0, y2_i - y1_i + 1)
+
+    # Compute the area of both bounding boxes
+    box1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
+    box2_area = (x2_p - x1_p + 1) * (y2_p - y1_p + 1)
+
+    # Compute the IoU
+    iou_value = inter_area / float(box1_area + box2_area - inter_area)
+
+    return iou_value
+
+def nms(boxes, iou_same=0.6, iou_diff=0.95):
+    """Perform Non-Maximum Suppression (NMS) with different IoU thresholds for same and different classes."""
+    # Extract class scores
+    scores = boxes[:, 1]
+
+    # Sort indices by scores in descending order
+    indices = np.argsort(scores)[::-1]
+    selected_boxes = []
+
+    while len(indices) > 0:
+        current = indices[0]
+        current_box = boxes[current]
+        current_class = current_box[0]
+        current_score = current_box[1]
+        current_coords = current_box[2:]
+
+        selected_boxes.append(current)
+        indices = indices[1:]
+
+        filtered_indices = []
+        for i in indices:
+            box = boxes[i]
+            box_class = box[0]
+            box_coords = box[2:]
+            iou_value = iou(current_coords, box_coords)
+            threshold = iou_same if current_class == box_class else iou_diff
+
+            # If the IoU is below the threshold, keep the box
+            if iou_value < threshold:
+                filtered_indices.append(i)
+        indices = filtered_indices
+    return selected_boxes
+
+def is_contained(box1, box2):
+    """Check if box1 is contained within box2."""
+    _, _, x1, y1, x2, y2 = box1
+    _, _, x1_p, y1_p, x2_p, y2_p = box2
+    box1_area = (x2 - x1) * (y2 - y1)
+    xi1 = max(x1, x1_p)
+    yi1 = max(y1, y1_p)
+    xi2 = min(x2, x2_p)
+    yi2 = min(y2, y2_p)
+    inter_width = max(0, xi2 - xi1)
+    inter_height = max(0, yi2 - yi1)
+    intersect_area = inter_width * inter_height
+    iou = intersect_area / box1_area if box1_area > 0 else 0
+    return iou >= 0.9
+
+def check_containment(boxes, formula_index=None):
+    """Check containment relationships among boxes."""
+    n = len(boxes)
+    contains_other = np.zeros(n, dtype=int)
+    contained_by_other = np.zeros(n, dtype=int)
+
+    for i in range(n):
+        for j in range(n):
+            if i == j:
+                continue
+            if formula_index is not None:
+                if boxes[i][0] == formula_index and boxes[j][0] != formula_index:
+                    continue
+            if is_contained(boxes[i], boxes[j]):
+                contained_by_other[i] = 1
+                contains_other[j] = 1
+    return contains_other, contained_by_other
 
 
 class DetPostProcess:
@@ -598,9 +629,7 @@ class DetPostProcess:
 
     def __init__(
         self,
-        threshold: float = 0.5,
-        labels: Optional[List[str]] = None,
-        layout_postprocess: bool = False,
+        labels: Optional[List[str]] = None
     ) -> None:
         """Initialize the DetPostProcess class.
 
@@ -610,11 +639,16 @@ class DetPostProcess:
             layout_postprocess (bool, optional): Whether to apply layout post-processing. Defaults to False.
         """
         super().__init__()
-        self.threshold = threshold
         self.labels = labels
-        self.layout_postprocess = layout_postprocess
 
-    def apply(self, boxes: ndarray, img_size, threshold: Union[float, dict]) -> Boxes:
+    def apply(self, 
+            boxes: ndarray, 
+            img_size: Tuple[int, int],
+            threshold: Union[float, dict], 
+            layout_nms: bool, 
+            layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]], 
+            layout_merge_bboxes_mode: Optional[str]
+        ) -> Boxes:
         """Apply post-processing to the detection boxes.
 
         Args:
@@ -631,9 +665,8 @@ class DetPostProcess:
             category_filtered_boxes = []
             for cat_id in np.unique(boxes[:, 0]):
                 category_boxes = boxes[boxes[:, 0] == cat_id]
-                category_scores = category_boxes[:, 1]
                 category_threshold = threshold.get(int(cat_id), 0.5)
-                selected_indices = category_scores > category_threshold
+                selected_indices = (category_boxes[:, 1] > category_threshold) & (category_boxes[:, 0] > -1)
                 category_filtered_boxes.append(category_boxes[selected_indices])
             boxes = (
                 np.vstack(category_filtered_boxes)
@@ -641,69 +674,37 @@ class DetPostProcess:
                 else np.array([])
             )
 
-        if self.layout_postprocess:
+        if layout_nms:
             filtered_boxes = []
             ### Layout postprocess for NMS
-            for cat_id in np.unique(boxes[:, 0]):
-                category_boxes = boxes[boxes[:, 0] == cat_id]
-                category_scores = category_boxes[:, 1]
-                if len(category_boxes) > 0:
-                    nms_indices = non_max_suppression(
-                        category_boxes[:, 2:], category_scores, 0.5
-                    )
-                    category_boxes = category_boxes[nms_indices]
-                    keep_boxes = []
-                    for i, box in enumerate(category_boxes):
-                        if all(
-                            not is_box_mostly_inside(box[2:], other_box[2:])
-                            for j, other_box in enumerate(category_boxes)
-                            if i != j
-                        ):
-                            keep_boxes.append(box)
-                    filtered_boxes.extend(keep_boxes)
-            boxes = np.array(filtered_boxes)
-            ### Layout postprocess for removing boxes inside image category box
-            if self.labels and "image" in self.labels:
-                image_cls_id = self.labels.index("image")
-                if len(boxes) > 0:
-                    image_boxes = boxes[boxes[:, 0] == image_cls_id]
-                    other_boxes = boxes[boxes[:, 0] != image_cls_id]
-                    to_keep = []
-                    for box in other_boxes:
-                        keep = True
-                        for img_box in image_boxes:
-                            if (
-                                box[2] >= img_box[2]
-                                and box[3] >= img_box[3]
-                                and box[4] <= img_box[4]
-                                and box[5] <= img_box[5]
-                            ):
-                                keep = False
-                                break
-                        if keep:
-                            to_keep.append(box)
-                    boxes = (
-                        np.vstack([image_boxes, to_keep]) if to_keep else image_boxes
-                    )
-            ### Layout postprocess for overlaps
-            final_boxes = []
-            while len(boxes) > 0:
-                current_box = boxes[0]
-                current_score = current_box[1]
-                overlaps = [current_box]
-                non_overlaps = []
-                for other_box in boxes[1:]:
-                    iou = compute_iou(current_box[2:], other_box[2:])
-                    if iou > 0.95:
-                        if other_box[1] > current_score:
-                            overlaps.append(other_box)
-                    else:
-                        non_overlaps.append(other_box)
-                best_box = max(overlaps, key=lambda x: x[1])
-                final_boxes.append(best_box)
-                boxes = np.array(non_overlaps)
-            boxes = np.array(final_boxes)
-
+            selected_indices = nms(boxes, iou_same=0.6, iou_diff=0.98)
+            boxes = np.array(boxes[selected_indices])
+
+        if layout_merge_bboxes_mode:
+            assert layout_merge_bboxes_mode in ["union", "large", "small"], \
+                f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_merge_bboxes_mode}"
+
+            if layout_merge_bboxes_mode == "union":
+                pass
+            else:
+                formula_index = self.labels.index("formula") if "formula" in self.labels else None
+                contains_other, contained_by_other = check_containment(boxes, formula_index)
+                if layout_merge_bboxes_mode == "large":
+                    boxes = boxes[contained_by_other == 0]
+                elif layout_merge_bboxes_mode == "small":
+                    boxes = boxes[(contains_other == 0) | (contained_by_other == 1)] 
+
+        if layout_unclip_ratio:
+            if isinstance(layout_unclip_ratio, float):
+                layout_unclip_ratio = (layout_unclip_ratio, layout_unclip_ratio)
+            elif isinstance(layout_unclip_ratio, (tuple, list)):
+                assert len(layout_unclip_ratio) == 2, f"The length of `layout_unclip_ratio` should be 2."
+            else:
+                raise ValueError(
+                    f"The type of `layout_unclip_ratio` must be float or Tuple[float, float], but got {type(layout_unclip_ratio)}."
+                )
+            boxes = unclip_boxes(boxes, layout_unclip_ratio)
+        
         if boxes.shape[1] == 6:
             """For Normal Object Detection"""
             boxes = restructured_boxes(boxes, self.labels, img_size)
@@ -722,6 +723,9 @@ class DetPostProcess:
         batch_outputs: List[dict],
         datas: List[dict],
         threshold: Optional[Union[float, dict]] = None,
+        layout_nms: bool = False,
+        layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
+        layout_merge_bboxes_mode: Optional[str] = None,
     ) -> List[Boxes]:
         """Apply the post-processing to a batch of outputs.
 
@@ -737,7 +741,10 @@ class DetPostProcess:
             boxes = self.apply(
                 output["boxes"],
                 data["ori_img_size"],
-                threshold if threshold is not None else self.threshold,
+                threshold,
+                layout_nms,
+                layout_unclip_ratio,
+                layout_merge_bboxes_mode
             )
             outputs.append(boxes)
         return outputs

+ 3 - 0
paddlex/inference/models_new/object_detection/utils.py

@@ -62,4 +62,7 @@ STATIC_SHAPE_MODEL_LIST = [
     "YOLOX-S",
     "YOLOX-T",
     "YOLOX-X",
+    "PP-DocLayout-L",
+    "PP-DocLayout-M",
+    "PP-DocLayout-S",
 ]

+ 26 - 3
paddlex/inference/pipelines_new/object_detection/pipeline.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Union, Tuple
 import numpy as np
 
 from ...utils.pp_option import PaddlePredictorOption
@@ -50,21 +50,44 @@ class ObjectDetectionPipeline(BasePipeline):
             model_kwargs["threshold"] = model_cfg["threshold"]
         if "img_size" in model_cfg:
             model_kwargs["img_size"] = model_cfg["img_size"]
+        if "layout_nms" in model_cfg:
+            model_kwargs["layout_nms"] = model_cfg["layout_nms"]
+        if "layout_unclip_ratio" in model_cfg:
+            model_kwargs["layout_unclip_ratio"] = model_cfg["layout_unclip_ratio"]
+        if "layout_merge_bboxes_mode" in model_cfg:
+            model_kwargs["layout_merge_bboxes_mode"] = model_cfg["layout_merge_bboxes_mode"]
         self.det_model = self.create_model(model_cfg, **model_kwargs)
 
     def predict(
         self,
         input: str | list[str] | np.ndarray | list[np.ndarray],
-        threshold: Optional[float] = None,
+        threshold: Optional[Union[float, dict]] = None,
+        layout_nms: bool = False,
+        layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
+        layout_merge_bboxes_mode: Optional[str] = None,
         **kwargs,
     ) -> DetResult:
         """Predicts object detection results for the given input.
 
         Args:
             input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) or path(s) to the images.
+            img_size (Optional[Union[int, Tuple[int, int]]]): The size of the input image. Default is None.
             threshold (Optional[float]): The threshold value to filter out low-confidence predictions. Default is None.
+            layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to False.
+            layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
+                Defaults to None.
+                If it's a single number, then both width and height are used.
+                If it's a tuple of two numbers, then they are used separately for width and height respectively.
+                If it's None, then no unclipping will be performed.
+            layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to None.
             **kwargs: Additional keyword arguments that can be passed to the function.
         Returns:
             DetResult: The predicted detection results.
         """
-        yield from self.det_model(input, threshold=threshold, **kwargs)
+        yield from self.det_model(
+                input, 
+                threshold=threshold, 
+                layout_nms=layout_nms, 
+                layout_unclip_ratio=layout_unclip_ratio, 
+                layout_merge_bboxes_mode=layout_merge_bboxes_mode,
+                **kwargs)

+ 3 - 0
paddlex/inference/utils/official_models.py

@@ -329,6 +329,9 @@ PP-LCNet_x1_0_vehicle_attribute_infer.tar",
     "GroundingDINO-T": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0rc0/GroundingDINO-T_infer.tar",
     "SAM-H_box": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0rc0/SAM-H_box_infer.tar",
     "SAM-H_point": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0rc0/SAM-H_point_infer.tar",
+    "PP-DocLayout-L": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0rc0/PP-DocLayout-L_infer.tar",
+    "PP-DocLayout-M": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0rc0/PP-DocLayout-M_infer.tar",
+    "PP-DocLayout-S": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0rc0/PP-DocLayout-S_infer.tar",
 }
 
 

+ 3 - 0
paddlex/modules/object_detection/model_list.py

@@ -78,4 +78,7 @@ MODELS = [
     "Co-DINO-Swin-L",
     "RT-DETR-L_wired_table_cell_det",
     "RT-DETR-L_wireless_table_cell_det",
+    "PP-DocLayout-L",
+    "PP-DocLayout-M",
+    "PP-DocLayout-S",
 ]

+ 173 - 0
paddlex/repo_apis/PaddleDetection_api/configs/PP-DocLayout-L.yaml

@@ -0,0 +1,173 @@
+# Runtime
+epoch: 40
+log_iter: 10
+find_unused_parameters: true
+use_gpu: true
+use_xpu: false
+use_mlu: false
+use_npu: false
+use_ema: true
+ema_decay: 0.9999
+ema_decay_type: "exponential"
+ema_filter_no_grad: true
+save_dir: output
+snapshot_epoch: 1
+print_flops: false
+print_params: false
+eval_size: [640, 640]
+
+# Dataset
+metric: COCO
+num_classes: 80
+
+worker_num: 4
+
+TrainDataset:
+  name: COCODetDataset
+  image_dir: images
+  anno_path: annotations/instance_train.json
+  dataset_dir: datasets/COCO
+  data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
+
+EvalDataset:
+  name: COCODetDataset
+  image_dir: images
+  anno_path: annotations/instance_val.json
+  dataset_dir: datasets/COCO
+  allow_empty: true
+
+TestDataset:
+  name: ImageFolder
+  anno_path: annotations/instance_val.json
+  dataset_dir: datasets/COCO
+
+TrainReader:
+  sample_transforms:
+    - Decode: {}
+    - RandomDistort: {prob: 0.8}
+    - RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
+    - RandomCrop: {prob: 0.8}
+    - RandomFlip: {}
+  batch_transforms:
+    - BatchRandomResize: {target_size: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], random_size: True, random_interp: True, keep_ratio: False}
+    - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
+    - NormalizeBox: {}
+    - BboxXYXY2XYWH: {}
+    - Permute: {}
+  batch_size: 8
+  shuffle: true
+  drop_last: true
+  collate_batch: false
+  use_shared_memory: true
+
+EvalReader:
+  sample_transforms:
+    - Decode: {}
+    - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
+    - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
+    - Permute: {}
+  batch_size: 4
+  shuffle: false
+  drop_last: false
+
+TestReader:
+  inputs_def:
+    image_shape: [3, 640, 640]
+  sample_transforms:
+    - Decode: {}
+    - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
+    - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
+    - Permute: {}
+  batch_size: 1
+  shuffle: false
+  drop_last: false
+
+# Model
+architecture: DETR
+pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/PPHGNetV2_L_ssld_pretrained.pdparams
+
+norm_type: sync_bn
+hidden_dim: 256
+use_focal_loss: True
+
+DETR:
+  backbone: PPHGNetV2
+  neck: HybridEncoder
+  transformer: RTDETRTransformer
+  detr_head: DINOHead
+  post_process: DETRPostProcess
+
+PPHGNetV2:
+  arch: 'L'
+  return_idx: [1, 2, 3]
+  freeze_stem_only: true
+  freeze_at: 0
+  freeze_norm: true
+  lr_mult_list: [0., 0.05, 0.05, 0.05, 0.05]
+
+HybridEncoder:
+  hidden_dim: 256
+  use_encoder_idx: [2]
+  num_encoder_layers: 1
+  encoder_layer:
+    name: TransformerLayer
+    d_model: 256
+    nhead: 8
+    dim_feedforward: 1024
+    dropout: 0.
+    activation: 'gelu'
+  expansion: 1.0
+
+RTDETRTransformer:
+  num_queries: 300
+  position_embed_type: sine
+  feat_strides: [8, 16, 32]
+  num_levels: 3
+  nhead: 8
+  num_decoder_layers: 6
+  dim_feedforward: 1024
+  dropout: 0.0
+  activation: relu
+  num_denoising: 100
+  label_noise_ratio: 0.5
+  box_noise_scale: 1.0
+  learnt_init_query: false
+
+DINOHead:
+  loss:
+    name: DINOLoss
+    loss_coeff: {class: 1, bbox: 5, giou: 2}
+    aux_loss: true
+    use_vfl: true
+    matcher:
+      name: HungarianMatcher
+      matcher_coeff: {class: 2, bbox: 5, giou: 2}
+
+DETRPostProcess:
+  num_top_queries: 300
+
+# Optimizer
+LearningRate:
+  base_lr: 0.0001
+  schedulers:
+  - !PiecewiseDecay
+    gamma: 1.0
+    milestones: [100]
+    use_warmup: true
+  - !LinearWarmup
+    start_factor: 0.001
+    steps: 100
+
+OptimizerBuilder:
+  clip_grad_by_norm: 0.1
+  regularizer: false
+  optimizer:
+    type: AdamW
+    weight_decay: 0.0001
+
+# Export
+export:
+  post_process: true
+  nms: true
+  benchmark: false
+  fuse_conv_bn: false

+ 165 - 0
paddlex/repo_apis/PaddleDetection_api/configs/PP-DocLayout-M.yaml

@@ -0,0 +1,165 @@
+# Runtime
+epoch: 100
+log_iter: 10
+find_unused_parameters: true
+use_gpu: true
+use_xpu: false
+use_mlu: false
+use_npu: false
+use_ema: true
+save_dir: output
+snapshot_epoch: 10
+print_flops: false
+print_params: false
+
+# Dataset
+metric: COCO
+num_classes: 17
+
+worker_num: 6
+eval_height: &eval_height 640
+eval_width: &eval_width 640
+eval_size: &eval_size [*eval_height, *eval_width]
+
+TrainDataset:
+  name: COCODetDataset
+  image_dir: images
+  anno_path: annotations/instance_train.json
+  dataset_dir: datasets/COCO
+  data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
+
+EvalDataset:
+  name: COCODetDataset
+  image_dir: images
+  anno_path: annotations/instance_val.json
+  dataset_dir: datasets/COCO
+  allow_empty: true
+
+TestDataset:
+  name: ImageFolder
+  anno_path: annotations/instance_val.json
+  dataset_dir: datasets/COCO
+
+TrainReader:
+  sample_transforms:
+  - Decode: {}
+  - RandomCrop: {}
+  - RandomFlip: {prob: 0.5}
+  - RandomDistort: {}
+  batch_transforms:
+  - BatchRandomResize: {target_size: [576, 608, 640, 672, 704], random_size: True, random_interp: True, keep_ratio: False}
+  - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
+  - Permute: {}
+  - PadGT: {}
+  batch_size: 16
+  shuffle: true
+  drop_last: true
+
+EvalReader:
+  sample_transforms:
+  - Decode: {}
+  - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
+  - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
+  - Permute: {}
+  batch_transforms:
+  - PadBatch: {pad_to_stride: 32}
+  batch_size: 8
+  shuffle: false
+
+TestReader:
+  inputs_def:
+    image_shape: [3, *eval_height, *eval_width]
+  sample_transforms:
+  - Decode: {}
+  - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
+  - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
+  - Permute: {}
+  batch_size: 1
+
+# Model
+architecture: PicoDet
+pretrain_weights: https://paddle-model-ecology.bj.bcebos.com/paddlex/pretrained/PicoDet-L_layout_pretrained_v1.pdparams
+
+PicoDet:
+  backbone: LCNet
+  neck: LCPAN
+  head: PicoHeadV2
+
+LCNet:
+  scale: 2.0
+  feature_maps: [3, 4, 5]
+
+LCPAN:
+  out_channels: 160
+  use_depthwise: true
+  num_features: 4
+
+PicoHeadV2:
+  conv_feat:
+    name: PicoFeat
+    feat_in: 160
+    feat_out: 160
+    num_convs: 4
+    num_fpn_stride: 4
+    norm_type: bn
+    share_cls_reg: true
+    use_se: true
+  fpn_stride: [8, 16, 32, 64]
+  feat_in_chan: 160
+  prior_prob: 0.01
+  reg_max: 7
+  cell_offset: 0.5
+  grid_cell_scale: 5.0
+  static_assigner_epoch: 100
+  use_align_head: true
+  static_assigner:
+    name: ATSSAssigner
+    topk: 9
+    force_gt_matching: false
+  assigner:
+    name: TaskAlignedAssigner
+    topk: 13
+    alpha: 1.0
+    beta: 6.0
+  loss_class:
+    name: VarifocalLoss
+    use_sigmoid: false
+    iou_weighted: true
+    loss_weight: 1.0
+  loss_dfl:
+    name: DistributionFocalLoss
+    loss_weight: 0.5
+  loss_bbox:
+    name: GIoULoss
+    loss_weight: 2.5
+  nms:
+    name: MultiClassNMS
+    nms_top_k: 1000
+    keep_top_k: 100
+    score_threshold: 0.025
+    nms_threshold: 0.6
+
+# Optimizer
+LearningRate:
+  base_lr: 0.06
+  schedulers:
+  - name: CosineDecay
+    max_epochs: 150
+  - name: LinearWarmup
+    start_factor: 0.1
+    steps: 300
+
+OptimizerBuilder:
+  optimizer:
+    momentum: 0.9
+    type: Momentum
+  regularizer:
+    factor: 0.00004
+    type: L2
+
+# Export
+export:
+  post_process: true
+  nms: true
+  benchmark: false
+  fuse_conv_bn: false

+ 165 - 0
paddlex/repo_apis/PaddleDetection_api/configs/PP-DocLayout-S.yaml

@@ -0,0 +1,165 @@
+# Runtime
+epoch: 100
+log_iter: 10
+find_unused_parameters: true
+use_gpu: true
+use_xpu: false
+use_mlu: false
+use_npu: false
+use_ema: true
+save_dir: output
+snapshot_epoch: 10
+print_flops: false
+print_params: false
+
+# Dataset
+metric: COCO
+num_classes: 17
+
+worker_num: 6
+eval_height: &eval_height 480
+eval_width: &eval_width 480
+eval_size: &eval_size [*eval_height, *eval_width]
+
+TrainDataset:
+  name: COCODetDataset
+  image_dir: images
+  anno_path: annotations/instance_train.json
+  dataset_dir: datasets/COCO
+  data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
+
+EvalDataset:
+  name: COCODetDataset
+  image_dir: images
+  anno_path: annotations/instance_val.json
+  dataset_dir: datasets/COCO
+  allow_empty: true
+
+TestDataset:
+  name: ImageFolder
+  anno_path: annotations/instance_val.json
+  dataset_dir: datasets/COCO
+
+TrainReader:
+  sample_transforms:
+  - Decode: {}
+  - RandomCrop: {}
+  - RandomFlip: {prob: 0.5}
+  - RandomDistort: {}
+  batch_transforms:
+  - BatchRandomResize: {target_size: [416, 448, 480, 512, 544], random_size: True, random_interp: True, keep_ratio: False}
+  - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
+  - Permute: {}
+  - PadGT: {}
+  batch_size: 16
+  shuffle: true
+  drop_last: true
+
+EvalReader:
+  sample_transforms:
+  - Decode: {}
+  - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
+  - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
+  - Permute: {}
+  batch_transforms:
+  - PadBatch: {pad_to_stride: 32}
+  batch_size: 8
+  shuffle: false
+
+TestReader:
+  inputs_def:
+    image_shape: [3, *eval_height, *eval_width]
+  sample_transforms:
+  - Decode: {}
+  - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
+  - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
+  - Permute: {}
+  batch_size: 1
+
+# Model
+architecture: PicoDet
+pretrain_weights: https://paddle-model-ecology.bj.bcebos.com/paddlex/pretrained/PicoDet-S_layout_pretrained_17cls.pdparams
+
+PicoDet:
+  backbone: LCNet
+  neck: LCPAN
+  head: PicoHeadV2
+
+LCNet:
+  scale: 0.75
+  feature_maps: [3, 4, 5]
+
+LCPAN:
+  out_channels: 96
+  use_depthwise: true
+  num_features: 4
+
+PicoHeadV2:
+  conv_feat:
+    name: PicoFeat
+    feat_in: 96
+    feat_out: 96
+    num_convs: 2
+    num_fpn_stride: 4
+    norm_type: bn
+    share_cls_reg: true
+    use_se: true
+  fpn_stride: [8, 16, 32, 64]
+  feat_in_chan: 96
+  prior_prob: 0.01
+  reg_max: 7
+  cell_offset: 0.5
+  grid_cell_scale: 5.0
+  static_assigner_epoch: 100
+  use_align_head: true
+  static_assigner:
+    name: ATSSAssigner
+    topk: 9
+    force_gt_matching: false
+  assigner:
+    name: TaskAlignedAssigner
+    topk: 13
+    alpha: 1.0
+    beta: 6.0
+  loss_class:
+    name: VarifocalLoss
+    use_sigmoid: false
+    iou_weighted: true
+    loss_weight: 1.0
+  loss_dfl:
+    name: DistributionFocalLoss
+    loss_weight: 0.5
+  loss_bbox:
+    name: GIoULoss
+    loss_weight: 2.5
+  nms:
+    name: MultiClassNMS
+    nms_top_k: 1000
+    keep_top_k: 100
+    score_threshold: 0.025
+    nms_threshold: 0.6
+
+# Optimizer
+LearningRate:
+  base_lr: 0.08
+  schedulers:
+  - name: CosineDecay
+    max_epochs: 300
+  - name: LinearWarmup
+    start_factor: 0.1
+    steps: 100
+
+OptimizerBuilder:
+  optimizer:
+    momentum: 0.9
+    type: Momentum
+  regularizer:
+    factor: 0.00004
+    type: L2
+
+# Export
+export:
+  post_process: true
+  nms: true
+  benchmark: false
+  fuse_conv_bn: false

+ 45 - 0
paddlex/repo_apis/PaddleDetection_api/object_det/register.py

@@ -1059,3 +1059,48 @@ register_model_info(
         },
     }
 )
+
+register_model_info(
+    {
+        "model_name": "PP-DocLayout-L",
+        "suite": "Det",
+        "config_path": osp.join(PDX_CONFIG_DIR, "PP-DocLayout-L.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
+        "supported_dataset_types": ["COCODetDataset"],
+        "supported_train_opts": {
+            "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
+            "dy2st": False,
+            "amp": ["OFF"],
+        },
+    }
+)
+
+register_model_info(
+    {
+        "model_name": "PP-DocLayout-M",
+        "suite": "Det",
+        "config_path": osp.join(PDX_CONFIG_DIR, "PP-DocLayout-M.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
+        "supported_dataset_types": ["COCODetDataset"],
+        "supported_train_opts": {
+            "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
+            "dy2st": False,
+            "amp": ["OFF"],
+        },
+    }
+)
+
+register_model_info(
+    {
+        "model_name": "PP-DocLayout-S",
+        "suite": "Det",
+        "config_path": osp.join(PDX_CONFIG_DIR, "PP-DocLayout-S.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
+        "supported_dataset_types": ["COCODetDataset"],
+        "supported_train_opts": {
+            "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
+            "dy2st": False,
+            "amp": ["OFF"],
+        },
+    }
+)