Selaa lähdekoodia

support textline orientation & doc preprocessor for OCR pipeline (#2775)

* support textline orientation & doc preprocessor for OCR pipeline

* support to save doc preprocess output for OCR pipeline

* update
cuicheng01 10 kuukautta sitten
vanhempi
commit
4de304004f

+ 49 - 2
api_examples/pipelines/test_ocr.py

@@ -16,9 +16,56 @@ from paddlex import create_pipeline
 
 pipeline = create_pipeline(pipeline="OCR")
 
-output = pipeline.predict("./test_samples/general_ocr_002.png")
-
+output = pipeline.predict(
+    "./test_samples/general_ocr_002.png",
+    use_doc_orientation_classify=True,
+    use_doc_unwarping=True,
+    use_textline_orientation=True,
+)
+# output = pipeline.predict(
+#     "./test_samples/general_ocr_002.png",
+#     use_doc_orientation_classify=True,
+#     use_doc_unwarping=True,
+#     use_textline_orientation=False,
+# )
+# output = pipeline.predict(
+#     "./test_samples/general_ocr_002.png",
+#     use_doc_orientation_classify=True,
+#     use_doc_unwarping=False,
+#     use_textline_orientation=True,
+# )
+# output = pipeline.predict(
+#     "./test_samples/general_ocr_002.png",
+#     use_doc_orientation_classify=True,
+#     use_doc_unwarping=False,
+#     use_textline_orientation=False,
+# )
+# output = pipeline.predict(
+#     "./test_samples/general_ocr_002.png",
+#     use_doc_orientation_classify=False,
+#     use_doc_unwarping=True,
+#     use_textline_orientation=True,
+# )
+# output = pipeline.predict(
+#     "./test_samples/general_ocr_002.png",
+#     use_doc_orientation_classify=False,
+#     use_doc_unwarping=True,
+#     use_textline_orientation=False,
+# )
+# output = pipeline.predict(
+#     "./test_samples/general_ocr_002.png",
+#     use_doc_orientation_classify=False,
+#     use_doc_unwarping=False,
+#     use_textline_orientation=True,
+# )
+# output = pipeline.predict(
+#     "./test_samples/general_ocr_002.png",
+#     use_doc_orientation_classify=False,
+#     use_doc_unwarping=False,
+#     use_textline_orientation=False,
+# )
 # output = pipeline.predict("./test_samples/财报1.pdf")
 for res in output:
     print(res)
     res.save_to_img("./output")
+    res.save_to_json("./output/res.json")

+ 26 - 1
paddlex/configs/pipelines/OCR.yaml

@@ -3,14 +3,39 @@ pipeline_name: OCR
 
 text_type: general
 
+use_doc_preprocessor: True
+use_textline_orientation: True
+
+SubPipelines:
+  DocPreprocessor:
+    pipeline_name: doc_preprocessor
+    use_doc_orientation_classify: True
+    use_doc_unwarping: True
+    SubModules:
+      DocOrientationClassify:
+        module_name: doc_text_orientation
+        model_name: PP-LCNet_x1_0_doc_ori
+        model_dir: null
+        batch_size: 1
+      DocUnwarping:
+        module_name: image_unwarping
+        model_name: UVDoc
+        model_dir: null
+        batch_size: 1
+
 SubModules:
   TextDetection:
     module_name: text_detection
     model_name: PP-OCRv4_mobile_det
     model_dir: null
+    batch_size: 1
+  TextLineOrientation:
+    module_name: textline_orientation
+    model_name: PP-LCNet_x0_25_textline_ori 
+    model_dir: null
     batch_size: 1    
   TextRecognition:
     module_name: text_recognition
-    model_name: PP-OCRv4_mobile_rec
+    model_name: PP-OCRv4_mobile_rec 
     model_dir: null
     batch_size: 1

+ 198 - 13
paddlex/inference/pipelines_new/ocr/pipeline.py

@@ -12,14 +12,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
 import numpy as np
+from scipy.ndimage import rotate
 from ...common.reader import ReadImage
 from ...common.batch_sampler import ImageBatchSampler
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 from ..components import CropByPolys, SortQuadBoxes, SortPolyBoxes
 from .result import OCRResult
+from ..doc_preprocessor.result import DocPreprocessorResult
+from ....utils import logging
 
 
 class OCRPipeline(BasePipeline):
@@ -49,11 +52,7 @@ class OCRPipeline(BasePipeline):
             device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
         )
 
-        text_det_model_config = config["SubModules"]["TextDetection"]
-        self.text_det_model = self.create_model(text_det_model_config)
-
-        text_rec_model_config = config["SubModules"]["TextRecognition"]
-        self.text_rec_model = self.create_model(text_rec_model_config)
+        self.inintial_predictor(config)
 
         self.text_type = config["text_type"]
 
