|
|
@@ -29,6 +29,7 @@ from ..components import (
|
|
|
CropByPolys,
|
|
|
SortPolyBoxes,
|
|
|
SortQuadBoxes,
|
|
|
+ cal_ocr_word_box,
|
|
|
convert_points_to_boxes,
|
|
|
rotate_image,
|
|
|
)
|
|
|
@@ -129,11 +130,11 @@ class _OCRPipeline(BasePipeline):
|
|
|
{"model_config_error": "config error for text_rec_model!"},
|
|
|
)
|
|
|
self.text_rec_score_thresh = text_rec_config.get("score_thresh", 0)
|
|
|
+ self.return_word_box = text_rec_config.get("return_word_box", False)
|
|
|
self.input_shape = text_rec_config.get("input_shape", None)
|
|
|
self.text_rec_model = self.create_model(
|
|
|
text_rec_config, input_shape=self.input_shape
|
|
|
)
|
|
|
-
|
|
|
self.batch_sampler = ImageBatchSampler(batch_size=config.get("batch_size", 1))
|
|
|
self.img_reader = ReadImage(format="BGR")
|
|
|
|
|
|
@@ -292,6 +293,7 @@ class _OCRPipeline(BasePipeline):
|
|
|
text_det_box_thresh: Optional[float] = None,
|
|
|
text_det_unclip_ratio: Optional[float] = None,
|
|
|
text_rec_score_thresh: Optional[float] = None,
|
|
|
+ return_word_box: Optional[bool] = None,
|
|
|
) -> OCRResult:
|
|
|
"""
|
|
|
Predict OCR results based on input images or arrays with optional preprocessing steps.
|
|
|
@@ -308,6 +310,7 @@ class _OCRPipeline(BasePipeline):
|
|
|
text_det_box_thresh (Optional[float]): Threshold for text detection boxes.
|
|
|
text_det_unclip_ratio (Optional[float]): Ratio for unclipping text detection boxes.
|
|
|
text_rec_score_thresh (Optional[float]): Score threshold for text recognition.
|
|
|
+ return_word_box (Optional[bool]): Whether to return word boxes along with recognized texts.
|
|
|
Returns:
|
|
|
OCRResult: Generator yielding OCR results for each input image.
|
|
|
"""
|
|
|
@@ -330,6 +333,8 @@ class _OCRPipeline(BasePipeline):
|
|
|
|
|
|
if text_rec_score_thresh is None:
|
|
|
text_rec_score_thresh = self.text_rec_score_thresh
|
|
|
+ if return_word_box is None:
|
|
|
+ return_word_box = self.return_word_box
|
|
|
|
|
|
for _, batch_data in enumerate(self.batch_sampler(input)):
|
|
|
image_arrays = self.img_reader(batch_data.instances)
|
|
|
@@ -367,6 +372,7 @@ class _OCRPipeline(BasePipeline):
|
|
|
"text_det_params": text_det_params,
|
|
|
"text_type": self.text_type,
|
|
|
"text_rec_score_thresh": text_rec_score_thresh,
|
|
|
+ "return_word_box": return_word_box,
|
|
|
"rec_texts": [],
|
|
|
"rec_scores": [],
|
|
|
"rec_polys": [],
|
|
|
@@ -433,22 +439,41 @@ class _OCRPipeline(BasePipeline):
|
|
|
all_subs_of_img[x["sub_img_id"]] for x in sorted_subs_info
|
|
|
]
|
|
|
for i, rec_res in enumerate(
|
|
|
- self.text_rec_model(sorted_subs_of_img)
|
|
|
+ self.text_rec_model(
|
|
|
+ sorted_subs_of_img, return_word_box=return_word_box
|
|
|
+ )
|
|
|
):
|
|
|
sub_img_id = sorted_subs_info[i]["sub_img_id"]
|
|
|
sub_img_info_list[sub_img_id]["rec_res"] = rec_res
|
|
|
+ if return_word_box:
|
|
|
+ res["text_word"] = []
|
|
|
+ res["text_word_region"] = []
|
|
|
for sno in range(len(sub_img_info_list)):
|
|
|
rec_res = sub_img_info_list[sno]["rec_res"]
|
|
|
if rec_res["rec_score"] >= text_rec_score_thresh:
|
|
|
- res["rec_texts"].append(rec_res["rec_text"])
|
|
|
+ if return_word_box:
|
|
|
+ word_box_content_list, word_box_list = cal_ocr_word_box(
|
|
|
+ rec_res["rec_text"][0],
|
|
|
+ dt_polys[sno],
|
|
|
+ rec_res["rec_text"][1],
|
|
|
+ )
|
|
|
+ res["text_word"].append(word_box_content_list)
|
|
|
+ res["text_word_region"].append(word_box_list)
|
|
|
+ res["rec_texts"].append(rec_res["rec_text"][0])
|
|
|
+ else:
|
|
|
+ res["rec_texts"].append(rec_res["rec_text"])
|
|
|
res["rec_scores"].append(rec_res["rec_score"])
|
|
|
res["vis_fonts"].append(rec_res["vis_font"])
|
|
|
res["rec_polys"].append(dt_polys[sno])
|
|
|
-
|
|
|
for res in results:
|
|
|
if self.text_type == "general":
|
|
|
rec_boxes = convert_points_to_boxes(res["rec_polys"])
|
|
|
res["rec_boxes"] = rec_boxes
|
|
|
+ if return_word_box:
|
|
|
+ res["text_word_boxes"] = [
|
|
|
+ convert_points_to_boxes(line)
|
|
|
+ for line in res["text_word_region"]
|
|
|
+ ]
|
|
|
else:
|
|
|
res["rec_boxes"] = np.array([])
|
|
|
|