|
|
@@ -19,7 +19,12 @@ from ...common.reader import ReadImage
|
|
|
from ...common.batch_sampler import ImageBatchSampler
|
|
|
from ...utils.pp_option import PaddlePredictorOption
|
|
|
from ..base import BasePipeline
|
|
|
-from ..components import CropByPolys, SortQuadBoxes, SortPolyBoxes
|
|
|
+from ..components import (
|
|
|
+ CropByPolys,
|
|
|
+ SortQuadBoxes,
|
|
|
+ SortPolyBoxes,
|
|
|
+ convert_points_to_boxes,
|
|
|
+)
|
|
|
from .result import OCRResult
|
|
|
from ..doc_preprocessor.result import DocPreprocessorResult
|
|
|
from ....utils import logging
|
|
|
@@ -54,14 +59,22 @@ class OCRPipeline(BasePipeline):
|
|
|
|
|
|
self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
|
|
|
if self.use_doc_preprocessor:
|
|
|
- doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
|
|
|
+ doc_preprocessor_config = config.get("SubPipelines", {}).get(
|
|
|
+ "DocPreprocessor",
|
|
|
+ {
|
|
|
+ "pipeline_config_error": "config error for doc_preprocessor_pipeline!"
|
|
|
+ },
|
|
|
+ )
|
|
|
self.doc_preprocessor_pipeline = self.create_pipeline(
|
|
|
doc_preprocessor_config
|
|
|
)
|
|
|
|
|
|
self.use_textline_orientation = config.get("use_textline_orientation", True)
|
|
|
if self.use_textline_orientation:
|
|
|
- textline_orientation_config = config["SubModules"]["TextLineOrientation"]
|
|
|
+ textline_orientation_config = config.get("SubModules", {}).get(
|
|
|
+ "TextLineOrientation",
|
|
|
+ {"model_config_error": "config error for textline_orientation_model!"},
|
|
|
+ )
|
|
|
# TODO: add batch_size
|
|
|
# batch_size = textline_orientation_config.get("batch_size", 1)
|
|
|
# self.textline_orientation_model = self.create_model(
|
|
|
@@ -71,26 +84,42 @@ class OCRPipeline(BasePipeline):
|
|
|
textline_orientation_config
|
|
|
)
|
|
|
|
|
|
- text_det_config = config["SubModules"]["TextDetection"]
|
|
|
- self.text_det_limit_side_len = text_det_config.get("limit_side_len", 960)
|
|
|
- self.text_det_limit_type = text_det_config.get("limit_type", "max")
|
|
|
- self.text_det_thresh = text_det_config.get("thresh", 0.3)
|
|
|
- self.text_det_box_thresh = text_det_config.get("box_thresh", 0.6)
|
|
|
- self.text_det_max_candidates = text_det_config.get("max_candidates", 1000)
|
|
|
- self.text_det_unclip_ratio = text_det_config.get("unclip_ratio", 2.0)
|
|
|
- self.text_det_use_dilation = text_det_config.get("use_dilation", False)
|
|
|
+ text_det_config = config.get("SubModules", {}).get(
|
|
|
+ "TextDetection", {"model_config_error": "config error for text_det_model!"}
|
|
|
+ )
|
|
|
+ self.text_type = config["text_type"]
|
|
|
+ if self.text_type == "general":
|
|
|
+ self.text_det_limit_side_len = text_det_config.get("limit_side_len", 960)
|
|
|
+ self.text_det_limit_type = text_det_config.get("limit_type", "max")
|
|
|
+ self.text_det_thresh = text_det_config.get("thresh", 0.3)
|
|
|
+ self.text_det_box_thresh = text_det_config.get("box_thresh", 0.6)
|
|
|
+ self.text_det_unclip_ratio = text_det_config.get("unclip_ratio", 2.0)
|
|
|
+ self._sort_boxes = SortQuadBoxes()
|
|
|
+ self._crop_by_polys = CropByPolys(det_box_type="quad")
|
|
|
+ elif self.text_type == "seal":
|
|
|
+ self.text_det_limit_side_len = text_det_config.get("limit_side_len", 736)
|
|
|
+ self.text_det_limit_type = text_det_config.get("limit_type", "min")
|
|
|
+ self.text_det_thresh = text_det_config.get("thresh", 0.2)
|
|
|
+ self.text_det_box_thresh = text_det_config.get("box_thresh", 0.6)
|
|
|
+ self.text_det_unclip_ratio = text_det_config.get("unclip_ratio", 0.5)
|
|
|
+ self._sort_boxes = SortPolyBoxes()
|
|
|
+ self._crop_by_polys = CropByPolys(det_box_type="poly")
|
|
|
+ else:
|
|
|
+ raise ValueError("Unsupported text type {}".format(self.text_type))
|
|
|
+
|
|
|
self.text_det_model = self.create_model(
|
|
|
text_det_config,
|
|
|
limit_side_len=self.text_det_limit_side_len,
|
|
|
limit_type=self.text_det_limit_type,
|
|
|
thresh=self.text_det_thresh,
|
|
|
box_thresh=self.text_det_box_thresh,
|
|
|
- max_candidates=self.text_det_max_candidates,
|
|
|
unclip_ratio=self.text_det_unclip_ratio,
|
|
|
- use_dilation=self.text_det_use_dilation,
|
|
|
)
|
|
|
|
|
|
- text_rec_config = config["SubModules"]["TextRecognition"]
|
|
|
+ text_rec_config = config.get("SubModules", {}).get(
|
|
|
+ "TextRecognition",
|
|
|
+ {"model_config_error": "config error for text_rec_model!"},
|
|
|
+ )
|
|
|
# TODO: add batch_size
|
|
|
# batch_size = text_rec_config.get("batch_size", 1)
|
|
|
# self.text_rec_model = self.create_model(text_rec_config,
|
|
|
@@ -98,16 +127,6 @@ class OCRPipeline(BasePipeline):
|
|
|
self.text_rec_score_thresh = text_rec_config.get("score_thresh", 0)
|
|
|
self.text_rec_model = self.create_model(text_rec_config)
|
|
|
|
|
|
- self.text_type = config["text_type"]
|
|
|
- if self.text_type == "general":
|
|
|
- self._sort_boxes = SortQuadBoxes()
|
|
|
- self._crop_by_polys = CropByPolys(det_box_type="quad")
|
|
|
- elif self.text_type == "seal":
|
|
|
- self._sort_boxes = SortPolyBoxes()
|
|
|
- self._crop_by_polys = CropByPolys(det_box_type="poly")
|
|
|
- else:
|
|
|
- raise ValueError("Unsupported text type {}".format(self.text_type))
|
|
|
-
|
|
|
self.batch_sampler = ImageBatchSampler(batch_size=1)
|
|
|
self.img_reader = ReadImage(format="BGR")
|
|
|
|
|
|
@@ -175,36 +194,6 @@ class OCRPipeline(BasePipeline):
|
|
|
|
|
|
return True
|
|
|
|
|
|
- def predict_doc_preprocessor_res(
|
|
|
- self, image_array: np.ndarray, model_settings: dict
|
|
|
- ) -> tuple[DocPreprocessorResult, np.ndarray]:
|
|
|
- """
|
|
|
- Preprocess the document image based on input parameters.
|
|
|
-
|
|
|
- Args:
|
|
|
- image_array (np.ndarray): The input image array.
|
|
|
- model_settings (dict): Dictionary containing preprocessing parameters.
|
|
|
-
|
|
|
- Returns:
|
|
|
- tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
|
|
|
- result dictionary and the processed image array.
|
|
|
- """
|
|
|
- if model_settings["use_doc_preprocessor"]:
|
|
|
- use_doc_orientation_classify = model_settings[
|
|
|
- "use_doc_orientation_classify"
|
|
|
- ]
|
|
|
- use_doc_unwarping = model_settings["use_doc_unwarping"]
|
|
|
- doc_preprocessor_res = next(
|
|
|
- self.doc_preprocessor_pipeline(
|
|
|
- image_array,
|
|
|
- use_doc_orientation_classify=use_doc_orientation_classify,
|
|
|
- use_doc_unwarping=use_doc_unwarping,
|
|
|
- )
|
|
|
- )
|
|
|
- else:
|
|
|
- doc_preprocessor_res = {"output_img": image_array}
|
|
|
- return doc_preprocessor_res
|
|
|
-
|
|
|
def get_model_settings(
|
|
|
self,
|
|
|
use_doc_orientation_classify: Optional[bool],
|
|
|
@@ -222,15 +211,15 @@ class OCRPipeline(BasePipeline):
|
|
|
Returns:
|
|
|
dict: A dictionary containing the model settings.
|
|
|
"""
|
|
|
- if use_doc_orientation_classify is None:
|
|
|
- use_doc_orientation_classify = self.use_doc_orientation_classify
|
|
|
- if use_doc_unwarping is None:
|
|
|
- use_doc_unwarping = self.use_doc_unwarping
|
|
|
+ if use_doc_orientation_classify is None and use_doc_unwarping is None:
|
|
|
+ use_doc_preprocessor = self.use_doc_preprocessor
|
|
|
+ else:
|
|
|
+ use_doc_preprocessor = True
|
|
|
+
|
|
|
if use_textline_orientation is None:
|
|
|
use_textline_orientation = self.use_textline_orientation
|
|
|
return dict(
|
|
|
- use_doc_orientation_classify=use_doc_orientation_classify,
|
|
|
- use_doc_unwarping=use_doc_unwarping,
|
|
|
+ use_doc_preprocessor=use_doc_preprocessor,
|
|
|
use_textline_orientation=use_textline_orientation,
|
|
|
)
|
|
|
|
|
|
@@ -240,9 +229,7 @@ class OCRPipeline(BasePipeline):
|
|
|
text_det_limit_type: Optional[str] = None,
|
|
|
text_det_thresh: Optional[float] = None,
|
|
|
text_det_box_thresh: Optional[float] = None,
|
|
|
- text_det_max_candidates: Optional[int] = None,
|
|
|
text_det_unclip_ratio: Optional[float] = None,
|
|
|
- text_det_use_dilation: Optional[bool] = None,
|
|
|
) -> dict:
|
|
|
"""
|
|
|
Get text detection parameters.
|
|
|
@@ -254,9 +241,7 @@ class OCRPipeline(BasePipeline):
|
|
|
text_det_limit_type (Optional[str]): The type of limit to apply to the text box.
|
|
|
text_det_thresh (Optional[float]): The threshold for text detection.
|
|
|
text_det_box_thresh (Optional[float]): The threshold for the bounding box.
|
|
|
- text_det_max_candidates (Optional[int]): The maximum number of candidate text boxes.
|
|
|
text_det_unclip_ratio (Optional[float]): The ratio for unclipping the text box.
|
|
|
- text_det_use_dilation (Optional[bool]): Whether to use dilation in text detection.
|
|
|
|
|
|
Returns:
|
|
|
dict: A dictionary containing the text detection parameters.
|
|
|
@@ -269,20 +254,14 @@ class OCRPipeline(BasePipeline):
|
|
|
text_det_thresh = self.text_det_thresh
|
|
|
if text_det_box_thresh is None:
|
|
|
text_det_box_thresh = self.text_det_box_thresh
|
|
|
- if text_det_max_candidates is None:
|
|
|
- text_det_max_candidates = self.text_det_max_candidates
|
|
|
if text_det_unclip_ratio is None:
|
|
|
text_det_unclip_ratio = self.text_det_unclip_ratio
|
|
|
- if text_det_use_dilation is None:
|
|
|
- text_det_use_dilation = self.text_det_use_dilation
|
|
|
return dict(
|
|
|
limit_side_len=text_det_limit_side_len,
|
|
|
limit_type=text_det_limit_type,
|
|
|
thresh=text_det_thresh,
|
|
|
box_thresh=text_det_box_thresh,
|
|
|
- max_candidates=text_det_max_candidates,
|
|
|
unclip_ratio=text_det_unclip_ratio,
|
|
|
- use_dilation=text_det_use_dilation,
|
|
|
)
|
|
|
|
|
|
def predict(
|
|
|
@@ -295,9 +274,7 @@ class OCRPipeline(BasePipeline):
|
|
|
text_det_limit_type: Optional[str] = None,
|
|
|
text_det_thresh: Optional[float] = None,
|
|
|
text_det_box_thresh: Optional[float] = None,
|
|
|
- text_det_max_candidates: Optional[int] = None,
|
|
|
text_det_unclip_ratio: Optional[float] = None,
|
|
|
- text_det_use_dilation: Optional[bool] = None,
|
|
|
text_rec_score_thresh: Optional[float] = None,
|
|
|
) -> OCRResult:
|
|
|
"""
|
|
|
@@ -312,9 +289,7 @@ class OCRPipeline(BasePipeline):
|
|
|
text_det_limit_type (Optional[str]): Type of limit to apply for text detection.
|
|
|
text_det_thresh (Optional[float]): Threshold for text detection.
|
|
|
text_det_box_thresh (Optional[float]): Threshold for text detection boxes.
|
|
|
- text_det_max_candidates (Optional[int]): Maximum number of text detection candidates.
|
|
|
text_det_unclip_ratio (Optional[float]): Ratio for unclipping text detection boxes.
|
|
|
- text_det_use_dilation (Optional[bool]): Whether to use dilation in text detection.
|
|
|
text_rec_score_thresh (Optional[float]): Score threshold for text recognition.
|
|
|
Returns:
|
|
|
OCRResult: Generator yielding OCR results for each input image.
|
|
|
@@ -323,13 +298,6 @@ class OCRPipeline(BasePipeline):
|
|
|
model_settings = self.get_model_settings(
|
|
|
use_doc_orientation_classify, use_doc_unwarping, use_textline_orientation
|
|
|
)
|
|
|
- if (
|
|
|
- model_settings["use_doc_orientation_classify"]
|
|
|
- or model_settings["use_doc_unwarping"]
|
|
|
- ):
|
|
|
- model_settings["use_doc_preprocessor"] = True
|
|
|
- else:
|
|
|
- model_settings["use_doc_preprocessor"] = False
|
|
|
|
|
|
if not self.check_model_settings_valid(model_settings):
|
|
|
yield {"error": "the input params for model settings are invalid!"}
|
|
|
@@ -339,9 +307,7 @@ class OCRPipeline(BasePipeline):
|
|
|
text_det_limit_type,
|
|
|
text_det_thresh,
|
|
|
text_det_box_thresh,
|
|
|
- text_det_max_candidates,
|
|
|
text_det_unclip_ratio,
|
|
|
- text_det_use_dilation,
|
|
|
)
|
|
|
|
|
|
if text_rec_score_thresh is None:
|
|
|
@@ -356,9 +322,16 @@ class OCRPipeline(BasePipeline):
|
|
|
|
|
|
image_array = self.img_reader(batch_data)[0]
|
|
|
|
|
|
- doc_preprocessor_res = self.predict_doc_preprocessor_res(
|
|
|
- image_array, model_settings
|
|
|
- )
|
|
|
+ if model_settings["use_doc_preprocessor"]:
|
|
|
+ doc_preprocessor_res = next(
|
|
|
+ self.doc_preprocessor_pipeline(
|
|
|
+ image_array,
|
|
|
+ use_doc_orientation_classify=use_doc_orientation_classify,
|
|
|
+ use_doc_unwarping=use_doc_unwarping,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ doc_preprocessor_res = {"output_img": image_array}
|
|
|
|
|
|
doc_preprocessor_image = doc_preprocessor_res["output_img"]
|
|
|
|
|
|
@@ -372,17 +345,18 @@ class OCRPipeline(BasePipeline):
|
|
|
dt_polys = self._sort_boxes(dt_polys)
|
|
|
|
|
|
single_img_res = {
|
|
|
- "input_path": batch_data[0],
|
|
|
+ "input_path": input_path,
|
|
|
"doc_preprocessor_res": doc_preprocessor_res,
|
|
|
"dt_polys": dt_polys,
|
|
|
"model_settings": model_settings,
|
|
|
"text_det_params": text_det_params,
|
|
|
"text_type": self.text_type,
|
|
|
+ "text_rec_score_thresh": text_rec_score_thresh,
|
|
|
}
|
|
|
|
|
|
single_img_res["rec_texts"] = []
|
|
|
single_img_res["rec_scores"] = []
|
|
|
- single_img_res["rec_boxes"] = []
|
|
|
+ single_img_res["rec_polys"] = []
|
|
|
if len(dt_polys) > 0:
|
|
|
all_subs_of_img = list(
|
|
|
self._crop_by_polys(doc_preprocessor_image, dt_polys)
|
|
|
@@ -404,5 +378,8 @@ class OCRPipeline(BasePipeline):
|
|
|
if rec_res["rec_score"] >= text_rec_score_thresh:
|
|
|
single_img_res["rec_texts"].append(rec_res["rec_text"])
|
|
|
single_img_res["rec_scores"].append(rec_res["rec_score"])
|
|
|
- single_img_res["rec_boxes"].append(dt_polys[rno])
|
|
|
+ single_img_res["rec_polys"].append(dt_polys[rno])
|
|
|
+
|
|
|
+ rec_boxes = convert_points_to_boxes(single_img_res["rec_polys"])
|
|
|
+ single_img_res["rec_boxes"] = rec_boxes
|
|
|
yield OCRResult(single_img_res)
|