|
|
@@ -34,16 +34,6 @@ class OCRPipeline(BasePipeline):
|
|
|
self,
|
|
|
config: Dict,
|
|
|
device: Optional[str] = None,
|
|
|
- use_doc_orientation_classify: Optional[bool] = None,
|
|
|
- use_doc_unwarping: Optional[bool] = None,
|
|
|
- use_textline_orientation: Optional[bool] = None,
|
|
|
- limit_side_len: Optional[int] = None,
|
|
|
- limit_type: Optional[str] = None,
|
|
|
- thresh: Optional[float] = None,
|
|
|
- box_thresh: Optional[float] = None,
|
|
|
- max_candidates: Optional[int] = None,
|
|
|
- unclip_ratio: Optional[float] = None,
|
|
|
- use_dilation: Optional[bool] = None,
|
|
|
pp_option: Optional[PaddlePredictorOption] = None,
|
|
|
use_hpip: bool = False,
|
|
|
hpi_params: Optional[Dict[str, Any]] = None,
|
|
|
@@ -52,69 +42,63 @@ class OCRPipeline(BasePipeline):
|
|
|
Initializes the class with given configurations and options.
|
|
|
|
|
|
Args:
|
|
|
- config (Dict): Configuration dictionary containing model and other parameters.
|
|
|
- device (Union[str, None]): The device to run the prediction on.
|
|
|
- use_textline_orientation (Union[bool, None]): Whether to use textline orientation.
|
|
|
- use_doc_orientation_classify (Union[bool, None]): Whether to use document orientation classification.
|
|
|
- use_doc_unwarping (Union[bool, None]): Whether to use document unwarping.
|
|
|
- limit_side_len (Union[int, None]): Limit of side length.
|
|
|
- limit_type (Union[str, None]): Type of limit.
|
|
|
- thresh (Union[float, None]): Threshold value.
|
|
|
- box_thresh (Union[float, None]): Box threshold value.
|
|
|
- max_candidates (Union[int, None]): Maximum number of candidates.
|
|
|
- unclip_ratio (Union[float, None]): Unclip ratio.
|
|
|
- use_dilation (Union[bool, None]): Whether to use dilation.
|
|
|
- pp_option (Union[PaddlePredictorOption, None]): Options for PaddlePaddle predictor.
|
|
|
- use_hpip (Union[bool, None]): Whether to use high-performance inference.
|
|
|
- hpi_params (Union[Dict[str, Any], None]): HPIP specific parameters.
|
|
|
+ config (Dict): Configuration dictionary containing various settings.
|
|
|
+ device (str, optional): Device to run the predictions on. Defaults to None.
|
|
|
+ pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
|
|
|
+ use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
|
|
|
+ hpi_params (Optional[Dict[str, Any]], optional): HPIP parameters. Defaults to None.
|
|
|
"""
|
|
|
super().__init__(
|
|
|
device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
|
|
|
)
|
|
|
|
|
|
- self.use_textline_orientation = (
|
|
|
- use_textline_orientation
|
|
|
- if use_textline_orientation is not None
|
|
|
- else config.get("use_textline_orientation", False)
|
|
|
- )
|
|
|
- self.use_doc_preprocessor = self.get_preprocessor_value(
|
|
|
- use_doc_orientation_classify, use_doc_unwarping, config, False
|
|
|
- )
|
|
|
-
|
|
|
- text_det_default_params = {
|
|
|
- "limit_side_len": 960,
|
|
|
- "limit_type": "max",
|
|
|
- "thresh": 0.3,
|
|
|
- "box_thresh": 0.6,
|
|
|
- "max_candidates": 1000,
|
|
|
- "unclip_ratio": 2.0,
|
|
|
- "use_dilation": False,
|
|
|
- }
|
|
|
-
|
|
|
- text_det_config = config["SubModules"]["TextDetection"]
|
|
|
- for key, default_params in text_det_default_params.items():
|
|
|
- text_det_config[key] = locals().get(
|
|
|
- key, text_det_config.get(key, default_params)
|
|
|
+ self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
|
|
|
+ if self.use_doc_preprocessor:
|
|
|
+ doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
|
|
|
+ self.doc_preprocessor_pipeline = self.create_pipeline(
|
|
|
+ doc_preprocessor_config
|
|
|
)
|
|
|
- self.text_det_model = self.create_model(text_det_config)
|
|
|
-
|
|
|
- text_rec_config = config["SubModules"]["TextRecognition"]
|
|
|
- self.text_rec_model = self.create_model(text_rec_config)
|
|
|
|
|
|
+ self.use_textline_orientation = config.get("use_textline_orientation", True)
|
|
|
if self.use_textline_orientation:
|
|
|
textline_orientation_config = config["SubModules"]["TextLineOrientation"]
|
|
|
+ # TODO: add batch_size
|
|
|
+ # batch_size = textline_orientation_config.get("batch_size", 1)
|
|
|
+ # self.textline_orientation_model = self.create_model(
|
|
|
+ # textline_orientation_config, batch_size=batch_size
|
|
|
+ # )
|
|
|
self.textline_orientation_model = self.create_model(
|
|
|
textline_orientation_config
|
|
|
)
|
|
|
|
|
|
- if self.use_doc_preprocessor:
|
|
|
- doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
|
|
|
- self.doc_preprocessor_pipeline = self.create_pipeline(
|
|
|
- doc_preprocessor_config
|
|
|
- )
|
|
|
+ text_det_config = config["SubModules"]["TextDetection"]
|
|
|
+ self.text_det_limit_side_len = text_det_config.get("limit_side_len", 960)
|
|
|
+ self.text_det_limit_type = text_det_config.get("limit_type", "max")
|
|
|
+ self.text_det_thresh = text_det_config.get("thresh", 0.3)
|
|
|
+ self.text_det_box_thresh = text_det_config.get("box_thresh", 0.6)
|
|
|
+ self.text_det_max_candidates = text_det_config.get("max_candidates", 1000)
|
|
|
+ self.text_det_unclip_ratio = text_det_config.get("unclip_ratio", 2.0)
|
|
|
+ self.text_det_use_dilation = text_det_config.get("use_dilation", False)
|
|
|
+ self.text_det_model = self.create_model(
|
|
|
+ text_det_config,
|
|
|
+ limit_side_len=self.text_det_limit_side_len,
|
|
|
+ limit_type=self.text_det_limit_type,
|
|
|
+ thresh=self.text_det_thresh,
|
|
|
+ box_thresh=self.text_det_box_thresh,
|
|
|
+ max_candidates=self.text_det_max_candidates,
|
|
|
+ unclip_ratio=self.text_det_unclip_ratio,
|
|
|
+ use_dilation=self.text_det_use_dilation,
|
|
|
+ )
|
|
|
|
|
|
- self.text_type = config["text_type"]
|
|
|
+ text_rec_config = config["SubModules"]["TextRecognition"]
|
|
|
+ # TODO: add batch_size
|
|
|
+ # batch_size = text_rec_config.get("batch_size", 1)
|
|
|
+ # self.text_rec_model = self.create_model(text_rec_config,
|
|
|
+ # batch_size=batch_size)
|
|
|
+ self.text_rec_score_thresh = text_rec_config.get("score_thresh", 0)
|
|
|
+ self.text_rec_model = self.create_model(text_rec_config)
|
|
|
|
|
|
+ self.text_type = config["text_type"]
|
|
|
if self.text_type == "general":
|
|
|
self._sort_boxes = SortQuadBoxes()
|
|
|
self._crop_by_polys = CropByPolys(det_box_type="quad")
|
|
|
@@ -127,16 +111,6 @@ class OCRPipeline(BasePipeline):
|
|
|
self.batch_sampler = ImageBatchSampler(batch_size=1)
|
|
|
self.img_reader = ReadImage(format="BGR")
|
|
|
|
|
|
- @staticmethod
|
|
|
- def get_preprocessor_value(orientation, unwarping, config, default):
|
|
|
- if orientation is None and unwarping is None:
|
|
|
- return config.get("use_doc_preprocessor", default)
|
|
|
- else:
|
|
|
- if orientation is False and unwarping is False:
|
|
|
- return False
|
|
|
- else:
|
|
|
- return True
|
|
|
-
|
|
|
def rotate_image(
|
|
|
self, image_array_list: List[np.ndarray], rotate_angle_list: List[int]
|
|
|
) -> List[np.ndarray]:
|
|
|
@@ -202,22 +176,24 @@ class OCRPipeline(BasePipeline):
|
|
|
return True
|
|
|
|
|
|
def predict_doc_preprocessor_res(
|
|
|
- self, image_array: np.ndarray, input_params: dict
|
|
|
+ self, image_array: np.ndarray, model_settings: dict
|
|
|
) -> tuple[DocPreprocessorResult, np.ndarray]:
|
|
|
"""
|
|
|
Preprocess the document image based on input parameters.
|
|
|
|
|
|
Args:
|
|
|
image_array (np.ndarray): The input image array.
|
|
|
- input_params (dict): Dictionary containing preprocessing parameters.
|
|
|
+ model_settings (dict): Dictionary containing preprocessing parameters.
|
|
|
|
|
|
Returns:
|
|
|
tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
|
|
|
result dictionary and the processed image array.
|
|
|
"""
|
|
|
- if input_params["use_doc_preprocessor"]:
|
|
|
- use_doc_orientation_classify = input_params["use_doc_orientation_classify"]
|
|
|
- use_doc_unwarping = input_params["use_doc_unwarping"]
|
|
|
+ if model_settings["use_doc_preprocessor"]:
|
|
|
+ use_doc_orientation_classify = model_settings[
|
|
|
+ "use_doc_orientation_classify"
|
|
|
+ ]
|
|
|
+ use_doc_unwarping = model_settings["use_doc_unwarping"]
|
|
|
doc_preprocessor_res = next(
|
|
|
self.doc_preprocessor_pipeline(
|
|
|
image_array,
|
|
|
@@ -229,61 +205,161 @@ class OCRPipeline(BasePipeline):
|
|
|
doc_preprocessor_res = {"output_img": image_array}
|
|
|
return doc_preprocessor_res
|
|
|
|
|
|
+ def get_model_settings(
|
|
|
+ self,
|
|
|
+ use_doc_orientation_classify: Optional[bool],
|
|
|
+ use_doc_unwarping: Optional[bool],
|
|
|
+ use_textline_orientation: Optional[bool],
|
|
|
+ ) -> dict:
|
|
|
+ """
|
|
|
+ Get the model settings based on the provided parameters or default values.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
|
|
|
+ use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
|
|
|
+ use_textline_orientation (Optional[bool]): Whether to use textline orientation.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: A dictionary containing the model settings.
|
|
|
+ """
|
|
|
+ if use_doc_orientation_classify is None:
|
|
|
+ use_doc_orientation_classify = self.use_doc_orientation_classify
|
|
|
+ if use_doc_unwarping is None:
|
|
|
+ use_doc_unwarping = self.use_doc_unwarping
|
|
|
+ if use_textline_orientation is None:
|
|
|
+ use_textline_orientation = self.use_textline_orientation
|
|
|
+ return dict(
|
|
|
+ use_doc_orientation_classify=use_doc_orientation_classify,
|
|
|
+ use_doc_unwarping=use_doc_unwarping,
|
|
|
+ use_textline_orientation=use_textline_orientation,
|
|
|
+ )
|
|
|
+
|
|
|
+ def get_text_det_params(
|
|
|
+ self,
|
|
|
+ text_det_limit_side_len: Optional[int] = None,
|
|
|
+ text_det_limit_type: Optional[str] = None,
|
|
|
+ text_det_thresh: Optional[float] = None,
|
|
|
+ text_det_box_thresh: Optional[float] = None,
|
|
|
+ text_det_max_candidates: Optional[int] = None,
|
|
|
+ text_det_unclip_ratio: Optional[float] = None,
|
|
|
+ text_det_use_dilation: Optional[bool] = None,
|
|
|
+ ) -> dict:
|
|
|
+ """
|
|
|
+ Get text detection parameters.
|
|
|
+
|
|
|
+ If a parameter is None, its default value from the instance will be used.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ text_det_limit_side_len (Optional[int]): The maximum side length of the text box.
|
|
|
+ text_det_limit_type (Optional[str]): The type of limit to apply to the text box.
|
|
|
+ text_det_thresh (Optional[float]): The threshold for text detection.
|
|
|
+ text_det_box_thresh (Optional[float]): The threshold for the bounding box.
|
|
|
+ text_det_max_candidates (Optional[int]): The maximum number of candidate text boxes.
|
|
|
+ text_det_unclip_ratio (Optional[float]): The ratio for unclipping the text box.
|
|
|
+ text_det_use_dilation (Optional[bool]): Whether to use dilation in text detection.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: A dictionary containing the text detection parameters.
|
|
|
+ """
|
|
|
+ if text_det_limit_side_len is None:
|
|
|
+ text_det_limit_side_len = self.text_det_limit_side_len
|
|
|
+ if text_det_limit_type is None:
|
|
|
+ text_det_limit_type = self.text_det_limit_type
|
|
|
+ if text_det_thresh is None:
|
|
|
+ text_det_thresh = self.text_det_thresh
|
|
|
+ if text_det_box_thresh is None:
|
|
|
+ text_det_box_thresh = self.text_det_box_thresh
|
|
|
+ if text_det_max_candidates is None:
|
|
|
+ text_det_max_candidates = self.text_det_max_candidates
|
|
|
+ if text_det_unclip_ratio is None:
|
|
|
+ text_det_unclip_ratio = self.text_det_unclip_ratio
|
|
|
+ if text_det_use_dilation is None:
|
|
|
+ text_det_use_dilation = self.text_det_use_dilation
|
|
|
+ return dict(
|
|
|
+ limit_side_len=text_det_limit_side_len,
|
|
|
+ limit_type=text_det_limit_type,
|
|
|
+ thresh=text_det_thresh,
|
|
|
+ box_thresh=text_det_box_thresh,
|
|
|
+ max_candidates=text_det_max_candidates,
|
|
|
+ unclip_ratio=text_det_unclip_ratio,
|
|
|
+ use_dilation=text_det_use_dilation,
|
|
|
+ )
|
|
|
+
|
|
|
def predict(
|
|
|
self,
|
|
|
input: str | list[str] | np.ndarray | list[np.ndarray],
|
|
|
- use_doc_orientation_classify: bool = False,
|
|
|
- use_doc_unwarping: bool = False,
|
|
|
- use_textline_orientation: bool = False,
|
|
|
- limit_side_len: int = 960,
|
|
|
- limit_type: str = "max",
|
|
|
- thresh: float = 0.3,
|
|
|
- box_thresh: float = 0.6,
|
|
|
- max_candidates: int = 1000,
|
|
|
- unclip_ratio: float = 2.0,
|
|
|
- use_dilation: bool = False,
|
|
|
- **kwargs,
|
|
|
+ use_doc_orientation_classify: Optional[bool] = None,
|
|
|
+ use_doc_unwarping: Optional[bool] = None,
|
|
|
+ use_textline_orientation: Optional[bool] = None,
|
|
|
+ text_det_limit_side_len: Optional[int] = None,
|
|
|
+ text_det_limit_type: Optional[str] = None,
|
|
|
+ text_det_thresh: Optional[float] = None,
|
|
|
+ text_det_box_thresh: Optional[float] = None,
|
|
|
+ text_det_max_candidates: Optional[int] = None,
|
|
|
+ text_det_unclip_ratio: Optional[float] = None,
|
|
|
+ text_det_use_dilation: Optional[bool] = None,
|
|
|
+ text_rec_score_thresh: Optional[float] = None,
|
|
|
) -> OCRResult:
|
|
|
- """Predicts OCR results for the given input.
|
|
|
+ """
|
|
|
+ Predict OCR results based on input images or arrays with optional preprocessing steps.
|
|
|
|
|
|
Args:
|
|
|
- input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) or path(s) to the images or pdf(s).
|
|
|
- **kwargs: Additional keyword arguments that can be passed to the function.
|
|
|
-
|
|
|
+ input (str | list[str] | np.ndarray | list[np.ndarray]): Input image of pdf path(s) or numpy array(s).
|
|
|
+ use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
|
|
|
+ use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
|
|
|
+ use_textline_orientation (Optional[bool]): Whether to use textline orientation prediction.
|
|
|
+ text_det_limit_side_len (Optional[int]): Maximum side length for text detection.
|
|
|
+ text_det_limit_type (Optional[str]): Type of limit to apply for text detection.
|
|
|
+ text_det_thresh (Optional[float]): Threshold for text detection.
|
|
|
+ text_det_box_thresh (Optional[float]): Threshold for text detection boxes.
|
|
|
+ text_det_max_candidates (Optional[int]): Maximum number of text detection candidates.
|
|
|
+ text_det_unclip_ratio (Optional[float]): Ratio for unclipping text detection boxes.
|
|
|
+ text_det_use_dilation (Optional[bool]): Whether to use dilation in text detection.
|
|
|
+ text_rec_score_thresh (Optional[float]): Score threshold for text recognition.
|
|
|
Returns:
|
|
|
- OCRResult: An iterable of OCRResult objects, each containing the predicted text and other relevant information.
|
|
|
+ OCRResult: Generator yielding OCR results for each input image.
|
|
|
"""
|
|
|
|
|
|
- model_settings = {
|
|
|
- "use_doc_orientation_classify": use_doc_orientation_classify,
|
|
|
- "use_doc_unwarping": use_doc_unwarping,
|
|
|
- "use_textline_orientation": use_textline_orientation,
|
|
|
- }
|
|
|
- if use_doc_orientation_classify or use_doc_unwarping:
|
|
|
+ model_settings = self.get_model_settings(
|
|
|
+ use_doc_orientation_classify, use_doc_unwarping, use_textline_orientation
|
|
|
+ )
|
|
|
+ if (
|
|
|
+ model_settings["use_doc_orientation_classify"]
|
|
|
+ or model_settings["use_doc_unwarping"]
|
|
|
+ ):
|
|
|
model_settings["use_doc_preprocessor"] = True
|
|
|
else:
|
|
|
model_settings["use_doc_preprocessor"] = False
|
|
|
|
|
|
if not self.check_model_settings_valid(model_settings):
|
|
|
- yield None
|
|
|
-
|
|
|
- text_det_params = {
|
|
|
- "limit_side_len": limit_side_len,
|
|
|
- "limit_type": limit_type,
|
|
|
- "thresh": thresh,
|
|
|
- "box_thresh": box_thresh,
|
|
|
- "max_candidates": max_candidates,
|
|
|
- "unclip_ratio": unclip_ratio,
|
|
|
- "use_dilation": use_dilation,
|
|
|
- }
|
|
|
+ yield {"error": "the input params for model settings are invalid!"}
|
|
|
+
|
|
|
+ text_det_params = self.get_text_det_params(
|
|
|
+ text_det_limit_side_len,
|
|
|
+ text_det_limit_type,
|
|
|
+ text_det_thresh,
|
|
|
+ text_det_box_thresh,
|
|
|
+ text_det_max_candidates,
|
|
|
+ text_det_unclip_ratio,
|
|
|
+ text_det_use_dilation,
|
|
|
+ )
|
|
|
+
|
|
|
+ if text_rec_score_thresh is None:
|
|
|
+ text_rec_score_thresh = self.text_rec_score_thresh
|
|
|
|
|
|
for img_id, batch_data in enumerate(self.batch_sampler(input)):
|
|
|
+ if not isinstance(batch_data[0], str):
|
|
|
+ # TODO: add support input_pth for ndarray and pdf
|
|
|
+ input_path = f"{img_id}"
|
|
|
+ else:
|
|
|
+ input_path = batch_data[0]
|
|
|
+
|
|
|
image_array = self.img_reader(batch_data)[0]
|
|
|
- img_id += 1
|
|
|
|
|
|
doc_preprocessor_res = self.predict_doc_preprocessor_res(
|
|
|
image_array, model_settings
|
|
|
)
|
|
|
+
|
|
|
doc_preprocessor_image = doc_preprocessor_res["output_img"]
|
|
|
|
|
|
det_res = next(
|
|
|
@@ -296,19 +372,17 @@ class OCRPipeline(BasePipeline):
|
|
|
dt_polys = self._sort_boxes(dt_polys)
|
|
|
|
|
|
single_img_res = {
|
|
|
- "input_path": input,
|
|
|
- # TODO: `doc_preprocessor_image` parameter does not need to be retained here, it requires further confirmation.
|
|
|
- "doc_preprocessor_image": doc_preprocessor_image,
|
|
|
+ "input_path": batch_data[0],
|
|
|
"doc_preprocessor_res": doc_preprocessor_res,
|
|
|
"dt_polys": dt_polys,
|
|
|
- "img_id": img_id,
|
|
|
- "input_params": model_settings,
|
|
|
+ "model_settings": model_settings,
|
|
|
"text_det_params": text_det_params,
|
|
|
"text_type": self.text_type,
|
|
|
}
|
|
|
|
|
|
- single_img_res["rec_text"] = []
|
|
|
- single_img_res["rec_score"] = []
|
|
|
+ single_img_res["rec_texts"] = []
|
|
|
+ single_img_res["rec_scores"] = []
|
|
|
+ single_img_res["rec_boxes"] = []
|
|
|
if len(dt_polys) > 0:
|
|
|
all_subs_of_img = list(
|
|
|
self._crop_by_polys(doc_preprocessor_image, dt_polys)
|
|
|
@@ -324,8 +398,11 @@ class OCRPipeline(BasePipeline):
|
|
|
single_img_res["textline_orientation_angle"] = angles
|
|
|
all_subs_of_img = self.rotate_image(all_subs_of_img, angles)
|
|
|
|
|
|
+ rno = -1
|
|
|
for rec_res in self.text_rec_model(all_subs_of_img):
|
|
|
- single_img_res["rec_text"].append(rec_res["rec_text"])
|
|
|
- single_img_res["rec_score"].append(rec_res["rec_score"])
|
|
|
-
|
|
|
+ rno += 1
|
|
|
+ if rec_res["rec_score"] >= text_rec_score_thresh:
|
|
|
+ single_img_res["rec_texts"].append(rec_res["rec_text"])
|
|
|
+ single_img_res["rec_scores"].append(rec_res["rec_score"])
|
|
|
+ single_img_res["rec_boxes"].append(dt_polys[rno])
|
|
|
yield OCRResult(single_img_res)
|