瀏覽代碼

support textline orientation & doc preprocessor for OCR pipeline (#2775)

* support textline orientation & doc preprocessor for OCR pipeline

* support to save doc preprocess output for OCR pipeline

* update
cuicheng01 10 月之前
父節點
當前提交
4de304004f

+ 49 - 2
api_examples/pipelines/test_ocr.py

@@ -16,9 +16,56 @@ from paddlex import create_pipeline
 
 
 pipeline = create_pipeline(pipeline="OCR")
 pipeline = create_pipeline(pipeline="OCR")
 
 
-output = pipeline.predict("./test_samples/general_ocr_002.png")
-
+output = pipeline.predict(
+    "./test_samples/general_ocr_002.png",
+    use_doc_orientation_classify=True,
+    use_doc_unwarping=True,
+    use_textline_orientation=True,
+)
+# output = pipeline.predict(
+#     "./test_samples/general_ocr_002.png",
+#     use_doc_orientation_classify=True,
+#     use_doc_unwarping=True,
+#     use_textline_orientation=False,
+# )
+# output = pipeline.predict(
+#     "./test_samples/general_ocr_002.png",
+#     use_doc_orientation_classify=True,
+#     use_doc_unwarping=False,
+#     use_textline_orientation=True,
+# )
+# output = pipeline.predict(
+#     "./test_samples/general_ocr_002.png",
+#     use_doc_orientation_classify=True,
+#     use_doc_unwarping=False,
+#     use_textline_orientation=False,
+# )
+# output = pipeline.predict(
+#     "./test_samples/general_ocr_002.png",
+#     use_doc_orientation_classify=False,
+#     use_doc_unwarping=True,
+#     use_textline_orientation=True,
+# )
+# output = pipeline.predict(
+#     "./test_samples/general_ocr_002.png",
+#     use_doc_orientation_classify=False,
+#     use_doc_unwarping=True,
+#     use_textline_orientation=False,
+# )
+# output = pipeline.predict(
+#     "./test_samples/general_ocr_002.png",
+#     use_doc_orientation_classify=False,
+#     use_doc_unwarping=False,
+#     use_textline_orientation=True,
+# )
+# output = pipeline.predict(
+#     "./test_samples/general_ocr_002.png",
+#     use_doc_orientation_classify=False,
+#     use_doc_unwarping=False,
+#     use_textline_orientation=False,
+# )
 # output = pipeline.predict("./test_samples/财报1.pdf")
 # output = pipeline.predict("./test_samples/财报1.pdf")
 for res in output:
 for res in output:
     print(res)
     print(res)
     res.save_to_img("./output")
     res.save_to_img("./output")
+    res.save_to_json("./output/res.json")

+ 26 - 1
paddlex/configs/pipelines/OCR.yaml

@@ -3,14 +3,39 @@ pipeline_name: OCR
 
 
 text_type: general
 text_type: general
 
 
+use_doc_preprocessor: True
+use_textline_orientation: True
+
+SubPipelines:
+  DocPreprocessor:
+    pipeline_name: doc_preprocessor
+    use_doc_orientation_classify: True
+    use_doc_unwarping: True
+    SubModules:
+      DocOrientationClassify:
+        module_name: doc_text_orientation
+        model_name: PP-LCNet_x1_0_doc_ori
+        model_dir: null
+        batch_size: 1
+      DocUnwarping:
+        module_name: image_unwarping
+        model_name: UVDoc
+        model_dir: null
+        batch_size: 1
+
 SubModules:
 SubModules:
   TextDetection:
   TextDetection:
     module_name: text_detection
     module_name: text_detection
     model_name: PP-OCRv4_mobile_det
     model_name: PP-OCRv4_mobile_det
     model_dir: null
     model_dir: null
+    batch_size: 1
+  TextLineOrientation:
+    module_name: textline_orientation
+    model_name: PP-LCNet_x0_25_textline_ori 
+    model_dir: null
     batch_size: 1    
     batch_size: 1    
   TextRecognition:
   TextRecognition:
     module_name: text_recognition
     module_name: text_recognition
-    model_name: PP-OCRv4_mobile_rec
+    model_name: PP-OCRv4_mobile_rec 
     model_dir: null
     model_dir: null
     batch_size: 1
     batch_size: 1

+ 198 - 13
paddlex/inference/pipelines_new/ocr/pipeline.py

@@ -12,14 +12,17 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
 import numpy as np
 import numpy as np
+from scipy.ndimage import rotate
 from ...common.reader import ReadImage
 from ...common.reader import ReadImage
 from ...common.batch_sampler import ImageBatchSampler
 from ...common.batch_sampler import ImageBatchSampler
 from ...utils.pp_option import PaddlePredictorOption
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 from ..base import BasePipeline
 from ..components import CropByPolys, SortQuadBoxes, SortPolyBoxes
 from ..components import CropByPolys, SortQuadBoxes, SortPolyBoxes
 from .result import OCRResult
 from .result import OCRResult
+from ..doc_preprocessor.result import DocPreprocessorResult
+from ....utils import logging
 
 
 
 
 class OCRPipeline(BasePipeline):
 class OCRPipeline(BasePipeline):
@@ -49,11 +52,7 @@ class OCRPipeline(BasePipeline):
             device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
             device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
         )
         )
 
 
