gaotingquan 1 年之前
父節點
當前提交
432c73ffba

+ 1 - 1
paddlex/inference/models/general_recognition.py

@@ -33,7 +33,7 @@ class ShiTuRecPredictor(CVPredictor):
         self._add_component(ReadImage(format="RGB"))
         for cfg in self.config["PreProcess"]["transform_ops"]:
             tf_key = list(cfg.keys())[0]
-            func = self._FUNC_MAP.get(tf_key)
+            func = self._FUNC_MAP[tf_key]
             args = cfg.get(tf_key, {})
             op = func(self, **args) if args else func(self)
             self._add_component(op)

+ 1 - 1
paddlex/inference/models/image_classification.py

@@ -34,7 +34,7 @@ class ClasPredictor(CVPredictor):
         self._add_component(ReadImage(format="RGB"))
         for cfg in self.config["PreProcess"]["transform_ops"]:
             tf_key = list(cfg.keys())[0]
-            func = self._FUNC_MAP.get(tf_key)
+            func = self._FUNC_MAP[tf_key]
             args = cfg.get(tf_key, {})
             op = func(self, **args) if args else func(self)
             self._add_component(op)

+ 1 - 1
paddlex/inference/models/instance_segmentation.py

@@ -30,7 +30,7 @@ class InstanceSegPredictor(DetPredictor):
         self._add_component(ReadImage(format="RGB"))
         for cfg in self.config["Preprocess"]:
             tf_key = cfg["type"]
-            func = self._FUNC_MAP.get(tf_key)
+            func = self._FUNC_MAP[tf_key]
             cfg.pop("type")
             args = cfg
             op = func(self, **args) if args else func(self)

+ 1 - 1
paddlex/inference/models/object_detection.py

@@ -33,7 +33,7 @@ class DetPredictor(CVPredictor):
         self._add_component(ReadImage(format="RGB"))
         for cfg in self.config["Preprocess"]:
             tf_key = cfg["type"]
-            func = self._FUNC_MAP.get(tf_key)
+            func = self._FUNC_MAP[tf_key]
             cfg.pop("type")
             args = cfg
             op = func(self, **args) if args else func(self)

+ 1 - 1
paddlex/inference/models/semantic_segmentation.py

@@ -34,7 +34,7 @@ class SegPredictor(CVPredictor):
         self._add_component(ToCHWImage())
         for cfg in self.config["Deploy"]["transforms"]:
             tf_key = cfg["type"]
-            func = self._FUNC_MAP.get(tf_key)
+            func = self._FUNC_MAP[tf_key]
             cfg.pop("type")
             args = cfg
             op = func(self, **args) if args else func(self)

+ 5 - 3
paddlex/inference/models/table_recognition.py

@@ -34,7 +34,7 @@ class TablePredictor(CVPredictor):
     def _build_components(self):
         for cfg in self.config["PreProcess"]["transform_ops"]:
             tf_key = list(cfg.keys())[0]
-            func = self._FUNC_MAP.get(tf_key)
+            func = self._FUNC_MAP[tf_key]
             args = cfg.get(tf_key, {})
             op = func(self, **args) if args else func(self)
             if op:
@@ -60,8 +60,10 @@ class TablePredictor(CVPredictor):
             raise Exception()
 
     @register("DecodeImage")
-    def build_readimg(self, *args, **kwargs):
-        return ReadImage(*args, **kwargs)
+    def build_readimg(self, channel_first=False, img_mode="BGR"):
+        assert channel_first is False
+        assert img_mode == "BGR"
+        return ReadImage()
 
     @register("TableLabelEncode")
     def foo(self, *args, **kwargs):

+ 1 - 1
paddlex/inference/models/text_detection.py

@@ -32,7 +32,7 @@ class TextDetPredictor(CVPredictor):
     def _build_components(self):
         for cfg in self.config["PreProcess"]["transform_ops"]:
             tf_key = list(cfg.keys())[0]
-            func = self._FUNC_MAP.get(tf_key)
+            func = self._FUNC_MAP[tf_key]
             args = cfg.get(tf_key, {})
             op = func(self, **args) if args else func(self)
             if op:

+ 1 - 1
paddlex/inference/models/text_recognition.py

@@ -33,7 +33,7 @@ class TextRecPredictor(CVPredictor):
         for cfg in self.config["PreProcess"]["transform_ops"]:
             tf_key = list(cfg.keys())[0]
             assert tf_key in self._FUNC_MAP
-            func = self._FUNC_MAP.get(tf_key)
+            func = self._FUNC_MAP[tf_key]
             args = cfg.get(tf_key, {})
             op = func(self, **args) if args else func(self)
             if op: