gaotingquan 1 рік тому
батько
коміт
2565a6abd7

+ 18 - 4
paddlex/inference/components/paddle_predictor/predictor.py

@@ -20,11 +20,10 @@ import numpy as np
 from ....utils.flags import FLAGS_json_format_model
 from ....utils import logging
 from ...utils.pp_option import PaddlePredictorOption
-from ..utils.mixin import PPEngineMixin
 from ..base import BaseComponent
 
 
-class BasePaddlePredictor(BaseComponent, PPEngineMixin):
+class BasePaddlePredictor(BaseComponent):
     """Predictor based on Paddle Inference"""
 
     OUTPUT_KEYS = "pred"
@@ -33,10 +32,25 @@ class BasePaddlePredictor(BaseComponent, PPEngineMixin):
 
     def __init__(self, model_dir, model_prefix, option):
         super().__init__()
-        PPEngineMixin.__init__(self, option)
         self.model_dir = model_dir
         self.model_prefix = model_prefix
-        self.reset()
+        self._update_option(option)
+
+    def _update_option(self, option):
+        if option:
+            if self.option and option == self.option:
+                return
+            self._option = option
+            self._option.attach(self)
+            self.reset()
+
+    @property
+    def option(self):
+        return self._option if hasattr(self, "_option") else None
+
+    @option.setter
+    def option(self, option):
+        self._update_option(option)
 
     def reset(self):
         if not self.option:

+ 0 - 1
paddlex/inference/components/transforms/image/common.py

@@ -21,7 +21,6 @@ import cv2
 
 from .....utils.cache import CACHE_DIR, temp_file_manager
 from ....utils.io import ImageReader, ImageWriter, PDFReader
-from ...utils.mixin import BatchSizeMixin
 from ...base import BaseComponent
 from ..read_data import _BaseRead
 from . import funcs as F

+ 13 - 3
paddlex/inference/components/transforms/read_data.py

@@ -17,18 +17,28 @@ from pathlib import Path
 
 from ....utils.download import download
 from ....utils.cache import CACHE_DIR
-from ..utils.mixin import BatchSizeMixin
 from ..base import BaseComponent
 
 
-class _BaseRead(BaseComponent, BatchSizeMixin):
+class _BaseRead(BaseComponent):
     """Load image from the file."""
 
+    NAME = "ReadCmp"
     SUFFIX = []
 
     def __init__(self, batch_size=1):
         super().__init__()
-        BatchSizeMixin.__init__(self, batch_size)
+        self._batch_size = batch_size
+
+    @property
+    def batch_size(self):
+        return self._batch_size
+
+    @batch_size.setter
+    def batch_size(self, value):
+        if value <= 0:
+            raise ValueError("Batch size must be positive.")
+        self._batch_size = value
 
     # XXX: auto download for url
     def _download_from_url(self, in_path):

+ 0 - 13
paddlex/inference/components/utils/__init__.py

@@ -1,13 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.

+ 0 - 53
paddlex/inference/components/utils/mixin.py

@@ -1,53 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from abc import abstractmethod
-
-
-class BatchSizeMixin:
-    NAME = "ReadCmp"
-
-    def __init__(self, batch_size=1):
-        self._batch_size = batch_size
-
-    @property
-    def batch_size(self):
-        return self._batch_size
-
-    @batch_size.setter
-    def batch_size(self, value):
-        if value <= 0:
-            raise ValueError("Batch size must be positive.")
-        self._batch_size = value
-
-
-class PPEngineMixin:
-    NAME = "PPEngineCmp"
-
-    def __init__(self, option=None):
-        self._option = option
-
-    @property
-    def option(self):
-        return self._option
-
-    @option.setter
-    def option(self, value):
-        if value is not None and value != self.option:
-            self._option = value
-            self._reset()
-
-    @abstractmethod
-    def _reset(self):
-        raise NotImplementedError

+ 29 - 10
paddlex/inference/utils/pp_option.py

@@ -35,6 +35,7 @@ class PaddlePredictorOption(object):
         self.model_name = model_name
         self._cfg = {}
         self._init_option(**kwargs)
+        self._observers = []
 
     def _init_option(self, **kwargs):
         for k, v in kwargs.items():