-        text_det_model_config = config["SubModules"]["TextDetection"]
-        self.text_det_model = self.create_model(text_det_model_config)
-
-        text_rec_model_config = config["SubModules"]["TextRecognition"]
-        self.text_rec_model = self.create_model(text_rec_model_config)
+        self.inintial_predictor(config)
 
 
         self.text_type = config["text_type"]
         self.text_type = config["text_type"]
 
 
@@ -69,8 +68,162 @@ class OCRPipeline(BasePipeline):
         self.batch_sampler = ImageBatchSampler(batch_size=1)
         self.batch_sampler = ImageBatchSampler(batch_size=1)
         self.img_reader = ReadImage(format="BGR")
         self.img_reader = ReadImage(format="BGR")
 
 
+    def set_used_models_flag(self, config: Dict) -> None:
+        """
+        Set the flags for which models to use based on the configuration.
+
+        Args:
+            config (Dict): A dictionary containing configuration settings.
+
+        Returns:
+            None
+        """
+        pipeline_name = config["pipeline_name"]
+
+        self.pipeline_name = pipeline_name
+
+        self.use_doc_preprocessor = False
+
+        if "use_doc_preprocessor" in config:
+            self.use_doc_preprocessor = config["use_doc_preprocessor"]
+
+        self.use_textline_orientation = False
+
+        if "use_textline_orientation" in config:
+            self.use_textline_orientation = config["use_textline_orientation"]
+
+    def inintial_predictor(self, config: Dict) -> None:
+        """Initializes the predictor based on the provided configuration.
+
+        Args:
+            config (Dict): A dictionary containing the configuration for the predictor.
+
+        Returns:
+            None
+        """
+
+        self.set_used_models_flag(config)
+
+        text_det_model_config = config["SubModules"]["TextDetection"]
+        self.text_det_model = self.create_model(text_det_model_config)
+
+        text_rec_model_config = config["SubModules"]["TextRecognition"]
+        self.text_rec_model = self.create_model(text_rec_model_config)
+
+        if self.use_doc_preprocessor:
+            doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
+            self.doc_preprocessor_pipeline = self.create_pipeline(
+                doc_preprocessor_config
+            )
+        # Just for initialize the predictor
+        if self.use_textline_orientation:
+            textline_orientation_config = config["SubModules"]["TextLineOrientation"]
+            self.textline_orientation_model = self.create_model(
+                textline_orientation_config
+            )
+        return
+
+    def rotate_image(
+        self, image_array_list: List[np.ndarray], rotate_angle_list: List[int]
+    ) -> List[np.ndarray]:
+        """
+        Rotate the given image arrays by their corresponding angles.
+        0 corresponds to 0 degrees, 1 corresponds to 180 degrees.
+
+        Args:
+            image_array_list (List[np.ndarray]): A list of input image arrays to be rotated.
+            rotate_angle_list (List[int]): A list of rotation indicators (0 or 1).
+                                        0 means rotate by 0 degrees
+                                        1 means rotate by 180 degrees
+
+        Returns:
+            List[np.ndarray]: A list of rotated image arrays.
+
+        Raises:
+            AssertionError: If any rotate_angle is not 0 or 1.
+            AssertionError: If the lengths of input lists don't match.
+        """
+        assert len(image_array_list) == len(
+            rotate_angle_list
+        ), f"Length of image_array_list ({len(image_array_list)}) must match length of rotate_angle_list ({len(rotate_angle_list)})"
+
+        for angle in rotate_angle_list:
+            assert angle in [0, 1], f"rotate_angle must be 0 or 1, now it's {angle}"
+
+        rotated_images = []
+        for image_array, rotate_indicator in zip(image_array_list, rotate_angle_list):
+            # Convert 0/1 indicator to actual rotation angle
+            rotate_angle = rotate_indicator * 180
+            rotated_image = rotate(image_array, rotate_angle, reshape=True)
+            rotated_images.append(rotated_image)
+
+        return rotated_images
+
+    def check_input_params_valid(self, input_params: Dict) -> bool:
+        """
+        Check if the input parameters are valid based on the initialized models.
+
+        Args:
+            input_params (Dict): A dictionary containing input parameters.
+
+        Returns:
+            bool: True if all required models are initialized according to input parameters, False otherwise.
+        """
+
+        if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor:
+            logging.error(
+                "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
+            )
+            return False
+
+        if (
+            input_params["use_textline_orientation"]
+            and not self.use_textline_orientation
+        ):
+            logging.error(
+                "Set use_textline_orientation, but the models for use_textline_orientation are not initialized."
+            )
+            return False
+
+        return True
+
+    def predict_doc_preprocessor_res(
+        self, image_array: np.ndarray, input_params: dict
+    ) -> tuple[DocPreprocessorResult, np.ndarray]:
+        """
+        Preprocess the document image based on input parameters.
+
+        Args:
+            image_array (np.ndarray): The input image array.
+            input_params (dict): Dictionary containing preprocessing parameters.
+
+        Returns:
+            tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
+                                              result dictionary and the processed image array.
+        """
+        if input_params["use_doc_preprocessor"]:
+            use_doc_orientation_classify = input_params["use_doc_orientation_classify"]
+            use_doc_unwarping = input_params["use_doc_unwarping"]
+            doc_preprocessor_res = next(
+                self.doc_preprocessor_pipeline(
+                    image_array,
+                    use_doc_orientation_classify=use_doc_orientation_classify,
+                    use_doc_unwarping=use_doc_unwarping,
+                )
+            )
+            doc_preprocessor_image = doc_preprocessor_res["output_img"]
+        else:
+            doc_preprocessor_res = {}
+            doc_preprocessor_image = image_array
+        return doc_preprocessor_res, doc_preprocessor_image
+
     def predict(
     def predict(
-        self, input: str | list[str] | np.ndarray | list[np.ndarray], **kwargs
+        self,
+        input: str | list[str] | np.ndarray | list[np.ndarray],
+        use_doc_orientation_classify: bool = False,
+        use_doc_unwarping: bool = False,
+        use_textline_orientation: bool = False,
+        **kwargs,
     ) -> OCRResult:
     ) -> OCRResult:
         """Predicts OCR results for the given input.
         """Predicts OCR results for the given input.
 
 
