Bobholamovic пре 1 година
родитељ
комит
4bfec6c0d4

+ 19 - 1
paddlex/inference/pipelines/base.py

@@ -13,7 +13,9 @@
 # limitations under the License.
 
 from abc import ABC, abstractmethod
+from typing import Any, Dict, Optional
 
+from ..predictors import create_predictor
 from ...utils.subclass_register import AutoRegisterABCMetaClass
 
 
@@ -23,6 +25,8 @@ def create_pipeline(
     model_dir_list: list,
     output: str,
     device: str,
+    use_hpip: bool,
+    hpi_params: Optional[Dict[str, Any]] = None,
 ) -> "BasePipeline":
     """build model evaluater
 
@@ -32,7 +36,12 @@ def create_pipeline(
     Returns:
         BasePipeline: the pipeline, which is subclass of BasePipeline.
     """
-    pipeline = BasePipeline.get(pipeline_name)(output=output, device=device)
+    predictor_kwargs = {"use_hpip": use_hpip}
+    if hpi_params is not None:
+        predictor_kwargs["hpi_params"] = hpi_params
+    pipeline = BasePipeline.get(pipeline_name)(
+        output=output, device=device, predictor_kwargs=predictor_kwargs
+    )
     pipeline.update_model(model_list, model_dir_list)
     pipeline.load_model()
     return pipeline
@@ -43,6 +52,15 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
 
     __is_base = True
 
+    def __init__(self, predictor_kwargs: Optional[Dict[str, Any]]) -> None:
+        super().__init__()
+        if predictor_kwargs is None:
+            predictor_kwargs = {}
+        self._predictor_kwargs = predictor_kwargs
+
     # alias the __call__() to predict()
     def __call__(self, *args, **kwargs):
         yield from self.predict(*args, **kwargs)
+
+    def _create_predictor(self, *args, **kwargs):
+        return create_predictor(*args, **kwargs, **self._predictor_kwargs)

+ 5 - 4
paddlex/inference/pipelines/general_recognition.py

@@ -13,7 +13,6 @@
 # limitations under the License.
 
 from .base import BasePipeline
-from ..predictors import create_predictor
 
 
 class ShiTuRecPipeline(BasePipeline):
@@ -21,9 +20,11 @@ class ShiTuRecPipeline(BasePipeline):
 
     entities = "general_recognition"
 
-    def __init__(self, model, batch_size=1, device="gpu"):
-        super().__init__()
-        self._predict = create_predictor(model, batch_size=batch_size, device=device)
+    def __init__(self, model, batch_size=1, device="gpu", predictor_kwargs=None):
+        super().__init__(predictor_kwargs)
+        self._predict = self._create_predictor(
+            model, batch_size=batch_size, device=device
+        )
 
     def predict(self, x):
         self._check_input(x)

+ 5 - 4
paddlex/inference/pipelines/image_classification.py

@@ -13,7 +13,6 @@
 # limitations under the License.
 
 from .base import BasePipeline
-from ..predictors import create_predictor
 
 
 class ClasPipeline(BasePipeline):
@@ -21,9 +20,11 @@ class ClasPipeline(BasePipeline):
 
     entities = "image_classification"
 
-    def __init__(self, model, batch_size=1, device="gpu"):
-        super().__init__()
-        self._predict = create_predictor(model, batch_size=batch_size, device=device)
+    def __init__(self, model, batch_size=1, device="gpu", predictor_kwargs=None):
+        super().__init__(predictor_kwargs)
+        self._predict = self._create_predictor(
+            model, batch_size=batch_size, device=device
+        )
 
     def predict(self, x):
         self._check_input(x)

+ 5 - 4
paddlex/inference/pipelines/instance_segmentation.py

@@ -13,7 +13,6 @@
 # limitations under the License.
 
 from .base import BasePipeline
-from ..predictors import create_predictor
 
 
 class InstanceSegPipeline(BasePipeline):
@@ -21,9 +20,11 @@ class InstanceSegPipeline(BasePipeline):
 
     entities = "instance_segmentation"
 