@@ -69,8 +68,162 @@ class OCRPipeline(BasePipeline):
         self.batch_sampler = ImageBatchSampler(batch_size=1)
         self.img_reader = ReadImage(format="BGR")
 
+    def set_used_models_flag(self, config: Dict) -> None:
+        """
+        Set the flags for which models to use based on the configuration.
+
+        Args:
+            config (Dict): A dictionary containing configuration settings.
+
+        Returns:
+            None
+        """
+        pipeline_name = config["pipeline_name"]
+
+        self.pipeline_name = pipeline_name
+
+        self.use_doc_preprocessor = False
+
+        if "use_doc_preprocessor" in config:
+            self.use_doc_preprocessor = config["use_doc_preprocessor"]
+
+        self.use_textline_orientation = False
+
+        if "use_textline_orientation" in config:
+            self.use_textline_orientation = config["use_textline_orientation"]
+
+    def inintial_predictor(self, config: Dict) -> None:
+        """Initializes the predictor based on the provided configuration.
+
+        Args:
+            config (Dict): A dictionary containing the configuration for the predictor.
+
+        Returns:
+            None
+        """
+
+        self.set_used_models_flag(config)
+
+        text_det_model_config = config["SubModules"]["TextDetection"]
+        self.text_det_model = self.create_model(text_det_model_config)
+
+        text_rec_model_config = config["SubModules"]["TextRecognition"]
+        self.text_rec_model = self.create_model(text_rec_model_config)
+
+        if self.use_doc_preprocessor:
+            doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
+            self.doc_preprocessor_pipeline = self.create_pipeline(
+                doc_preprocessor_config
+            )
+        # Just for initialize the predictor
+        if self.use_textline_orientation:
+            textline_orientation_config = config["SubModules"]["TextLineOrientation"]
+            self.textline_orientation_model = self.create_model(
+                textline_orientation_config
+            )
+        return
+
+    def rotate_image(
+        self, image_array_list: List[np.ndarray], rotate_angle_list: List[int]
+    ) -> List[np.ndarray]:
+        """
+        Rotate the given image arrays by their corresponding angles.
+        0 corresponds to 0 degrees, 1 corresponds to 180 degrees.
+
+        Args:
+            image_array_list (List[np.ndarray]): A list of input image arrays to be rotated.
+            rotate_angle_list (List[int]): A list of rotation indicators (0 or 1).
+                                        0 means rotate by 0 degrees
+                                        1 means rotate by 180 degrees
+
+        Returns:
+            List[np.ndarray]: A list of rotated image arrays.
+
+        Raises:
+            AssertionError: If any rotate_angle is not 0 or 1.
+            AssertionError: If the lengths of input lists don't match.
+        """
+        assert len(image_array_list) == len(
+            rotate_angle_list
+        ), f"Length of image_array_list ({len(image_array_list)}) must match length of rotate_angle_list ({len(rotate_angle_list)})"
+
+        for angle in rotate_angle_list:
+            assert angle in [0, 1], f"rotate_angle must be 0 or 1, now it's {angle}"
+
+        rotated_images = []
+        for image_array, rotate_indicator in zip(image_array_list, rotate_angle_list):
+            # Convert 0/1 indicator to actual rotation angle
+            rotate_angle = rotate_indicator * 180
+            rotated_image = rotate(image_array, rotate_angle, reshape=True)
+            rotated_images.append(rotated_image)
+
+        return rotated_images
+
+    def check_input_params_valid(self, input_params: Dict) -> bool:
+        """
+        Check if the input parameters are valid based on the initialized models.
+
+        Args:
+            input_params (Dict): A dictionary containing input parameters.
+
+        Returns:
+            bool: True if all required models are initialized according to input parameters, False otherwise.
+        """
+
+        if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor:
+            logging.error(
+                "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
+            )
+            return False
+
+        if (
+            input_params["use_textline_orientation"]
+            and not self.use_textline_orientation
+        ):
+            logging.error(
+                "Set use_textline_orientation, but the models for use_textline_orientation are not initialized."
+            )
+            return False
+
+        return True
+
+    def predict_doc_preprocessor_res(
+        self, image_array: np.ndarray, input_params: dict
+    ) -> tuple[DocPreprocessorResult, np.ndarray]:
+        """
+        Preprocess the document image based on input parameters.
+
+        Args:
+            image_array (np.ndarray): The input image array.
+            input_params (dict): Dictionary containing preprocessing parameters.
+
+        Returns:
+            tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
+                                              result dictionary and the processed image array.
+        """
+        if input_params["use_doc_preprocessor"]:
+            use_doc_orientation_classify = input_params["use_doc_orientation_classify"]
+            use_doc_unwarping = input_params["use_doc_unwarping"]
+            doc_preprocessor_res = next(
+                self.doc_preprocessor_pipeline(
+                    image_array,
+                    use_doc_orientation_classify=use_doc_orientation_classify,
+                    use_doc_unwarping=use_doc_unwarping,
+                )
+            )
+            doc_preprocessor_image = doc_preprocessor_res["output_img"]
+        else:
+            doc_preprocessor_res = {}
+            doc_preprocessor_image = image_array
+        return doc_preprocessor_res, doc_preprocessor_image
+
     def predict(
-        self, input: str | list[str] | np.ndarray | list[np.ndarray], **kwargs
+        self,
+        input: str | list[str] | np.ndarray | list[np.ndarray],
+        use_doc_orientation_classify: bool = False,
+        use_doc_unwarping: bool = False,
+        use_textline_orientation: bool = False,
+        **kwargs,
     ) -> OCRResult:
         """Predicts OCR results for the given input.
 
@@ -82,9 +235,29 @@ class OCRPipeline(BasePipeline):
             OCRResult: An iterable of OCRResult objects, each containing the predicted text and other relevant information.
         """
 
