zhangyubo0722 3 месяцев назад
Родитель
Сommit
7449b8fc63

+ 31 - 3
paddlex/inference/models/text_recognition/predictor.py

@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import numpy as np
+
 from ....modules.text_recognition.model_list import MODELS
 from ....utils.fonts import (
     ARABIC_FONT,
@@ -39,9 +41,10 @@ class TextRecPredictor(BasePredictor):
     _FUNC_MAP = {}
     register = FuncRegister(_FUNC_MAP)
 
-    def __init__(self, *args, input_shape=None, **kwargs):
+    def __init__(self, *args, input_shape=None, return_word_box=False, **kwargs):
         super().__init__(*args, **kwargs)
         self.input_shape = input_shape
+        self.return_word_box = return_word_box
         self.vis_font = self.get_vis_font()
         self.pre_tfs, self.infer, self.post_op = self._build()
 
@@ -68,12 +71,37 @@ class TextRecPredictor(BasePredictor):
         post_op = self.build_postprocess(**self.config["PostProcess"])
         return pre_tfs, infer, post_op
 
-    def process(self, batch_data):
+    def process(self, batch_data, return_word_box=False):
         batch_raw_imgs = self.pre_tfs["Read"](imgs=batch_data.instances)
+        width_list = []
+        for img in batch_raw_imgs:
+            width_list.append(img.shape[1] / float(img.shape[0]))
+        indices = np.argsort(np.array(width_list))
         batch_imgs = self.pre_tfs["ReisizeNorm"](imgs=batch_raw_imgs)
         x = self.pre_tfs["ToBatch"](imgs=batch_imgs)
         batch_preds = self.infer(x=x)
-        texts, scores = self.post_op(batch_preds)
+        batch_num = self.batch_sampler.batch_size
+        img_num = len(batch_raw_imgs)
+        rec_image_shape = next(
+            op["RecResizeImg"]["image_shape"]
+            for op in self.config["PreProcess"]["transform_ops"]
+            if "RecResizeImg" in op
+        )
+        imgC, imgH, imgW = rec_image_shape[:3]
+        max_wh_ratio = imgW / imgH
+        end_img_no = min(img_num, batch_num)
+        wh_ratio_list = []
+        for ino in range(0, end_img_no):
+            h, w = batch_raw_imgs[indices[ino]].shape[0:2]
+            wh_ratio = w * 1.0 / h
+            max_wh_ratio = max(max_wh_ratio, wh_ratio)
+            wh_ratio_list.append(wh_ratio)
+        texts, scores = self.post_op(
+            batch_preds,
+            return_word_box=return_word_box or self.return_word_box,
+            wh_ratio_list=wh_ratio_list,
+            max_wh_ratio=max_wh_ratio,
+        )
         return {
             "input_path": batch_data.input_paths,
             "page_index": batch_data.page_indexes,

+ 101 - 5
paddlex/inference/models/text_recognition/processors.py

@@ -129,7 +129,76 @@ class BaseRecLabelDecode:
         """add_special_char"""
         return character_list
 
-    def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+    def get_word_info(self, text, selection):
+        """
+        Group the decoded characters and record the corresponding decoded positions.
+
+        Args:
+            text: the decoded text
+            selection: the bool array that identifies which columns of features are decoded as non-separated characters
+        Returns:
+            word_list: list of the grouped words
+            word_col_list: list of decoding positions corresponding to each character in the grouped word
+            state_list: list of marker to identify the type of grouping words, including two types of grouping words:
+                        - 'cn': continuous chinese characters (e.g., 你好啊)
+                        - 'en&num': continuous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
+                        The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
+        """
+        state = None
+        word_content = []
+        word_col_content = []
+        word_list = []
+        word_col_list = []
+        state_list = []
+        valid_col = np.where(selection == True)[0]
+
+        for c_i, char in enumerate(text):
+            if "\u4e00" <= char <= "\u9fff":
+                c_state = "cn"
+            elif bool(re.search("[a-zA-Z0-9]", char)):
+                c_state = "en&num"
+            else:
+                c_state = "symbol"
+
+            if (
+                char == "."
+                and state == "en&num"
+                and c_i + 1 < len(text)
+                and bool(re.search("[0-9]", text[c_i + 1]))
+            ):
+                c_state = "en&num"
+            if char == "-" and state == "en&num":
+                c_state = "en&num"
+
+            if state is None:
+                state = c_state
+
+            if state != c_state:
+                if len(word_content) != 0:
+                    word_list.append(word_content)
+                    word_col_list.append(word_col_content)
+                    state_list.append(state)
+                    word_content = []
+                    word_col_content = []
+                state = c_state
+
+            word_content.append(char)
+            word_col_content.append(int(valid_col[c_i]))
+
+        if len(word_content) != 0:
+            word_list.append(word_content)
+            word_col_list.append(word_col_content)
+            state_list.append(state)
+
+        return word_list, word_col_list, state_list
+
+    def decode(
+        self,
+        text_index,
+        text_prob=None,
+        is_remove_duplicate=False,
+        return_word_box=False,
+    ):
         """convert text-index into text-label."""
         result_list = []
         ignored_tokens = self.get_ignored_tokens()
@@ -156,7 +225,24 @@ class BaseRecLabelDecode:
             if self.reverse:  # for arabic rec
                 text = self.pred_reverse(text)
 
-            result_list.append((text, np.mean(conf_list).tolist()))
+            if return_word_box:
+                word_list, word_col_list, state_list = self.get_word_info(
+                    text, selection
+                )
+                result_list.append(
+                    (
+                        text,
+                        np.mean(conf_list).tolist(),
+                        [
+                            len(text_index[batch_idx]),
+                            word_list,
+                            word_col_list,
+                            state_list,
+                        ],
+                    )
+                )
+            else:
+                result_list.append((text, np.mean(conf_list).tolist()))
         return result_list
 
     def get_ignored_tokens(self):
@@ -186,16 +272,26 @@ class CTCLabelDecode(BaseRecLabelDecode):
     def __init__(self, character_list=None, use_space_char=True):
         super().__init__(character_list, use_space_char=use_space_char)
 
-    def __call__(self, pred):
+    def __call__(self, pred, return_word_box=False, **kwargs):
         """apply"""
         preds = np.array(pred[0])
         preds_idx = preds.argmax(axis=-1)
         preds_prob = preds.max(axis=-1)
-        text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
+        text = self.decode(
+            preds_idx,
+            preds_prob,
+            is_remove_duplicate=True,
+            return_word_box=return_word_box,
+        )
+        if return_word_box:
+            for rec_idx, rec in enumerate(text):
+                wh_ratio = kwargs["wh_ratio_list"][rec_idx]
+                max_wh_ratio = kwargs["max_wh_ratio"]
+                rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio)
         texts = []
         scores = []
         for t in text:
-            texts.append(t[0])
+            texts.append(t[0] if len(t) <= 2 else (t[0], t[2]))
             scores.append(t[1])
         return texts, scores
 

+ 1 - 0
paddlex/inference/pipelines/components/__init__.py

@@ -20,6 +20,7 @@ from .common import (
     CVResult,
     SortPolyBoxes,
     SortQuadBoxes,
+    cal_ocr_word_box,
     convert_points_to_boxes,
     rotate_image,
 )

+ 1 - 0
paddlex/inference/pipelines/components/common/__init__.py

@@ -13,6 +13,7 @@
 # limitations under the License.
 
 from .base_result import BaseResult, CVResult
+from .cal_ocr_word_box import cal_ocr_word_box
 from .convert_points_and_boxes import convert_points_to_boxes
 from .crop_image_regions import CropByBoxes, CropByPolys
 from .sort_boxes import SortPolyBoxes, SortQuadBoxes

+ 104 - 0
paddlex/inference/pipelines/components/common/cal_ocr_word_box.py

@@ -0,0 +1,104 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = ["cal_ocr_word_box"]
+
+import numpy as np
+
+# from .convert_points_and_boxes import convert_points_to_boxes
+
+
+def cal_ocr_word_box(rec_str, box, rec_word_info):
+    """Calculate the detection frame for each word based on the results of recognition and detection of ocr"""
+
+    col_num, word_list, word_col_list, state_list = rec_word_info
+    box = box.tolist()
+    bbox_x_start = box[0][0]
+    bbox_x_end = box[1][0]
+    bbox_y_start = box[0][1]
+    bbox_y_end = box[2][1]
+
+    cell_width = (bbox_x_end - bbox_x_start) / col_num
+
+    word_box_list = []
+    word_box_content_list = []
+    cn_width_list = []
+    cn_col_list = []
+    for word, word_col, state in zip(word_list, word_col_list, state_list):
+        if state == "cn":
+            if len(word_col) != 1:
+                char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width
+                char_width = char_seq_length / (len(word_col) - 1)
+                cn_width_list.append(char_width)
+            cn_col_list += word_col
+            word_box_content_list += word
+        else:
+            cell_x_start = bbox_x_start + int(word_col[0] * cell_width)
+            cell_x_end = bbox_x_start + int((word_col[-1] + 1) * cell_width)
+            cell = (
+                (cell_x_start, bbox_y_start),
+                (cell_x_end, bbox_y_start),
+                (cell_x_end, bbox_y_end),
+                (cell_x_start, bbox_y_end),
+            )
+            word_box_list.append(cell)
+            word_box_content_list.append("".join(word))
+    if len(cn_col_list) != 0:
+        if len(cn_width_list) != 0:
+            avg_char_width = np.mean(cn_width_list)
+        else:
+            avg_char_width = (bbox_x_end - bbox_x_start) / len(rec_str)
+        for center_idx in cn_col_list:
+            center_x = (center_idx + 0.5) * cell_width
+            cell_x_start = max(int(center_x - avg_char_width / 2), 0) + bbox_x_start
+            cell_x_end = (
+                min(int(center_x + avg_char_width / 2), bbox_x_end - bbox_x_start)
+                + bbox_x_start
+            )
+            cell = (
+                (cell_x_start, bbox_y_start),
+                (cell_x_end, bbox_y_start),
+                (cell_x_end, bbox_y_end),
+                (cell_x_start, bbox_y_end),
+            )
+            word_box_list.append(cell)
+    word_box_list = sort_boxes(word_box_list, y_thresh=12)
+    return word_box_content_list, word_box_list
+
+
+def sort_boxes(boxes, y_thresh=10):
+
+    box_centers = [np.mean(box, axis=0) for box in boxes]
+    items = list(zip(boxes, box_centers))
+    items.sort(key=lambda x: x[1][1])
+
+    lines = []
+    current_line = []
+    last_y = None
+    for box, center in items:
+        if last_y is None or abs(center[1] - last_y) < y_thresh:
+            current_line.append((box, center))
+        else:
+            lines.append(current_line)
+            current_line = [(box, center)]
+        last_y = center[1]
+    if current_line:
+        lines.append(current_line)
+
+    final_box = []
+    for line in lines:
+        line = sorted(line, key=lambda x: x[1][0])
+        final_box.extend(box for box, center in line)
+
+    return final_box

+ 29 - 4
paddlex/inference/pipelines/ocr/pipeline.py

@@ -29,6 +29,7 @@ from ..components import (
     CropByPolys,
     SortPolyBoxes,
     SortQuadBoxes,
+    cal_ocr_word_box,
     convert_points_to_boxes,
     rotate_image,
 )
@@ -129,11 +130,11 @@ class _OCRPipeline(BasePipeline):
             {"model_config_error": "config error for text_rec_model!"},
         )
         self.text_rec_score_thresh = text_rec_config.get("score_thresh", 0)