-    def __init__(self, model, batch_size=1, device="gpu"):
-        super().__init__()
-        self._predict = create_predictor(model, batch_size=batch_size, device=device)
+    def __init__(self, model, batch_size=1, device="gpu", predictor_kwargs=None):
+        super().__init__(predictor_kwargs)
+        self._predict = self._create_predictor(
+            model, batch_size=batch_size, device=device
+        )
 
     def predict(self, x):
         self._check_input(x)

+ 5 - 4
paddlex/inference/pipelines/object_detection.py

@@ -13,7 +13,6 @@
 # limitations under the License.
 
 from .base import BasePipeline
-from ..predictors import create_predictor
 
 
 class DetPipeline(BasePipeline):
@@ -21,9 +20,11 @@ class DetPipeline(BasePipeline):
 
     entities = "object_detection"
 
-    def __init__(self, model, batch_size=1, device="gpu"):
-        super().__init__()
-        self._predict = create_predictor(model, batch_size=batch_size, device=device)
+    def __init__(self, model, batch_size=1, device="gpu", predictor_kwargs=None):
+        super().__init__(predictor_kwargs)
+        self._predict = self._create_predictor(
+            model, batch_size=batch_size, device=device
+        )
 
     def predict(self, x):
         self._check_input(x)

+ 12 - 5
paddlex/inference/pipelines/ocr.py

@@ -13,8 +13,6 @@
 # limitations under the License.
 
 from .base import BasePipeline
-from ..predictors import create_predictor
-from ...utils import logging
 from ..components import CropByPolys
 from ..results import OCRResult
 
@@ -24,9 +22,18 @@ class OCRPipeline(BasePipeline):
 
     entities = "ocr"
 
-    def __init__(self, det_model, rec_model, det_batch_size, rec_batch_size, **kwargs):
-        self._det_predict = create_predictor(det_model, batch_size=det_batch_size)
-        self._rec_predict = create_predictor(rec_model, batch_size=rec_batch_size)
+    def __init__(
+        self,
+        det_model,
+        rec_model,
+        det_batch_size,
+        rec_batch_size,
+        predictor_kwargs=None,
+        **kwargs
+    ):
+        super().__init__(predictor_kwargs)
+        self._det_predict = self._create_predictor(det_model, batch_size=det_batch_size)
+        self._rec_predict = self._create_predictor(rec_model, batch_size=rec_batch_size)
         # TODO: foo
         self._crop_by_polys = CropByPolys(det_box_type="foo")
 

+ 5 - 4
paddlex/inference/pipelines/semantic_segmentation.py

@@ -13,7 +13,6 @@
 # limitations under the License.
 
 from .base import BasePipeline
-from ..predictors import create_predictor
 
 
 class SegPipeline(BasePipeline):
@@ -21,9 +20,11 @@ class SegPipeline(BasePipeline):
 
     entities = "semantic_segmentation"
 
-    def __init__(self, model, batch_size=1, device="gpu"):
-        super().__init__()
-        self._predict = create_predictor(model, batch_size=batch_size, device=device)
+    def __init__(self, model, batch_size=1, device="gpu", predictor_kwargs=None):
+        super().__init__(predictor_kwargs)
+        self._predict = self._create_predictor(
+            model, batch_size=batch_size, device=device
+        )
 
     def predict(self, x):
         self._check_input(x)

+ 11 - 5
paddlex/inference/pipelines/table_recognition/table_recognition.py

@@ -14,7 +14,6 @@
 
 import numpy as np
 from ..base import BasePipeline
-from ...predictors import create_predictor
 from ..ocr import OCRPipeline
 from ...components import CropByBoxes
 from ...results import OCRResult, TableResult, StructureTableResult
@@ -24,6 +23,8 @@ from .utils import *
 class TableRecPipeline(BasePipeline):
     """Table Recognition Pipeline"""
 
+    entities = "table_recognition"
+
     def __init__(
         self,
         layout_model,
@@ -33,21 +34,26 @@ class TableRecPipeline(BasePipeline):
         batch_size=1,
         device="gpu",
         chat_ocr=False,
+        predictor_kwargs=None,
     ):