@@ -63,6 +64,10 @@ class PaddlePredictorOption(object):
             "enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False,
         }
 
+    def _update(self, k, v):
+        self._cfg[k] = v
+        self.notify()
+
     @property
     def run_mode(self):
         return self._cfg["run_mode"]
@@ -75,7 +80,7 @@ class PaddlePredictorOption(object):
             raise ValueError(
                 f"`run_mode` must be {support_run_mode_str}, but received {repr(run_mode)}."
             )
-        self._cfg["run_mode"] = run_mode
+        self._update("run_mode", run_mode)
 
     @property
     def device_type(self):
@@ -100,9 +105,9 @@ class PaddlePredictorOption(object):
             raise ValueError(
                 f"The device type must be one of {support_run_mode_str}, but received {repr(device_type)}."
             )
-        self._cfg["device"] = device_type
+        self._update("device", device_type)
         device_id = device_ids[0] if device_ids is not None else 0
-        self._cfg["device_id"] = device_id
+        self._update("device_id", device_id)
         set_env_for_device(device)
         if device_type not in ("cpu"):
             if device_ids is None or len(device_ids) > 1:
@@ -117,7 +122,7 @@ class PaddlePredictorOption(object):
         """set min subgraph size"""
         if not isinstance(min_subgraph_size, int):
             raise Exception()
-        self._cfg["min_subgraph_size"] = min_subgraph_size
+        self._update("min_subgraph_size", min_subgraph_size)
 
     @property
     def shape_info_filename(self):
@@ -126,7 +131,7 @@ class PaddlePredictorOption(object):
     @shape_info_filename.setter
     def shape_info_filename(self, shape_info_filename: str):
         """set shape info filename"""
-        self._cfg["shape_info_filename"] = shape_info_filename
+        self._update("shape_info_filename", shape_info_filename)
 
     @property
     def trt_calib_mode(self):
@@ -135,7 +140,7 @@ class PaddlePredictorOption(object):
     @trt_calib_mode.setter
     def trt_calib_mode(self, trt_calib_mode):
         """set trt calib mode"""
-        self._cfg["trt_calib_mode"] = trt_calib_mode
+        self._update("trt_calib_mode", trt_calib_mode)
 
     @property
     def cpu_threads(self):
@@ -146,7 +151,7 @@ class PaddlePredictorOption(object):
         """set cpu threads"""
         if not isinstance(cpu_threads, int) or cpu_threads < 1:
             raise Exception()
-        self._cfg["cpu_threads"] = cpu_threads
+        self._update("cpu_threads", cpu_threads)
 
     @property
     def trt_use_static(self):
@@ -155,7 +160,7 @@ class PaddlePredictorOption(object):
     @trt_use_static.setter
     def trt_use_static(self, trt_use_static):
         """set trt use static"""
-        self._cfg["trt_use_static"] = trt_use_static
+        self._update("trt_use_static", trt_use_static)
 
     @property
     def delete_pass(self):
@@ -163,7 +168,7 @@ class PaddlePredictorOption(object):
 
     @delete_pass.setter
     def delete_pass(self, delete_pass):
-        self._cfg["delete_pass"] = delete_pass
+        self._update("delete_pass", delete_pass)
 
     @property
     def enable_new_ir(self):
@@ -172,7 +177,7 @@ class PaddlePredictorOption(object):
     @enable_new_ir.setter
     def enable_new_ir(self, enable_new_ir: bool):
         """set run mode"""
-        self._cfg["enable_new_ir"] = enable_new_ir
+        self._update("enable_new_ir", enable_new_ir)
 
     def get_support_run_mode(self):
         """get supported run mode"""
@@ -205,3 +210,17 @@ class PaddlePredictorOption(object):
             for name, prop in vars(self.__class__).items()
             if isinstance(prop, property) and prop.fset is not None
         ]
+
+    def attach(self, observer):
+        if observer not in self._observers:
+            self._observers.append(observer)
+
+    def detach(self, observer):
+        try:
+            self._observers.remove(observer)
+        except ValueError:
+            pass
+
+    def notify(self):
+        for observer in self._observers:
+            observer.reset()