+        self.return_word_box = text_rec_config.get("return_word_box", False)
         self.input_shape = text_rec_config.get("input_shape", None)
         self.text_rec_model = self.create_model(
             text_rec_config, input_shape=self.input_shape
         )
-
         self.batch_sampler = ImageBatchSampler(batch_size=config.get("batch_size", 1))
         self.img_reader = ReadImage(format="BGR")
 
@@ -292,6 +293,7 @@ class _OCRPipeline(BasePipeline):
         text_det_box_thresh: Optional[float] = None,
         text_det_unclip_ratio: Optional[float] = None,
         text_rec_score_thresh: Optional[float] = None,
+        return_word_box: Optional[bool] = None,
     ) -> OCRResult:
         """
         Predict OCR results based on input images or arrays with optional preprocessing steps.
@@ -308,6 +310,7 @@ class _OCRPipeline(BasePipeline):
             text_det_box_thresh (Optional[float]): Threshold for text detection boxes.
             text_det_unclip_ratio (Optional[float]): Ratio for unclipping text detection boxes.
             text_rec_score_thresh (Optional[float]): Score threshold for text recognition.
+            return_word_box (Optional[bool]): Whether to return word boxes along with recognized texts.
         Returns:
             OCRResult: Generator yielding OCR results for each input image.
         """
