浏览代码

fix_layout_class_mode

Sunting78 8 月之前
父节点
当前提交
a4c925703a
共有 1 个文件被更改,包括 10 次插入9 次删除
  1. 10 9
      paddlex/inference/models/object_detection/predictor.py

+ 10 - 9
paddlex/inference/models/object_detection/predictor.py

@@ -51,7 +51,7 @@ class DetPredictor(BasicPredictor):
         threshold: Optional[Union[float, dict]] = None,
         threshold: Optional[Union[float, dict]] = None,
         layout_nms: Optional[bool] = None,
         layout_nms: Optional[bool] = None,
         layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
         layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
-        layout_merge_bboxes_mode: Optional[str] = None,
+        layout_merge_bboxes_mode: Optional[Union[str, dict]] = None,
         **kwargs,
         **kwargs,
     ):
     ):
         """Initializes DetPredictor.
         """Initializes DetPredictor.
@@ -66,7 +66,7 @@ class DetPredictor(BasicPredictor):
                 If it's a single number, then both width and height are used.
                 If it's a single number, then both width and height are used.
                 If it's a tuple of two numbers, then they are used separately for width and height respectively.
                 If it's a tuple of two numbers, then they are used separately for width and height respectively.
                 If it's None, then no unclipping will be performed.
                 If it's None, then no unclipping will be performed.
-            layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to None.
+            layout_merge_bboxes_mode (Optional[Union[str, dict]], optional): The mode for merging bounding boxes. Defaults to None.
             **kwargs: Arbitrary keyword arguments passed to the superclass.
             **kwargs: Arbitrary keyword arguments passed to the superclass.
         """
         """
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
@@ -97,11 +97,12 @@ class DetPredictor(BasicPredictor):
                 )
                 )
 
 
         if layout_merge_bboxes_mode is not None:
         if layout_merge_bboxes_mode is not None:
-            assert layout_merge_bboxes_mode in [
-                "union",
-                "large",
-                "small",
-            ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_merge_bboxes_mode}"
+            if isinstance(layout_merge_bboxes_mode, str):
+                assert layout_merge_bboxes_mode in [
+                    "union",
+                    "large",
+                    "small",
+                ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'] or a dict, but got {layout_merge_bboxes_mode}"
 
 
         self.img_size = img_size
         self.img_size = img_size
         self.threshold = threshold
         self.threshold = threshold
@@ -209,7 +210,7 @@ class DetPredictor(BasicPredictor):
         threshold: Optional[Union[float, dict]] = None,
         threshold: Optional[Union[float, dict]] = None,
         layout_nms: bool = False,
         layout_nms: bool = False,
         layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
         layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
-        layout_merge_bboxes_mode: Optional[str] = None,
+        layout_merge_bboxes_mode: Optional[Union[str, dict]] = None,
     ):
     ):
         """
         """
         Process a batch of data through the preprocessing, inference, and postprocessing.
         Process a batch of data through the preprocessing, inference, and postprocessing.
@@ -219,7 +220,7 @@ class DetPredictor(BasicPredictor):
             threshold (Optional[float, dict], optional): The threshold for filtering out low-confidence predictions.
             threshold (Optional[float, dict], optional): The threshold for filtering out low-confidence predictions.
             layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to None.
             layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to None.
             layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
             layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
-            layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to None.
+            layout_merge_bboxes_mode (Optional[Union[str, dict]], optional): The mode for merging bounding boxes. Defaults to None.
 
 
         Returns:
         Returns:
             dict: A dictionary containing the input path, raw image, class IDs, scores, and label names
             dict: A dictionary containing the input path, raw image, class IDs, scores, and label names