|
|
@@ -19,18 +19,16 @@ from typing import (
|
|
|
Any,
|
|
|
Dict,
|
|
|
Final,
|
|
|
- Generator,
|
|
|
- List,
|
|
|
+ Iterator,
|
|
|
Optional,
|
|
|
- Protocol,
|
|
|
TypedDict,
|
|
|
Union,
|
|
|
)
|
|
|
|
|
|
-import ultrainfer as ui
|
|
|
-from ultrainfer.model import BaseUltraInferModel
|
|
|
-from paddlex.inference.components import ReadImage, ReadTS
|
|
|
-from paddlex.inference.models import BasePredictor
|
|
|
+import ultra_infer as ui
|
|
|
+from ultra_infer.model import BaseUltraInferModel
|
|
|
+from paddlex.inference.common.reader import ReadImage
|
|
|
+from paddlex.inference.models_new import BasePredictor
|
|
|
from paddlex.inference.utils.new_ir_blacklist import NEWIR_BLOCKLIST
|
|
|
from paddlex.utils import device as device_helper
|
|
|
from paddlex.utils import logging
|
|
|
@@ -38,7 +36,7 @@ from paddlex.utils.subclass_register import AutoRegisterABCMetaClass
|
|
|
from typing_extensions import assert_never
|
|
|
|
|
|
from paddlex_hpi._config import HPIConfig
|
|
|
-from paddlex_hpi._utils.typing import Backend, BatchData
|
|
|
+from paddlex_hpi._utils.typing import Backend
|
|
|
|
|
|
HPI_CONFIG_KEY: Final[str] = "Hpi"
|
|
|
|
|
|
@@ -64,6 +62,11 @@ class HPPredictor(BasePredictor, metaclass=AutoRegisterABCMetaClass):
|
|
|
self._hpi_params = hpi_params or {}
|
|
|
self._hpi_config = self._get_hpi_config()
|
|
|
self._ui_model = self.build_ui_model()
|
|
|
+ self._data_reader = self._build_data_reader()
|
|
|
+
|
|
|
+ def __call__(self, input: Any, **kwargs: dict[str, Any]) -> Iterator[Any]:
|
|
|
+ self.set_predictor(**kwargs)
|
|
|
+ yield from self.apply(input)
|
|
|
|
|
|
@property
|
|
|
def model_path(self) -> Path:
|
|
|
@@ -79,6 +82,8 @@ class HPPredictor(BasePredictor, metaclass=AutoRegisterABCMetaClass):
|
|
|
if device is not None:
|
|
|
if device != self._device:
|
|
|
raise RuntimeError("Currently, changing devices is not supported.")
|
|
|
+ if "batch_size" in kwargs:
|
|
|
+ self.batch_sampler.batch_size = kwargs.pop("batch_size")
|
|
|
if kwargs:
|
|
|
raise TypeError(f"Unexpected arguments: {kwargs}")
|
|
|
|
|
|
@@ -86,10 +91,6 @@ class HPPredictor(BasePredictor, metaclass=AutoRegisterABCMetaClass):
|
|
|
option = self._create_ui_option()
|
|
|
return self._build_ui_model(option)
|
|
|
|
|
|
- @abc.abstractmethod
|
|
|
- def _build_ui_model(self, option: ui.RuntimeOption) -> BaseUltraInferModel:
|
|
|
- raise NotImplementedError
|
|
|
-
|
|
|
def _get_hpi_config(self) -> HPIConfig:
|
|
|
if HPI_CONFIG_KEY not in self.config:
|
|
|
logging.debug("Key %r not found in the config", HPI_CONFIG_KEY)
|
|
|
@@ -134,56 +135,20 @@ class HPPredictor(BasePredictor, metaclass=AutoRegisterABCMetaClass):
|
|
|
backend_config.update_ui_option(option, self.model_dir)
|
|
|
return option
|
|
|
|
|
|
-
|
|
|
-class _DataReaderLike(Protocol):
|
|
|
- batch_size: int
|
|
|
-
|
|
|
- def __call__(self, input_list: Any) -> Generator[BatchData, None, None]: ...
|
|
|
-
|
|
|
-
|
|
|
-class HPPredictorWithDataReader(HPPredictor):
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- model_dir: Union[str, PathLike],
|
|
|
- config: Optional[Dict[str, Any]] = None,
|
|
|
- device: Optional[str] = None,
|
|
|
- hpi_params: Optional[HPIParams] = None,
|
|
|
- ) -> None:
|
|
|
- super().__init__(
|
|
|
- model_dir=model_dir,
|
|
|
- config=config,
|
|
|
- device=device,
|
|
|
- hpi_params=hpi_params,
|
|
|
- )
|
|
|
- self._batch_size = 1
|
|
|
- self._data_reader = self._build_data_reader()
|
|
|
-
|
|
|
- def set_predictor(self, **kwargs: Any) -> None:
|
|
|
- batch_size = kwargs.pop("batch_size", None)
|
|
|
- super().set_predictor(**kwargs)
|
|
|
- if batch_size is not None:
|
|
|
- self._batch_size = batch_size
|
|
|
- self._data_reader.batch_size = batch_size
|
|
|
- logging.info("Batch size updated to %d", self._batch_size)
|
|
|
-
|
|
|
- def apply(self, input: Any) -> Generator[BatchData, None, None]:
|
|
|
- for batch_data in self._data_reader(input):
|
|
|
- yield self._predict(batch_data)
|
|
|
-
|
|
|
@abc.abstractmethod
|
|
|
- def _build_data_reader(self) -> _DataReaderLike:
|
|
|
+ def _build_ui_model(self, option: ui.RuntimeOption) -> BaseUltraInferModel:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
@abc.abstractmethod
|
|
|
- def _predict(self, batch_data: BatchData) -> BatchData:
|
|
|
+ def _build_data_reader(self):
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
-class CVPredictor(HPPredictorWithDataReader):
|
|
|
- def _build_data_reader(self) -> _DataReaderLike:
|
|
|
- return ReadImage(batch_size=self._batch_size, format="BGR")
|
|
|
+class CVPredictor(HPPredictor):
|
|
|
+ def _build_data_reader(self):
|
|
|
+ return ReadImage(format="BGR")
|
|
|
|
|
|
|
|
|
-class TSPredictor(HPPredictorWithDataReader):
|
|
|
- def _build_data_reader(self) -> _DataReaderLike:
|
|
|
- return ReadTS(batch_size=self._batch_size)
|
|
|
+class TSPredictor(HPPredictor):
|
|
|
+ def _build_data_reader(self):
|
|
|
+ return None
|