@@ -330,6 +333,8 @@ class _OCRPipeline(BasePipeline):
 
         if text_rec_score_thresh is None:
             text_rec_score_thresh = self.text_rec_score_thresh
+        if return_word_box is None:
+            return_word_box = self.return_word_box
 
         for _, batch_data in enumerate(self.batch_sampler(input)):
             image_arrays = self.img_reader(batch_data.instances)
@@ -367,6 +372,7 @@ class _OCRPipeline(BasePipeline):
                     "text_det_params": text_det_params,
                     "text_type": self.text_type,
                     "text_rec_score_thresh": text_rec_score_thresh,
+                    "return_word_box": return_word_box,
                     "rec_texts": [],
                     "rec_scores": [],
                     "rec_polys": [],
@@ -433,22 +439,41 @@ class _OCRPipeline(BasePipeline):
                         all_subs_of_img[x["sub_img_id"]] for x in sorted_subs_info
                     ]
                     for i, rec_res in enumerate(
-                        self.text_rec_model(sorted_subs_of_img)
+                        self.text_rec_model(
+                            sorted_subs_of_img, return_word_box=return_word_box
+                        )
                     ):
                         sub_img_id = sorted_subs_info[i]["sub_img_id"]
                         sub_img_info_list[sub_img_id]["rec_res"] = rec_res