@@ -82,9 +235,29 @@ class OCRPipeline(BasePipeline):
             OCRResult: An iterable of OCRResult objects, each containing the predicted text and other relevant information.
             OCRResult: An iterable of OCRResult objects, each containing the predicted text and other relevant information.
         """
         """
 
 
+        input_params = {
+            "use_doc_preprocessor": self.use_doc_preprocessor,
+            "use_doc_orientation_classify": use_doc_orientation_classify,
+            "use_doc_unwarping": use_doc_unwarping,
+            "use_textline_orientation": self.use_textline_orientation,
+        }
+        if use_doc_orientation_classify or use_doc_unwarping:
+            input_params["use_doc_preprocessor"] = True
+        else:
+            input_params["use_doc_preprocessor"] = False
+
+        if not self.check_input_params_valid(input_params):
+            yield None
+
         for img_id, batch_data in enumerate(self.batch_sampler(input)):
         for img_id, batch_data in enumerate(self.batch_sampler(input)):
-            raw_img = self.img_reader(batch_data)[0]
-            det_res = next(self.text_det_model(raw_img))
+            image_array = self.img_reader(batch_data)[0]
+            img_id += 1
+
+            doc_preprocessor_res, doc_preprocessor_image = (
+                self.predict_doc_preprocessor_res(image_array, input_params)
+            )
+
+            det_res = next(self.text_det_model(doc_preprocessor_image))
 
 
             dt_polys = det_res["dt_polys"]
             dt_polys = det_res["dt_polys"]
             dt_scores = det_res["dt_scores"]
             dt_scores = det_res["dt_scores"]
