Browse Source

fix_layout_class_mode

Sunting78 8 months ago
parent
commit
a4c925703a
1 changed files with 10 additions and 9 deletions
  1. 10 9
      paddlex/inference/models/object_detection/predictor.py

+ 10 - 9
paddlex/inference/models/object_detection/predictor.py

@@ -51,7 +51,7 @@ class DetPredictor(BasicPredictor):
         threshold: Optional[Union[float, dict]] = None,
         layout_nms: Optional[bool] = None,
         layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
-        layout_merge_bboxes_mode: Optional[str] = None,
+        layout_merge_bboxes_mode: Optional[Union[str, dict]] = None,
         **kwargs,
     ):
         """Initializes DetPredictor.
@@ -66,7 +66,7 @@ class DetPredictor(BasicPredictor):
                 If it's a single number, then both width and height are used.
                 If it's a tuple of two numbers, then they are used separately for width and height respectively.
                 If it's None, then no unclipping will be performed.
-            layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to None.
+            layout_merge_bboxes_mode (Optional[Union[str, dict]], optional): The mode for merging bounding boxes. Defaults to None.
             **kwargs: Arbitrary keyword arguments passed to the superclass.
         """
         super().__init__(*args, **kwargs)
@@ -97,11 +97,12 @@ class DetPredictor(BasicPredictor):
                 )
 
         if layout_merge_bboxes_mode is not None:
-            assert layout_merge_bboxes_mode in [
-                "union",
-                "large",
-                "small",
-            ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_merge_bboxes_mode}"
+            if isinstance(layout_merge_bboxes_mode, str):
+                assert layout_merge_bboxes_mode in [
+                    "union",
+                    "large",
+                    "small",
+                ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'] or a dict, but got {layout_merge_bboxes_mode}"
 
         self.img_size = img_size
         self.threshold = threshold
@@ -209,7 +210,7 @@ class DetPredictor(BasicPredictor):
         threshold: Optional[Union[float, dict]] = None,
         layout_nms: bool = False,
         layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
-        layout_merge_bboxes_mode: Optional[str] = None,
+        layout_merge_bboxes_mode: Optional[Union[str, dict]] = None,
     ):
         """
         Process a batch of data through the preprocessing, inference, and postprocessing.
@@ -219,7 +220,7 @@ class DetPredictor(BasicPredictor):
             threshold (Optional[float, dict], optional): The threshold for filtering out low-confidence predictions.
             layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to None.
             layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
-            layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to None.
+            layout_merge_bboxes_mode (Optional[Union[str, dict]], optional): The mode for merging bounding boxes. Defaults to None.
 
         Returns:
             dict: A dictionary containing the input path, raw image, class IDs, scores, and label names