+                    if return_word_box:
+                        res["text_word"] = []
+                        res["text_word_region"] = []
                     for sno in range(len(sub_img_info_list)):
                         rec_res = sub_img_info_list[sno]["rec_res"]
                         if rec_res["rec_score"] >= text_rec_score_thresh:
-                            res["rec_texts"].append(rec_res["rec_text"])
+                            if return_word_box:
+                                word_box_content_list, word_box_list = cal_ocr_word_box(
+                                    rec_res["rec_text"][0],
+                                    dt_polys[sno],
+                                    rec_res["rec_text"][1],
+                                )
+                                res["text_word"].append(word_box_content_list)
+                                res["text_word_region"].append(word_box_list)
+                                res["rec_texts"].append(rec_res["rec_text"][0])
+                            else:
+                                res["rec_texts"].append(rec_res["rec_text"])
                             res["rec_scores"].append(rec_res["rec_score"])
                             res["vis_fonts"].append(rec_res["vis_font"])
                             res["rec_polys"].append(dt_polys[sno])
-
             for res in results:
                 if self.text_type == "general":
                     rec_boxes = convert_points_to_boxes(res["rec_polys"])
                     res["rec_boxes"] = rec_boxes
+                    if return_word_box:
+                        res["text_word_boxes"] = [
+                            convert_points_to_boxes(line)
+                            for line in res["text_word_region"]
+                        ]
                 else:
                     res["rec_boxes"] = np.array([])
 

