Browse Source

pipeline support topk input (#2799)

zhangyubo0722 10 months ago
parent
commit
6781884187

+ 1 - 1
api_examples/pipelines/test_image_classification.py

@@ -16,7 +16,7 @@ from paddlex import create_pipeline
 
 pipeline = create_pipeline(pipeline="image_classification")
 
-output = pipeline.predict("./test_samples/general_image_classification_001.jpg")
+output = pipeline.predict("./test_samples/general_image_classification_001.jpg", topk=5)
 
 # output = pipeline.predict("./test_samples/财报1.pdf")
 

+ 1 - 1
docs/pipeline_usage/tutorials/cv_pipelines/pedestrian_attribute_recognition.en.md

@@ -48,7 +48,7 @@ Pedestrian attribute recognition is a key function in computer vision systems, u
 <thead>
 <tr>
 <th>Model</th><th>Model Download Link</th>
-<th>mA (%)</th>
+<th>mAP (%)</th>
 <th>GPU Inference Time (ms)</th>
 <th>CPU Inference Time (ms)</th>
 <th>Model Size (M)</th>

+ 1 - 1
docs/pipeline_usage/tutorials/cv_pipelines/pedestrian_attribute_recognition.md

@@ -48,7 +48,7 @@ comments: true
 <thead>
 <tr>
 <th>模型</th><th>模型下载链接</th>
-<th>mA(%)</th>
+<th>mAP(%)</th>
 <th>GPU推理耗时(ms)</th>
 <th>CPU推理耗时 (ms)</th>
 <th>模型存储大小(M)</th>

+ 1 - 1
docs/pipeline_usage/tutorials/cv_pipelines/vehicle_attribute_recognition.en.md

@@ -44,7 +44,7 @@ Vehicle attribute recognition is a crucial component in computer vision systems.
 <thead>
 <tr>
 <th>Model</th><th>Model Download Link</th>
-<th>mA (%)</th>
+<th>mAP (%)</th>
 <th>GPU Inference Time (ms)</th>
 <th>CPU Inference Time (ms)</th>
 <th>Model Size (M)</th>

+ 1 - 1
docs/pipeline_usage/tutorials/cv_pipelines/vehicle_attribute_recognition.md

@@ -45,7 +45,7 @@ comments: true
 <thead>
 <tr>
 <th>模型</th><th>模型下载链接</th>
-<th>mA(%)</th>
+<th>mAP(%)</th>
 <th>GPU推理耗时(ms)</th>
 <th>CPU推理耗时 (ms)</th>
 <th>模型存储大小(M)</th>

+ 3 - 0
paddlex/inference/pipelines_new/image_classification/pipeline.py

@@ -57,6 +57,7 @@ class ImageClassificationPipeline(BasePipeline):
         self.image_classification_model = self.create_model(
             image_classification_model_config, **model_kwargs
         )
+        self.topk = image_classification_model_config.get("topk", 5)
 
     def predict(
         self, input: str | list[str] | np.ndarray | list[np.ndarray], topk=None
@@ -70,4 +71,6 @@ class ImageClassificationPipeline(BasePipeline):
         Returns:
             TopkResult: The predicted top k results.
         """
+      
+        topk = kwargs.pop("topk", self.topk)
         yield from self.image_classification_model(input, topk=topk)