Просмотр исходного кода

Support key parameters for OCR pipeline (#2810)

* Support key parameters for OCR pipeline

* update OCR.yaml
cuicheng01 10 месяцев назад
Родитель
Сommit
9b6dded599

+ 5 - 3
api_examples/pipelines/test_ocr.py

@@ -14,13 +14,15 @@
 
 from paddlex import create_pipeline
 
-pipeline = create_pipeline(pipeline="OCR")
+pipeline = create_pipeline(pipeline="OCR", limit_side_len=320)
 
 output = pipeline.predict(
     "./test_samples/general_ocr_002.png",
     use_doc_orientation_classify=True,
-    use_doc_unwarping=True,
-    use_textline_orientation=True,
+    use_doc_unwarping=False,
+    use_textline_orientation=False,
+    unclip_ratio=3.0,
+    limit_side_len=1920,
 )
 # output = pipeline.predict(
 #     "./test_samples/general_ocr_002.png",

+ 9 - 2
paddlex/configs/pipelines/OCR.yaml

@@ -3,8 +3,8 @@ pipeline_name: OCR
 
 text_type: general
 
-use_doc_preprocessor: True
-use_textline_orientation: True
+use_doc_preprocessor: False
+use_textline_orientation: False
 
 SubPipelines:
   DocPreprocessor:
@@ -29,6 +29,13 @@ SubModules:
     model_name: PP-OCRv4_mobile_det
     model_dir: null
     batch_size: 1
+    limit_side_len: 960
+    limit_type: max
+    thresh: 0.3
+    box_thresh: 0.6
+    max_candidates: 1000
+    unclip_ratio: 2.0
+    use_dilation: False
   TextLineOrientation:
     module_name: textline_orientation
     model_name: PP-LCNet_x0_25_textline_ori 

+ 114 - 82
paddlex/inference/pipelines_new/ocr/pipeline.py

@@ -33,8 +33,18 @@ class OCRPipeline(BasePipeline):
     def __init__(
         self,
         config: Dict,
-        device: str = None,
-        pp_option: PaddlePredictorOption = None,
+        device: Optional[str] = None,
+        use_doc_orientation_classify: Optional[bool] = None,
+        use_doc_unwarping: Optional[bool] = None,
+        use_textline_orientation: Optional[bool] = None,
+        limit_side_len: Optional[int] = None,
+        limit_type: Optional[str] = None,
+        thresh: Optional[float] = None,
+        box_thresh: Optional[float] = None,
+        max_candidates: Optional[int] = None,
+        unclip_ratio: Optional[float] = None,
+        use_dilation: Optional[bool] = None,
+        pp_option: Optional[PaddlePredictorOption] = None,
         use_hpip: bool = False,
         hpi_params: Optional[Dict[str, Any]] = None,
     ) -> None:
@@ -43,16 +53,65 @@ class OCRPipeline(BasePipeline):
 
         Args:
             config (Dict): Configuration dictionary containing model and other parameters.
-            device (str): The device to run the prediction on. Default is None.
-            pp_option (PaddlePredictorOption): Options for PaddlePaddle predictor. Default is None.
-            use_hpip (bool): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
-            hpi_params (Optional[Dict[str, Any]]): HPIP specific parameters. Default is None.
+            device (Union[str, None]): The device to run the prediction on.
+            use_textline_orientation (Union[bool, None]): Whether to use textline orientation.
+            use_doc_orientation_classify (Union[bool, None]): Whether to use document orientation classification.
+            use_doc_unwarping (Union[bool, None]): Whether to use document unwarping.
+            limit_side_len (Union[int, None]): Limit of side length.
+            limit_type (Union[str, None]): Type of limit.
+            thresh (Union[float, None]): Threshold value.
+            box_thresh (Union[float, None]): Box threshold value.
+            max_candidates (Union[int, None]): Maximum number of candidates.
+            unclip_ratio (Union[float, None]): Unclip ratio.
+            use_dilation (Union[bool, None]): Whether to use dilation.
+            pp_option (Union[PaddlePredictorOption, None]): Options for PaddlePaddle predictor.
+            use_hpip (Union[bool, None]): Whether to use high-performance inference.
+            hpi_params (Union[Dict[str, Any], None]): HPIP specific parameters.
         """
         super().__init__(
             device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
         )
 
-        self.inintial_predictor(config)
+        self.use_textline_orientation = (
+            use_textline_orientation
+            if use_textline_orientation is not None
+            else config.get("use_textline_orientation", False)
+        )
+        self.use_doc_preprocessor = self.get_preprocessor_value(
+            use_doc_orientation_classify, use_doc_unwarping, config, False
+        )
+
+        text_det_default_params = {
+            "limit_side_len": 960,
+            "limit_type": "max",
+            "thresh": 0.3,
+            "box_thresh": 0.6,
+            "max_candidates": 1000,
+            "unclip_ratio": 2.0,
+            "use_dilation": False,
+        }
+
+        text_det_config = config["SubModules"]["TextDetection"]
+        for key, default_params in text_det_default_params.items():
+            text_det_config[key] = locals().get(
+                key, text_det_config.get(key, default_params)
+            )
+        self.text_det_model = self.create_model(text_det_config)
+
+        text_rec_config = config["SubModules"]["TextRecognition"]
+        self.text_rec_model = self.create_model(text_rec_config)
+
+        if self.use_textline_orientation:
+            textline_orientation_config = config["SubModules"]["TextLineOrientation"]
+            self.textline_orientation_model = self.create_model(
+                textline_orientation_config
+            )
+
+        if self.use_doc_preprocessor:
+            doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
+            self.doc_preprocessor_pipeline = self.create_pipeline(
+                doc_preprocessor_config
+            )
 
         self.text_type = config["text_type"]
 
@@ -68,60 +127,15 @@ class OCRPipeline(BasePipeline):
         self.batch_sampler = ImageBatchSampler(batch_size=1)
         self.img_reader = ReadImage(format="BGR")
 
-    def set_used_models_flag(self, config: Dict) -> None:
-        """
-        Set the flags for which models to use based on the configuration.
-
-        Args:
-            config (Dict): A dictionary containing configuration settings.
-
-        Returns:
-            None
-        """
-        pipeline_name = config["pipeline_name"]
-
-        self.pipeline_name = pipeline_name
-
-        self.use_doc_preprocessor = False
-
-        if "use_doc_preprocessor" in config:
-            self.use_doc_preprocessor = config["use_doc_preprocessor"]
-
-        self.use_textline_orientation = False
-
-        if "use_textline_orientation" in config:
-            self.use_textline_orientation = config["use_textline_orientation"]
-
-    def inintial_predictor(self, config: Dict) -> None:
-        """Initializes the predictor based on the provided configuration.
-
-        Args:
-            config (Dict): A dictionary containing the configuration for the predictor.
-
-        Returns:
-            None
-        """
-
-        self.set_used_models_flag(config)
-
-        text_det_model_config = config["SubModules"]["TextDetection"]
-        self.text_det_model = self.create_model(text_det_model_config)
-
-        text_rec_model_config = config["SubModules"]["TextRecognition"]
-        self.text_rec_model = self.create_model(text_rec_model_config)
-
-        if self.use_doc_preprocessor:
-            doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
-            self.doc_preprocessor_pipeline = self.create_pipeline(
-                doc_preprocessor_config
-            )
-        # Just for initialize the predictor
-        if self.use_textline_orientation:
-            textline_orientation_config = config["SubModules"]["TextLineOrientation"]
-            self.textline_orientation_model = self.create_model(
-                textline_orientation_config
-            )
-        return
+    @staticmethod
+    def get_preprocessor_value(orientation, unwarping, config, default):
+        if orientation is None and unwarping is None:
+            return config.get("use_doc_preprocessor", default)
+        else:
+            if orientation is False and unwarping is False:
+                return False
+            else:
+                return True
 
     def rotate_image(
         self, image_array_list: List[np.ndarray], rotate_angle_list: List[int]
@@ -159,25 +173,25 @@ class OCRPipeline(BasePipeline):
 
         return rotated_images
 
-    def check_input_params_valid(self, input_params: Dict) -> bool:
+    def check_model_settings_valid(self, model_settings: Dict) -> bool:
         """
         Check if the input parameters are valid based on the initialized models.
 
         Args:
-            input_params (Dict): A dictionary containing input parameters.
+            model_info_params(Dict): A dictionary containing input parameters.
 
         Returns:
             bool: True if all required models are initialized according to input parameters, False otherwise.
         """
 
-        if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor:
+        if model_settings["use_doc_preprocessor"] and not self.use_doc_preprocessor:
             logging.error(
                 "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
             )
             return False
 
         if (
-            input_params["use_textline_orientation"]
+            model_settings["use_textline_orientation"]
             and not self.use_textline_orientation
         ):
             logging.error(
@@ -211,11 +225,9 @@ class OCRPipeline(BasePipeline):
                     use_doc_unwarping=use_doc_unwarping,
                 )
             )
-            doc_preprocessor_image = doc_preprocessor_res["output_img"]
         else:
-            doc_preprocessor_res = {}
-            doc_preprocessor_image = image_array
-        return doc_preprocessor_res, doc_preprocessor_image
+            doc_preprocessor_res = {"output_img": image_array}
+        return doc_preprocessor_res
 
     def predict(
         self,
@@ -223,6 +235,13 @@ class OCRPipeline(BasePipeline):
         use_doc_orientation_classify: bool = False,
         use_doc_unwarping: bool = False,
         use_textline_orientation: bool = False,
+        limit_side_len: int = 960,
+        limit_type: str = "max",
+        thresh: float = 0.3,
+        box_thresh: float = 0.6,
+        max_candidates: int = 1000,
+        unclip_ratio: float = 2.0,
+        use_dilation: bool = False,
         **kwargs,
     ) -> OCRResult:
         """Predicts OCR results for the given input.
@@ -235,44 +254,56 @@ class OCRPipeline(BasePipeline):
             OCRResult: An iterable of OCRResult objects, each containing the predicted text and other relevant information.
         """
 
-        input_params = {
-            "use_doc_preprocessor": self.use_doc_preprocessor,
+        model_settings = {
             "use_doc_orientation_classify": use_doc_orientation_classify,
             "use_doc_unwarping": use_doc_unwarping,
-            "use_textline_orientation": self.use_textline_orientation,
+            "use_textline_orientation": use_textline_orientation,
         }
         if use_doc_orientation_classify or use_doc_unwarping:
-            input_params["use_doc_preprocessor"] = True
+            model_settings["use_doc_preprocessor"] = True
         else:
-            input_params["use_doc_preprocessor"] = False
+            model_settings["use_doc_preprocessor"] = False
 
-        if not self.check_input_params_valid(input_params):
+        if not self.check_model_settings_valid(model_settings):
             yield None
 
+        text_det_params = {
+            "limit_side_len": limit_side_len,
+            "limit_type": limit_type,
+            "thresh": thresh,
+            "box_thresh": box_thresh,
+            "max_candidates": max_candidates,
+            "unclip_ratio": unclip_ratio,
+            "use_dilation": use_dilation,
+        }
+
         for img_id, batch_data in enumerate(self.batch_sampler(input)):
             image_array = self.img_reader(batch_data)[0]
             img_id += 1
 
-            doc_preprocessor_res, doc_preprocessor_image = (
-                self.predict_doc_preprocessor_res(image_array, input_params)
+            doc_preprocessor_res = self.predict_doc_preprocessor_res(
+                image_array, model_settings
             )
+            doc_preprocessor_image = doc_preprocessor_res["output_img"]
 
-            det_res = next(self.text_det_model(doc_preprocessor_image))
+            det_res = next(
+                self.text_det_model(doc_preprocessor_image, **text_det_params)
+            )
 
             dt_polys = det_res["dt_polys"]
             dt_scores = det_res["dt_scores"]
 
-            ########## [TODO] Need to confirm filtering thresholds for detection and recognition modules
-
             dt_polys = self._sort_boxes(dt_polys)
 
             single_img_res = {
                 "input_path": input,
+                # TODO: `doc_preprocessor_image` parameter does not need to be retained here, it requires further confirmation.
                 "doc_preprocessor_image": doc_preprocessor_image,
                 "doc_preprocessor_res": doc_preprocessor_res,
                 "dt_polys": dt_polys,
                 "img_id": img_id,
-                "input_params": input_params,
+                "input_params": model_settings,
+                "text_det_params": text_det_params,
                 "text_type": self.text_type,
             }
 
@@ -283,13 +314,14 @@ class OCRPipeline(BasePipeline):
                     self._crop_by_polys(doc_preprocessor_image, dt_polys)
                 )
                 # use textline orientation model
-                if input_params["use_textline_orientation"]:
+                if model_settings["use_textline_orientation"]:
                     angles = [
                         textline_angle_info["class_ids"][0]
                         for textline_angle_info in self.textline_orientation_model(
                             all_subs_of_img
                         )
                     ]
+                    single_img_res["textline_orientation_angle"] = angles
                     all_subs_of_img = self.rotate_image(all_subs_of_img, angles)
 
                 for rec_res in self.text_rec_model(all_subs_of_img):