+ 39 - 3
paddlex/inference/pipelines/ocr/result.py

@@ -72,8 +72,35 @@ class OCRResult(BaseCVResult):
         Returns:
             Dict[Image.Image]: A dictionary containing two images: 'doc_preprocessor_res' and 'ocr_res_img'.
         """
-        boxes = self["rec_polys"]
-        txts = self["rec_texts"]
+
+        if "text_word_region" in self:
+            boxes = []
+            txts = []
+            text_word_region = [
+                item for sublist in self["text_word_region"] for item in sublist
+            ]
+            text_word = [item for sublist in self["text_word"] for item in sublist]
+            for idx, word_region in enumerate(text_word_region):
+                char_box = word_region
+                box_height = int(
+                    math.sqrt(
+                        (char_box[0][0] - char_box[3][0]) ** 2
+                        + (char_box[0][1] - char_box[3][1]) ** 2
+                    )
+                )
+                box_width = int(
+                    math.sqrt(
+                        (char_box[0][0] - char_box[1][0]) ** 2
+                        + (char_box[0][1] - char_box[1][1]) ** 2
+                    )
+                )
+                if box_height == 0 or box_width == 0:
+                    continue
+                boxes.append(word_region)
+                txts.append(text_word[idx])
+        else:
+            boxes = self["rec_polys"]
+            txts = self["rec_texts"]
         image = self["doc_preprocessor_res"]["output_img"]
         h, w = image.shape[0:2]
         image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
@@ -102,7 +129,8 @@ class OCRResult(BaseCVResult):
                 else:
                     box_pts = [(int(x), int(y)) for x, y in box.tolist()]
                     draw_left.polygon(box_pts, fill=color)
-
+                if isinstance(txt, tuple):
+                    txt = txt[0]
                 img_right_text = draw_box_txt_fine((w, h), box, txt, vis_font.path)
                 pts = np.array(box, np.int32).reshape((-1, 1, 2))
                 cv2.polylines(img_right_text, [pts], True, color, 1)
@@ -149,6 +177,7 @@ class OCRResult(BaseCVResult):
                 self["textline_orientation_angles"]
             )
         data["text_rec_score_thresh"] = self["text_rec_score_thresh"]
+        data["return_word_box"] = self["return_word_box"]
         data["rec_texts"] = self["rec_texts"]
         data["rec_scores"] = np.array(self["rec_scores"])
         data["rec_polys"] = (
@@ -157,6 +186,9 @@ class OCRResult(BaseCVResult):
             else np.array(self["rec_polys"])
         )
         data["rec_boxes"] = np.array(self["rec_boxes"])
+        if "text_word_boxes" in self:
+            data["text_word_boxes"] = self["text_word_boxes"]
+            data["text_word"] = self["text_word"]
 
         return JsonMixin._to_str(data, *args, **kwargs)
 
@@ -183,10 +215,14 @@ class OCRResult(BaseCVResult):
         if "textline_orientation_angles" in self:
             data["textline_orientation_angles"] = self["textline_orientation_angles"]
         data["text_rec_score_thresh"] = self["text_rec_score_thresh"]
+        data["return_word_box"] = self["return_word_box"]
         data["rec_texts"] = self["rec_texts"]
         data["rec_scores"] = self["rec_scores"]
         data["rec_polys"] = self["rec_polys"]
         data["rec_boxes"] = self["rec_boxes"]
+        if "text_word_boxes" in self:
+            data["text_word_boxes"] = self["text_word_boxes"]
+            data["text_word"] = self["text_word"]
         return JsonMixin._to_json(data, *args, **kwargs)
 
 

+ 5 - 0
paddlex/utils/pipeline_arguments.py

@@ -89,6 +89,11 @@ PIPELINE_ARGUMENTS = {
             "type": float,
             "help": "Sets the score threshold for text recognition.",
         },
+        {
+            "name": "--return_word_box",
+            "type": bool,
+            "help": "Determines whether to return word box",
+        },
     ],
     "object_detection": [
         {