Ver Fonte

add uvdoc (#1991)

* add uvdoc

* add uvdoc
Sunflower7788 há 1 ano atrás
pai
commit
03dbc08e20

+ 12 - 0
paddlex/configs/image_unwarping/UVDoc.yaml

@@ -0,0 +1,12 @@
+Global:
+  model: UVDoc
+  mode: predict # check_dataset/train/evaluate/predict
+  device: gpu:0
+  output: "output"
+
+Predict:
+  model_dir: "output/best_accuracy"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/doc_test.jpg"
+  kernel_option:
+    run_mode: paddle
+    batch_size: 1

+ 2 - 0
paddlex/inference/predictors/official_models.py

@@ -169,6 +169,8 @@ openatom_rec_svtrv2_ch_infer.tar",
     "PicoDet_layout_1x": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PicoDet-L_layout_infer.tar",
     "SLANet": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/SLANet_infer.tar",
     "LaTeX_OCR_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/LaTeX_OCR_rec_infer.tar",
+    "UVDoc": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/UVDoc_infer.tar",
+
 }
 
 

+ 5 - 7
paddlex/modules/__init__.py

@@ -85,13 +85,10 @@ from .ts_classification import (
     TSCLSExportor,
     TSCLSPredictor,
 )
-from .ts_forecast import (
-    TSFCDatasetChecker,
-    TSFCTrainer,
-    TSFCEvaluator,
-    TSFCExportor,
-    TSFCPredictor,
-)
+
+from .ts_forecast import TSFCDatasetChecker, TSFCTrainer, TSFCEvaluator, TSFCPredictor
+from .image_unwarping import WarpPredictor
+
 
 from .base.predictor.transforms import image_common
 from .image_classification import transforms as cls_transforms
@@ -101,3 +98,4 @@ from .text_recognition import transforms as text_rec_transforms
 from .table_recognition import transforms as table_rec_transforms
 from .semantic_segmentation import transforms as seg_transforms
 from .instance_segmentation import transforms as instance_seg_transforms
+from .image_unwarping import transforms as image_unwarping_transforms

+ 1 - 0
paddlex/modules/base/predictor/utils/official_models.py

@@ -168,6 +168,7 @@ openatom_rec_svtrv2_ch_infer.tar",
     "PicoDet_layout_1x": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PicoDet-L_layout_infer.tar",
     "SLANet": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/SLANet_infer.tar",
     "LaTeX_OCR_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/LaTeX_OCR_rec_infer.tar",
+    "UVDoc": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/UVDoc_infer.tar",
     "DLinear": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/DLinear_infer.tar",
     "NLinear": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/NLinear_infer.tar",
     "RLinear": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/RLinear_infer.tar",

+ 15 - 0
paddlex/modules/image_unwarping/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .predictor import WarpPredictor, transforms

+ 18 - 0
paddlex/modules/image_unwarping/model_list.py

@@ -0,0 +1,18 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+MODELS = [
+    "UVDoc",
+]
+

+ 16 - 0
paddlex/modules/image_unwarping/predictor/__init__.py

@@ -0,0 +1,16 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .predictor import WarpPredictor
+from . import transforms

+ 28 - 0
paddlex/modules/image_unwarping/predictor/keys.py

@@ -0,0 +1,28 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class WarpKeys(object):
+    """
+    This class defines a set of keys used for communication of Cls predictors
+    and transforms. Both predictors and transforms accept a dict or a list of
+    dicts as input, and they get the objects of their interest from the dict, or
+    put the generated objects into the dict, all based on these keys.
+    """
+
+    # Common keys
+    IMAGE = "image"
+    IM_PATH = "input_path"
+    # Suite-specific keys
+    DOCTR_IMG = "doc_img"

+ 74 - 0
paddlex/modules/image_unwarping/predictor/predictor.py

@@ -0,0 +1,74 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import numpy as np
+from pathlib import Path
+
+from ...base import BasePredictor
+from ...base.predictor.transforms import image_common
+from .keys import WarpKeys as K
+from . import transforms as T
+from ..model_list import MODELS
+
+
+class WarpPredictor(BasePredictor):
+    """Clssification Predictor"""
+
+    entities = MODELS
+
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        return [[K.IMAGE], [K.IM_PATH]]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        return [K.DOCTR_IMG]
+
+    def _run(self, batch_input):
+        """run"""
+        input_dict = {}
+        input_dict[K.IMAGE] = np.stack(
+            [data[K.IMAGE] for data in batch_input], axis=0
+        ).astype(dtype=np.float32, copy=False)
+        input_ = [input_dict[K.IMAGE]]
+        outputs = self._predictor.predict(input_)
+        Warp_outs = outputs[0]
+        # In-place update
+        pred = batch_input
+        for dict_, Warp_out in zip(pred, Warp_outs):
+            dict_[K.DOCTR_IMG] = Warp_out
+        return pred
+
+    def _get_pre_transforms_from_config(self):
+        """get preprocess transforms"""
+        pre_transforms = [
+            image_common.ReadImage(format='RGB'),
+            image_common.Normalize(scale=1./255, mean=0.0, std=1.0),
+            image_common.ToCHWImage()
+        ]
+        
+        return pre_transforms
+
+    def _get_post_transforms_from_config(self):
+        """get postprocess transforms"""
+        post_transforms = [
+            T.DocTrPostProcess(scale=255.),
+            T.SaveDocTrResults(self.output)
+        ] # yapf: disable
+        return post_transforms

+ 89 - 0
paddlex/modules/image_unwarping/predictor/transforms.py

@@ -0,0 +1,89 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from pathlib import Path
+import numpy as np
+
+from .keys import WarpKeys as K
+from ...base import BaseTransform
+from ...base.predictor.io import ImageWriter, ImageReader
+from ....utils import logging
+
+
+_all__ = ['DocTrPostProcess',  'SaveDocTrResults']
+
+
+class DocTrPostProcess(BaseTransform):
+    """ normalize image such as substract mean, divide std
+    """
+
+    def __init__(self, scale=None, **kwargs):
+        if isinstance(scale, str):
+            scale = eval(scale)
+        self.scale = np.float32(scale if scale is not None else 255.0)
+
+    def apply(self, data):
+        im = data[K.DOCTR_IMG]
+        assert isinstance(im,
+                          np.ndarray), "invalid input 'im' in DocTrPostProcess"
+
+        im = im.squeeze()
+        im = im.transpose(1, 2, 0)
+        im *= self.scale
+        im = im[:, :, ::-1]
+        im = im.astype("uint8", copy=False) 
+        data[K.DOCTR_IMG] = im
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        return [K.DOCTR_IMG]
+
+    @classmethod
+    def get_output_keys(cls):
+        return [K.DOCTR_IMG]
+
+
+class SaveDocTrResults(BaseTransform):
+
+    _FILE_EXT = '.png'
+
+    def __init__(self, save_dir, file_name=None):
+        super().__init__()
+        self.save_dir = save_dir
+        self._writer = ImageWriter(backend='opencv')
+
+    @staticmethod
+    def _replace_ext(path, new_ext):
+        """replace ext"""
+        stem, _ = os.path.splitext(path)
+        return stem + new_ext
+    
+    def apply(self, data):
+        ori_path = data[K.IM_PATH]
+        file_name = os.path.basename(ori_path)
+        file_name = self._replace_ext(file_name, self._FILE_EXT)
+        save_path = os.path.join(self.save_dir, file_name)
+        doctr_img = data[K.DOCTR_IMG]
+        self._writer.write(save_path, doctr_img)
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        return [K.DOCTR_IMG]
+
+    @classmethod
+    def get_output_keys(cls):
+        return []