|
|
@@ -57,6 +57,7 @@ class ImageClassificationPipeline(BasePipeline):
|
|
|
self.image_classification_model = self.create_model(
|
|
|
image_classification_model_config, **model_kwargs
|
|
|
)
|
|
|
+ self.topk = image_classification_model_config.get("topk", 5)
|
|
|
|
|
|
def predict(
|
|
|
self, input: str | list[str] | np.ndarray | list[np.ndarray], topk=None
|
|
|
@@ -70,4 +71,6 @@ class ImageClassificationPipeline(BasePipeline):
|
|
|
Returns:
|
|
|
TopkResult: The predicted top k results.
|
|
|
"""
|
|
|
+
|
|
|
+ topk = kwargs.pop("topk", self.topk)
|
|
|
yield from self.image_classification_model(input, topk=topk)
|