@@ -93,19 +266,31 @@ class OCRPipeline(BasePipeline):
 
 
             dt_polys = self._sort_boxes(dt_polys)
             dt_polys = self._sort_boxes(dt_polys)
 
 
-            img_id += 1
-
             single_img_res = {
             single_img_res = {
-                "input_img": raw_img,
+                "input_img": image_array,
+                "doc_preprocessor_image": doc_preprocessor_image,
+                "doc_preprocessor_res": doc_preprocessor_res,
                 "dt_polys": dt_polys,
                 "dt_polys": dt_polys,
                 "img_id": img_id,
                 "img_id": img_id,
+                "input_params": input_params,
                 "text_type": self.text_type,
                 "text_type": self.text_type,
             }
             }
 
 
             single_img_res["rec_text"] = []
             single_img_res["rec_text"] = []
             single_img_res["rec_score"] = []
             single_img_res["rec_score"] = []
             if len(dt_polys) > 0:
             if len(dt_polys) > 0:
-                all_subs_of_img = list(self._crop_by_polys(raw_img, dt_polys))
+                all_subs_of_img = list(
+                    self._crop_by_polys(doc_preprocessor_image, dt_polys)
+                )
+                # use textline orientation model
+                if input_params["use_textline_orientation"]:
+                    angles = [
+                        textline_angle_info["class_ids"][0]
+                        for textline_angle_info in self.textline_orientation_model(
+                            all_subs_of_img
+                        )
+                    ]
+                    all_subs_of_img = self.rotate_image(all_subs_of_img, angles)
 
 
                 for rec_res in self.text_rec_model(all_subs_of_img):
                 for rec_res in self.text_rec_model(all_subs_of_img):
                     single_img_res["rec_text"].append(rec_res["rec_text"])
                     single_img_res["rec_text"].append(rec_res["rec_text"])

+ 8 - 2
paddlex/inference/pipelines_new/ocr/result.py

@@ -38,9 +38,15 @@ class OCRResult(BaseCVResult):
             *args: Additional positional arguments.
             *args: Additional positional arguments.
             **kwargs: Additional keyword arguments.
             **kwargs: Additional keyword arguments.
         """
         """
+        input_params = self["input_params"]
+        img_id = self["img_id"]
+        if input_params["use_doc_preprocessor"]:
+            save_img_path = Path(save_path) / f"doc_preprocessor_result_img_{img_id}.jpg"
+            self["doc_preprocessor_res"].save_to_img(save_img_path)
+
         if not str(save_path).lower().endswith((".jpg", ".png")):
         if not str(save_path).lower().endswith((".jpg", ".png")):
-            img_id = self["img_id"]
             save_path = Path(save_path) / f"res_ocr_{img_id}.jpg"
             save_path = Path(save_path) / f"res_ocr_{img_id}.jpg"
+
         super().save_to_img(save_path, *args, **kwargs)
         super().save_to_img(save_path, *args, **kwargs)
 
 
     def get_minarea_rect(self, points: np.ndarray) -> np.ndarray:
     def get_minarea_rect(self, points: np.ndarray) -> np.ndarray:
@@ -91,7 +97,7 @@ class OCRResult(BaseCVResult):
         boxes = self["dt_polys"]
         boxes = self["dt_polys"]
         txts = self["rec_text"]
         txts = self["rec_text"]
         scores = self["rec_score"]
         scores = self["rec_score"]
-        image = self["input_img"]
+        image = self["doc_preprocessor_image"]
         h, w = image.shape[0:2]
         h, w = image.shape[0:2]
         image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
         image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
         img_left = Image.fromarray(image_rgb)
         img_left = Image.fromarray(image_rgb)