|
@@ -31,24 +31,8 @@ class SegPredictor(BasePredictor):
|
|
|
|
|
|
|
|
entities = MODELS
|
|
entities = MODELS
|
|
|
|
|
|
|
|
- def __init__(
|
|
|
|
|
- self,
|
|
|
|
|
- model_name,
|
|
|
|
|
- model_dir,
|
|
|
|
|
- kernel_option,
|
|
|
|
|
- output,
|
|
|
|
|
- pre_transforms=None,
|
|
|
|
|
- post_transforms=None,
|
|
|
|
|
- has_prob_map=False,
|
|
|
|
|
- ):
|
|
|
|
|
- super().__init__(
|
|
|
|
|
- model_name=model_name,
|
|
|
|
|
- model_dir=model_dir,
|
|
|
|
|
- kernel_option=kernel_option,
|
|
|
|
|
- output=output,
|
|
|
|
|
- pre_transforms=pre_transforms,
|
|
|
|
|
- post_transforms=post_transforms,
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ def __init__(self, has_prob_map=False, *args, **kwargs):
|
|
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
self.has_prob_map = has_prob_map
|
|
self.has_prob_map = has_prob_map
|
|
|
|
|
|
|
|
def load_other_src(self):
|
|
def load_other_src(self):
|