|
|
@@ -26,14 +26,13 @@ class ClasPredictor(BasePredictor):
|
|
|
|
|
|
entities = MODELS
|
|
|
|
|
|
- INPUT_KEYS = "x"
|
|
|
- OUTPUT_KEYS = "topk_res"
|
|
|
- DEAULT_INPUTS = {"x": "x"}
|
|
|
- DEAULT_OUTPUTS = {"topk_res": "topk_res"}
|
|
|
-
|
|
|
_FUNC_MAP = {}
|
|
|
register = FuncRegister(_FUNC_MAP)
|
|
|
|
|
|
+ def _check_args(self, kwargs):
|
|
|
+ assert set(kwargs.keys()).issubset(set(["batch_size"]))
|
|
|
+ return kwargs
|
|
|
+
|
|
|
def _build_components(self):
|
|
|
ops = {}
|
|
|
ops["ReadImage"] = ReadImage(batch_size=self.kwargs.get("batch_size", 1))
|
|
|
@@ -44,12 +43,10 @@ class ClasPredictor(BasePredictor):
|
|
|
op = func(self, **args) if args else func(self)
|
|
|
ops[tf_key] = op
|
|
|
|
|
|
- kernel_option = PaddlePredictorOption()
|
|
|
- kernel_option.set_device(self.device)
|
|
|
predictor = ImagePredictor(
|
|
|
model_dir=self.model_dir,
|
|
|
model_prefix=self.MODEL_FILE_PREFIX,
|
|
|
- option=kernel_option,
|
|
|
+ option=self.pp_option,
|
|
|
)
|
|
|
ops["predictor"] = predictor
|
|
|
|
|
|
@@ -62,7 +59,10 @@ class ClasPredictor(BasePredictor):
|
|
|
return ops
|
|
|
|
|
|
@register("ResizeImage")
|
|
|
- def build_resize(self, resize_short=None, size=None):
|
|
|
+ # TODO(gaotingquan): backend & interpolation
|
|
|
+ def build_resize(
|
|
|
+ self, resize_short=None, size=None, backend="cv2", interpolation="LINEAR"
|
|
|
+ ):
|
|
|
assert resize_short or size
|
|
|
if resize_short:
|
|
|
op = ResizeByShort(
|
|
|
@@ -97,9 +97,13 @@ class ClasPredictor(BasePredictor):
|
|
|
def build_topk(self, topk, label_list=None):
|
|
|
return Topk(topk=int(topk), class_ids=label_list)
|
|
|
|
|
|
+ @register("MultiLabelThreshOutput")
|
|
|
+ def build_threshoutput(self, threshold, label_list=None):
|
|
|
+ return MultiLabelThreshOutput(threshold=float(threshold), class_ids=label_list)
|
|
|
+
|
|
|
@batchable_method
|
|
|
def _pack_res(self, data):
|
|
|
keys = ["img_path", "class_ids", "scores"]
|
|
|
if "label_names" in data:
|
|
|
keys.append("label_names")
|
|
|
- return {"topk_res": TopkResult({key: data[key] for key in keys})}
|
|
|
+ return {"result": TopkResult({key: data[key] for key in keys})}
|