Преглед изворни кода

update_uvdoc_warp (#2654)

* update_uvdoc_warp

* update_docstring

* update_docstring
Sunflower7788 пре 11 месеци
родитељ
комит
36c0a298a2

+ 2 - 1
paddlex/inference/models_new/__init__.py

@@ -34,7 +34,8 @@ from .semantic_segmentation import SegPredictor
 # from .ts_fc import TSFcPredictor
 # from .ts_ad import TSAdPredictor
 # from .ts_cls import TSClsPredictor
-# from .image_unwarping import WarpPredictor
+from .image_unwarping import WarpPredictor
+
 # from .multilabel_classification import MLClasPredictor
 # from .anomaly_detection import UadPredictor
 # from .formula_recognition import LaTeXOCRPredictor

+ 15 - 0
paddlex/inference/models_new/image_unwarping/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .predictor import WarpPredictor

+ 104 - 0
paddlex/inference/models_new/image_unwarping/predictor.py

@@ -0,0 +1,104 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Union, Dict, List, Tuple
+import numpy as np
+
+from ....modules.image_unwarping.model_list import MODELS
+from ...common.batch_sampler import ImageBatchSampler
+from ...common.reader import ReadImage
+from ..common import (
+    Normalize,
+    ToCHWImage,
+    ToBatch,
+    StaticInfer,
+)
+from ..base import BasicPredictor
+from .processors import DocTrPostProcess
+from .result import DocTrResult
+
+
+class WarpPredictor(BasicPredictor):
+    """WarpPredictor that inherits from BasicPredictor."""
+
+    entities = MODELS
+
+    def __init__(self, *args: List, **kwargs: Dict) -> None:
+        """Initializes WarpPredictor.
+
+        Args:
+            *args: Arbitrary positional arguments passed to the superclass.
+            **kwargs: Arbitrary keyword arguments passed to the superclass.
+        """
+        super().__init__(*args, **kwargs)
+        self.preprocessors, self.infer, self.postprocessors = self._build()
+
+    def _build_batch_sampler(self) -> ImageBatchSampler:
+        """Builds and returns an ImageBatchSampler instance.
+
+        Returns:
+            ImageBatchSampler: An instance of ImageBatchSampler.
+        """
+        return ImageBatchSampler()
+
+    def _get_result_class(self) -> type:
+        """Returns the warpping result, DocTrResult.
+
+        Returns:
+            type: The DocTrResult.
+        """
+        return DocTrResult
+
+    def _build(self) -> Tuple:
+        """Build the preprocessors, inference engine, and postprocessors based on the configuration.
+
+        Returns:
+            tuple: A tuple containing the preprocessors, inference engine, and postprocessors.
+        """
+        preprocessors = {"Read": ReadImage(format="RGB")}
+        preprocessors["Normalize"] = Normalize(mean=0.0, std=1.0, scale=1.0 / 255)
+        preprocessors["ToCHW"] = ToCHWImage()
+        preprocessors["ToBatch"] = ToBatch()
+
+        infer = StaticInfer(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+
+        postprocessors = {"DocTrPostProcess": DocTrPostProcess()}
+        return preprocessors, infer, postprocessors
+
+    def process(self, batch_data: List[Union[str, np.ndarray]]) -> Dict[str, Any]:
+        """
+        Process a batch of data through the preprocessing, inference, and postprocessing.
+
+        Args:
+            batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
+
+        Returns:
+            dict: A dictionary containing the input path, raw image, class IDs, scores, and label names for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
+        """
+        batch_raw_imgs = self.preprocessors["Read"](imgs=batch_data)
+        batch_imgs = self.preprocessors["Normalize"](imgs=batch_raw_imgs)
+        batch_imgs = self.preprocessors["ToCHW"](imgs=batch_imgs)
+        x = self.preprocessors["ToBatch"](imgs=batch_imgs)
+        batch_preds = self.infer(x=x)
+        batch_warp_preds = self.postprocessors["DocTrPostProcess"](batch_preds)
+
+        return {
+            "input_path": batch_data,
+            "input_img": batch_raw_imgs,
+            "doctr_img": batch_warp_preds,
+        }

+ 88 - 0
paddlex/inference/models_new/image_unwarping/processors.py

@@ -0,0 +1,88 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from typing import List, Union, Tuple
+
+
+class DocTrPostProcess:
+    """
+    Post-processing class for cropping regions from images (though currently only performs scaling and color channel adjustments).
+
+    Attributes:
+        scale (np.float32): A scaling factor to be applied to the image pixel values.
+            Defaults to 255.0 if not provided.
+
+    Methods:
+        __call__(imgs: List[Union[np.ndarray, Tuple[np.ndarray, ...]]]) -> List[np.ndarray]:
+            Call method to process a list of images.
+        doctr(pred: Union[np.ndarray, Tuple[np.ndarray, ...]]) -> np.ndarray:
+            Method to process a single image or a tuple/list containing an image.
+    """
+
+    def __init__(self, scale: Union[str, float, None] = None):
+        """
+        Initializes the DocTrPostProcess class with a scaling factor.
+
+        Args:
+            scale (Union[str, float, None]): A scaling factor for the image pixel values.
+                If a string is provided, it will be converted to a float. Defaults to 255.0.
+        """
+        super().__init__()
+        self.scale = (
+            np.float32(scale) if isinstance(scale, (str, float)) else np.float32(255.0)
+        )
+
+    def __call__(
+        self, imgs: List[Union[np.ndarray, Tuple[np.ndarray, ...]]]
+    ) -> List[np.ndarray]:
+        """
+        Processes a list of images using the `doctr` method.
+
+        Args:
+            imgs (List[Union[np.ndarray, Tuple[np.ndarray, ...]]]): A list of images to process.
+                Each image can be a numpy array or a tuple containing a numpy array.
+
+        Returns:
+            List[np.ndarray]: A list of processed images.
+        """
+        return [self.doctr(img) for img in imgs]
+
+    def doctr(self, pred: Union[np.ndarray, Tuple[np.ndarray, ...]]) -> np.ndarray:
+        """
+        Processes a single image.
+
+        Args:
+            pred (Union[np.ndarray, Tuple[np.ndarray, ...]]): The image to process, which can be
+                a numpy array or a tuple containing a numpy array. Only the first element is used if it's a tuple.
+
+        Returns:
+            np.ndarray: The processed image.
+
+        Raises:
+            AssertionError: If the input is not a numpy array.
+        """
+        if isinstance(pred, tuple):
+            im = pred[0]
+        else:
+            im = pred
+        assert isinstance(
+            im, np.ndarray
+        ), "Invalid input 'im' in DocTrPostProcess. Expected a numpy array."
+        im = im.squeeze()
+        im = im.transpose(1, 2, 0)
+        im *= self.scale
+        im = im[:, :, ::-1]
+        im = im.astype("uint8", copy=False)
+        return im

+ 39 - 0
paddlex/inference/models_new/image_unwarping/result.py

@@ -0,0 +1,39 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import numpy as np
+from ...common.result import BaseCVResult
+
+
+class DocTrResult(BaseCVResult):
+    """
+    Result class for DocTr, encapsulating the output of a document image processing task.
+
+    Attributes:
+        (inherited from BaseCVResult): Any attributes defined in the base class.
+
+    Methods:
+        _to_img(self) -> np.ndarray:
+            Converts the stored image result to a numpy array.
+    """
+
+    def _to_img(self) -> np.ndarray:
+        result = np.array(self["doctr_img"])
+        return result
+
+    def _to_str(self, _, *args, **kwargs):
+        data = copy.deepcopy(self)
+        data["doctr_img"] = "..."
+        return super()._to_str(data, *args, **kwargs)