+        super().__init__(predictor_kwargs)
 
-        self.layout_predictor = create_predictor(
+        self.layout_predictor = self._create_predictor(
             model=layout_model, device=device, batch_size=batch_size
         )
         self.ocr_pipeline = OCRPipeline(
-            text_det_model, text_rec_model, batch_size, device
+            text_det_model,
+            text_rec_model,
+            batch_size,
+            device,
+            predictor_kwargs=predictor_kwargs,
         )
-        self.table_predictor = create_predictor(
+        self.table_predictor = self._create_predictor(
             model=table_model, device=device, batch_size=batch_size
         )
         self._crop_by_boxes = CropByBoxes()
         self._match = TableMatch(filter_ocr_result=False)
         self.chat_ocr = chat_ocr
-        super().__init__()
 
     def predict(self, x):
         batch_structure_res = []

+ 44 - 5
paddlex/inference/predictors/__init__.py

@@ -29,19 +29,58 @@ from .ts_fc import TSFcPredictor
 from .ts_cls import TSClsPredictor
 
 
-def create_predictor(model: str, device: str = None, *args, **kwargs) -> BasePredictor:
-    model_dir = check_model(model)
-    config = BasePredictor.load_config(model_dir)
-    model_name = config["Global"]["model_name"]
-    return BasicPredictor.get(model_name)(
+def _create_hp_predictor(
+    model_name, model_dir, device, config, hpi_params, *args, **kwargs
+):
+    try:
+        from paddlex_hpi.predictors import HPPredictor
+    except ModuleNotFoundError as e:
+        raise RuntimeError(
+            "The PaddleX HPI plugin is not properly installed, and the high-performance model inference features are not available."
+        )
+    if hpi_params is None:
+        raise ValueError("No HPI params given")
+    if "serial_number" not in hpi_params:
+        raise ValueError("The serial number is required but was not provided.")
+    serial_number = hpi_params["serial_number"]
+    update_license = hpi_params.get("update_license", False)
+    return HPPredictor.get(model_name)(
         model_dir=model_dir,
         config=config,
         device=device,
+        serial_number=serial_number,
+        update_license=update_license,
         *args,
         **kwargs,
     )
 
 
+def create_predictor(
+    model: str, device: str = None, *args, use_hpip=False, hpi_params=None, **kwargs
+) -> BasePredictor:
+    model_dir = check_model(model)
+    config = BasePredictor.load_config(model_dir)
+    model_name = config["Global"]["model_name"]
+    if use_hpip:
+        return _create_hp_predictor(
+            model_name=model_name,
+            model_dir=model_dir,
+            device=device,
+            config=config,
+            hpi_params=hpi_params,
+            *args,
+            **kwargs,
+        )
+    else:
+        return BasicPredictor.get(model_name)(
+            model_dir=model_dir,
+            config=config,
+            device=device,
+            *args,
+            **kwargs,
+        )
+
+
 def check_model(model):
     if Path(model).exists():
         return Path(model)

+ 16 - 8
paddlex/utils/lazy_loader.py

@@ -23,21 +23,29 @@ class LazyLoader(types.ModuleType):
     def __init__(self, local_name, parent_module_globals, name):
         self._local_name = local_name
         self._parent_module_globals = parent_module_globals
+        self._module = None
 
         super(LazyLoader, self).__init__(name)
 
+    @property
+    def loaded(self):
+        return self._module is not None
+
     def _load(self):
         module = importlib.import_module(self.__name__)
         self._parent_module_globals[self._local_name] = module
-
-        self.__dict__.update(module.__dict__)
-
-        return module
+        self._module = module
 
     def __getattr__(self, item):
-        module = self._load()
-        return getattr(module, item)
+        if not self.loaded:
+            # HACK: For circumventing shared library symbol conflicts when
+            # importing paddlex_hpi
+            if item in ("__file__",):
+                raise AttributeError
+            self._load()
+        return getattr(self._module, item)
 
     def __dir__(self):
-        module = self._load()
-        return dir(module)
+        if not self.loaded:
+            self._load()
+        return dir(self._module)