Kaynağa Gözat

support input threshold with class

zhangyubo0722 10 ay önce
ebeveyn
işleme
99fca5ca13

+ 3 - 1
paddlex/inference/models_new/image_multilabel_classification/processors.py

@@ -41,9 +41,11 @@ class MultiLabelThreshOutput:
                 raise ValueError(
                     "If using dictionary format, please specify default threshold explicitly with key 'default'."
                 )
-            default_threshold = threshold.pop("default")
+            default_threshold = threshold.get("default")
             threshold_list = [default_threshold for _ in range(num_classes)]
             for k, v in threshold.items():
+                if k == "default":
+                    continue
                 if isinstance(k, str):
                     assert (
                         k.isdigit()

+ 10 - 4
paddlex/inference/pipelines_new/attribute_recognition/pipeline.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Union, List
 
 import pickle
 from pathlib import Path
@@ -53,9 +53,15 @@ class AttributeRecPipeline(BasePipeline):
         )
         self.img_reader = ReadImage(format="BGR")
 
-    def predict(self, input, **kwargs):
-        det_threshold = kwargs.pop("det_threshold", self.det_threshold)
-        cls_threshold = kwargs.pop("cls_threshold", self.cls_threshold)
+    def predict(
+        self,
+        input: Union[str, List[str], np.ndarray, List[np.ndarray]],
+        det_threshold: float = None,
+        cls_threshold: Union[float, dict, list, None] = None,
+        **kwargs
+    ):
+        det_threshold = self.det_threshold if det_threshold is None else det_threshold
+        cls_threshold = self.cls_threshold if cls_threshold is None else cls_threshold
         for img_id, batch_data in enumerate(self.batch_sampler(input)):
             raw_imgs = self.img_reader(batch_data)
             all_det_res = list(self.det_model(raw_imgs, threshold=det_threshold))