Prechádzať zdrojové kódy

bugfix:

1. rename x to input;
2. fix error msg when arguments of BasicPredictor.set_predictor() are illegal;
gaotingquan 1 rok pred
rodič
commit
448830cf05

+ 5 - 1
paddlex/inference/components/base.py

@@ -13,8 +13,8 @@
 # limitations under the License.
 
 import inspect
+from abc import ABC, abstractmethod
 from copy import deepcopy
-from abc import ABC
 from types import GeneratorType
 
 from ...utils import logging
@@ -260,6 +260,10 @@ class BaseComponent(ABC):
     def name(self):
         return getattr(self, "NAME", self.__class__.__name__)
 
+    @abstractmethod
+    def apply(self, input):
+        raise NotImplementedError
+
 
 class ComponentsEngine(object):
     def __init__(self, ops):

+ 3 - 3
paddlex/inference/models/base/base_predictor.py

@@ -28,8 +28,8 @@ class BasePredictor(BaseComponent):
     KEEP_INPUT = False
     YIELD_BATCH = False
 
-    INPUT_KEYS = "x"
-    DEAULT_INPUTS = {"x": "x"}
+    INPUT_KEYS = "input"
+    DEAULT_INPUTS = {"input": "input"}
     OUTPUT_KEYS = "result"
     DEAULT_OUTPUTS = {"result": "result"}
 
@@ -57,7 +57,7 @@ class BasePredictor(BaseComponent):
         return self.config["Global"]["model_name"]
 
     @abstractmethod
-    def apply(self, x):
+    def apply(self, input):
         raise NotImplementedError
 
     @abstractmethod

+ 8 - 6
paddlex/inference/models/base/basic_predictor.py

@@ -13,6 +13,7 @@
 # limitations under the License.
 
 from abc import abstractmethod
+import inspect
 
 from ....utils.subclass_register import AutoRegisterABCMetaClass
 from ....utils import logging
@@ -46,9 +47,9 @@ class BasicPredictor(
         self.engine = ComponentsEngine(self.components)
         logging.debug(f"{self.__class__.__name__}: {self.model_dir}")
 
-    def apply(self, x):
+    def apply(self, input):
         """predict"""
-        yield from self._generate_res(self.engine(x))
+        yield from self._generate_res(self.engine(input))
 
     @generatorable_method
     def _generate_res(self, batch_data):
@@ -78,18 +79,19 @@ class BasicPredictor(
                 setattr(self, k, kwargs[k])
             else:
                 raise Exception(
-                    f"The arg({k}) is not supported to specify in predict() func! Only supports: {self._get_settable_attributes}"
+                    f"The arg({k}) is not supported to specify in predict() func! Only supports: {self._get_settable_attributes()}"
                 )
 
     def _has_setter(self, attr):
         prop = getattr(self.__class__, attr, None)
         return isinstance(prop, property) and prop.fset is not None
 
-    def _get_settable_attributes(self):
+    @classmethod
+    def _get_settable_attributes(cls):
         return [
             name
-            for name, prop in vars(self.__class__).items()
-            if isinstance(prop, property) and prop.fset is not None
+            for name, obj in inspect.getmembers(cls, lambda o: isinstance(o, property))
+            if obj.fset is not None
         ]
 
     @abstractmethod