Ver código fonte

support new ir blocklist

gaotingquan 1 ano atrás
pai
commit
33c7b59bf7

+ 1 - 1
paddlex/inference/components/paddle_predictor/predictor.py

@@ -30,7 +30,7 @@ class BasePaddlePredictor(BaseComponent, PPEngineMixin):
     DEAULT_OUTPUTS = {"pred": "pred"}
     ENABLE_BATCH = True
 
-    def __init__(self, model_dir, model_prefix, option: PaddlePredictorOption = None):
+    def __init__(self, model_dir, model_prefix, option):
         super().__init__()
         PPEngineMixin.__init__(self, option)
         self.model_dir = model_dir

+ 6 - 0
paddlex/inference/components/utils/mixin.py

@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from abc import abstractmethod
+
 
 class BatchSizeMixin:
     NAME = "ReadCmp"
@@ -45,3 +47,7 @@ class PPEngineMixin:
         if value != self.option:
             self._option = value
             self._reset()
+
+    @abstractmethod
+    def _reset(self):
+        raise NotImplementedError

+ 1 - 1
paddlex/inference/models/__init__.py

@@ -60,11 +60,11 @@ def _create_hp_predictor(
 
 def create_predictor(
     model: str,
-    *args,
     device=None,
     pp_option=None,
     use_hpip=False,
     hpi_params=None,
+    *args,
     **kwargs,
 ) -> BasePredictor:
     model_dir = check_model(model)

+ 9 - 5
paddlex/inference/models/base/basic_predictor.py

@@ -34,7 +34,9 @@ class BasicPredictor(
 
     __is_base = True
 
-    def __init__(self, model_dir, config=None, device=None, pp_option=None):
+    def __init__(
+        self, model_dir, config=None, device=None, pp_option=None, **option_kwargs
+    ):
         super().__init__(model_dir=model_dir, config=config)
         self._pred_set_func_map = {}
         self._pred_set_register = FuncRegister(self._pred_set_func_map)
@@ -42,14 +44,16 @@ class BasicPredictor(
         self._pred_set_register("pp_option")(self.set_pp_option)
         self._pred_set_register("batch_size")(self.set_batch_size)
 
-        self.pp_option = pp_option if pp_option else PaddlePredictorOption()
+        self.pp_option = (
+            pp_option
+            if pp_option
+            else PaddlePredictorOption(model_name=self.model_name, **option_kwargs)
+        )
         self.pp_option.set_device(device)
         self.components = {}
         self._build_components()
         self.engine = ComponentsEngine(self.components)
-        logging.debug(
-            f"-------------------- {self.__class__.__name__} --------------------\nModel: {self.model_dir}"
-        )
+        logging.debug(f"{self.__class__.__name__}: {self.model_dir}")
 
     def apply(self, x):
         """predict"""

+ 42 - 0
paddlex/inference/utils/new_ir_blacklist.py

@@ -0,0 +1,42 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+NEWIR_BLOCKLIST = [
+    "FasterRCNN-ResNet34-FPN",
+    "FasterRCNN-ResNet50",
+    "FasterRCNN-ResNet50-FPN",
+    "FasterRCNN-ResNet50-vd-FPN",
+    "FasterRCNN-ResNet50-vd-SSLDv2-FPN",
+    "FasterRCNN-ResNet101",
+    "FasterRCNN-ResNet101-FPN",
+    "FasterRCNN-ResNeXt101-vd-FPN",
+    "FasterRCNN-Swin-Tiny-FPN",
+    "Cascade-FasterRCNN-ResNet50-FPN",
+    "Cascade-FasterRCNN-ResNet50-vd-SSLDv2-FPN",
+    "PP-YOLOE_plus_SOD-S",
+    "PP-YOLOE_plus_SOD-L",
+    "PP-YOLOE_plus_SOD-largesize-L",
+    "PP-YOLOE_seg-S",
+    "MaskRCNN-ResNet50",
+    "MaskRCNN-ResNet50-FPN",
+    "MaskRCNN-ResNet50-vd-FPN",
+    "MaskRCNN-ResNet101-FPN",
+    "MaskRCNN-ResNet101-vd-FPN",
+    "MaskRCNN-ResNeXt101-vd-FPN",
+    "Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN",
+    "Cascade-MaskRCNN-ResNet50-FPN",
+    "DLinear_ad",
+    "PatchTST_ad",
+    "Nonstationary_ad",
+]

+ 6 - 4
paddlex/inference/utils/pp_option.py

@@ -12,9 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .device import parse_device
 from ...utils.func_register import FuncRegister
 from ...utils import logging
+from .device import parse_device
+from .new_ir_blacklist import NEWIR_BLOCKLIST
 
 
 class PaddlePredictorOption(object):
@@ -33,8 +34,9 @@ class PaddlePredictorOption(object):
     _FUNC_MAP = {}
     register = FuncRegister(_FUNC_MAP)
 
-    def __init__(self, **kwargs):
+    def __init__(self, model_name=None, **kwargs):
         super().__init__()
+        self.model_name = model_name
         self._cfg = {}
         self._init_option(**kwargs)
 
@@ -49,7 +51,7 @@ class PaddlePredictorOption(object):
         for k, v in self._get_default_config().items():
             self._cfg.setdefault(k, v)
 
-    def _get_default_config(cls):
+    def _get_default_config(self):
         """get default config"""
         return {
             "run_mode": "paddle",
@@ -61,7 +63,7 @@ class PaddlePredictorOption(object):
             "cpu_threads": 1,
             "trt_use_static": False,
             "delete_pass": [],
-            "enable_new_ir": True,
+            "enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False,
         }
 
     @register("run_mode")

+ 12 - 18
paddlex/model.py

@@ -25,42 +25,36 @@ from .modules import (
 
 
 # TODO(gaotingquan): support _ModelBasedConfig
-def create_model(model=None, **kwargs):
-    return _ModelBasedInference(model, **kwargs)
+def create_model(model=None, *args, **kwargs):
+    return _ModelBasedInference(model, *args, **kwargs)
 
 
 class _BaseModel:
-    @abstractmethod
     def check_dataset(self, *args, **kwargs):
-        raise NotImplementedError
+        raise Exception("check_dataset is not supported!")
 
-    @abstractmethod
     def train(self, *args, **kwargs):
-        raise NotImplementedError
+        raise Exception("train is not supported!")
 
-    @abstractmethod
     def evaluate(self, *args, **kwargs):
-        raise NotImplementedError
+        raise Exception("evaluate is not supported!")
 
-    @abstractmethod
     def export(self, *args, **kwargs):
-        raise NotImplementedError
+        raise Exception("export is not supported!")
 
-    @abstractmethod
     def predict(self, *args, **kwargs):
-        raise NotImplementedError
+        raise Exception("predict is not supported!")
 
-    @abstractmethod
     def set_predict(self, *args, **kwargs):
-        raise NotImplementedError
+        raise Exception("set_predict is not supported!")
 
     def __call__(self, *args, **kwargs):
         yield from self.predict(*args, **kwargs)
 
 
 class _ModelBasedInference(_BaseModel):
-    def __init__(self, model, device=None, **kwargs):
-        self._predictor = create_predictor(model, device=device, **kwargs)
+    def __init__(self, *args, **kwargs):
+        self._predictor = create_predictor(*args, **kwargs)
 
     def predict(self, *args, **kwargs):
         yield from self._predictor(*args, **kwargs)
@@ -108,5 +102,5 @@ class _ModelBasedConfig(_BaseModel):
         return exportor.export()
 
     def predict(self):
-        _predict_kwargs, _predictor = self._build_predictor()
-        yield from _predictor(**_predict_kwargs)
+        predict_kwargs, predictor = self._build_predictor()
+        yield from predictor(**predict_kwargs)

+ 1 - 1
paddlex/pipelines/table_recognition.yaml

@@ -1,6 +1,6 @@
 Global:
   pipeline_name: table_recognition
-  input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_001.png
+  input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/table_recognition.jpg
   
 ######################################## Setting ########################################
 # Please select the model from bellow `Support`