+        input_params = {
+            "use_doc_preprocessor": self.use_doc_preprocessor,
+            "use_doc_orientation_classify": use_doc_orientation_classify,
+            "use_doc_unwarping": use_doc_unwarping,
+            "use_textline_orientation": self.use_textline_orientation,
+        }
+        if use_doc_orientation_classify or use_doc_unwarping:
+            input_params["use_doc_preprocessor"] = True
+        else:
+            input_params["use_doc_preprocessor"] = False
+
+        if not self.check_input_params_valid(input_params):
+            yield None
+
         for img_id, batch_data in enumerate(self.batch_sampler(input)):
-            raw_img = self.img_reader(batch_data)[0]
-            det_res = next(self.text_det_model(raw_img))
+            image_array = self.img_reader(batch_data)[0]
+            img_id += 1
+
+            doc_preprocessor_res, doc_preprocessor_image = (
+                self.predict_doc_preprocessor_res(image_array, input_params)
+            )
+
+            det_res = next(self.text_det_model(doc_preprocessor_image))
 
             dt_polys = det_res["dt_polys"]
             dt_scores = det_res["dt_scores"]
@@ -93,19 +266,31 @@ class OCRPipeline(BasePipeline):
 
             dt_polys = self._sort_boxes(dt_polys)
 
-            img_id += 1
-
             single_img_res = {
-                "input_img": raw_img,
+                "input_img": image_array,
+                "doc_preprocessor_image": doc_preprocessor_image,
+                "doc_preprocessor_res": doc_preprocessor_res,
                 "dt_polys": dt_polys,
                 "img_id": img_id,
+                "input_params": input_params,
                 "text_type": self.text_type,
             }
 
             single_img_res["rec_text"] = []
             single_img_res["rec_score"] = []
             if len(dt_polys) > 0:
-                all_subs_of_img = list(self._crop_by_polys(raw_img, dt_polys))
+                all_subs_of_img = list(
+                    self._crop_by_polys(doc_preprocessor_image, dt_polys)
+                )
+                # use textline orientation model
+                if input_params["use_textline_orientation"]:
+                    angles = [
+                        textline_angle_info["class_ids"][0]
+                        for textline_angle_info in self.textline_orientation_model(
+                            all_subs_of_img
+                        )
+                    ]
+                    all_subs_of_img = self.rotate_image(all_subs_of_img, angles)
 
                 for rec_res in self.text_rec_model(all_subs_of_img):
                     single_img_res["rec_text"].append(rec_res["rec_text"])

+ 8 - 2
paddlex/inference/pipelines_new/ocr/result.py

@@ -38,9 +38,15 @@ class OCRResult(BaseCVResult):
             *args: Additional positional arguments.
             **kwargs: Additional keyword arguments.
         """
+        input_params = self["input_params"]
+        img_id = self["img_id"]
+        if input_params["use_doc_preprocessor"]:
+            save_img_path = Path(save_path) / f"doc_preprocessor_result_img_{img_id}.jpg"
+            self["doc_preprocessor_res"].save_to_img(save_img_path)
+
         if not str(save_path).lower().endswith((".jpg", ".png")):
-            img_id = self["img_id"]
             save_path = Path(save_path) / f"res_ocr_{img_id}.jpg"
+
         super().save_to_img(save_path, *args, **kwargs)
 
     def get_minarea_rect(self, points: np.ndarray) -> np.ndarray:
@@ -91,7 +97,7 @@ class OCRResult(BaseCVResult):
         boxes = self["dt_polys"]
         txts = self["rec_text"]
         scores = self["rec_score"]
-        image = self["input_img"]
+        image = self["doc_preprocessor_image"]
         h, w = image.shape[0:2]
         image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
         img_left = Image.fromarray(image_rgb)