| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- import re
- from typing import List
- import numpy as np
- from ....utils.deps import class_requires_deps, is_dep_available
- from ...utils.benchmark import benchmark
- if is_dep_available("opencv-contrib-python"):
- import cv2
- @benchmark.timeit
- @class_requires_deps("opencv-contrib-python")
- class OCRReisizeNormImg:
- """for ocr image resize and normalization"""
- def __init__(self, rec_image_shape=[3, 48, 320], input_shape=None):
- super().__init__()
- self.rec_image_shape = rec_image_shape
- self.input_shape = input_shape
- self.max_imgW = 3200
- def resize_norm_img(self, img, max_wh_ratio):
- """resize and normalize the img"""
- imgC, imgH, imgW = self.rec_image_shape
- assert imgC == img.shape[2]
- imgW = int((imgH * max_wh_ratio))
- if imgW > self.max_imgW:
- resized_image = cv2.resize(img, (self.max_imgW, imgH))
- resized_w = self.max_imgW
- imgW = self.max_imgW
- else:
- h, w = img.shape[:2]
- ratio = w / float(h)
- if math.ceil(imgH * ratio) > imgW:
- resized_w = imgW
- else:
- resized_w = int(math.ceil(imgH * ratio))
- resized_image = cv2.resize(img, (resized_w, imgH))
- resized_image = resized_image.astype("float32")
- resized_image = resized_image.transpose((2, 0, 1)) / 255
- resized_image -= 0.5
- resized_image /= 0.5
- padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
- padding_im[:, :, 0:resized_w] = resized_image
- return padding_im
- def __call__(self, imgs):
- """apply"""
- if self.input_shape is None:
- return [self.resize(img) for img in imgs]
- else:
- return [self.staticResize(img) for img in imgs]
- def resize(self, img):
- imgC, imgH, imgW = self.rec_image_shape
- max_wh_ratio = imgW / imgH
- h, w = img.shape[:2]
- wh_ratio = w * 1.0 / h
- max_wh_ratio = max(max_wh_ratio, wh_ratio)
- img = self.resize_norm_img(img, max_wh_ratio)
- return img
- def staticResize(self, img):
- imgC, imgH, imgW = self.input_shape
- resized_image = cv2.resize(img, (int(imgW), int(imgH)))
- resized_image = resized_image.transpose((2, 0, 1)) / 255
- resized_image -= 0.5
- resized_image /= 0.5
- return resized_image
- @benchmark.timeit
- class BaseRecLabelDecode:
- """Convert between text-label and text-index"""
- def __init__(self, character_str=None, use_space_char=True):
- super().__init__()
- self.reverse = False
- character_list = (
- list(character_str)
- if character_str is not None
- else list("0123456789abcdefghijklmnopqrstuvwxyz")
- )
- if use_space_char:
- character_list.append(" ")
- character_list = self.add_special_char(character_list)
- self.dict = {}
- for i, char in enumerate(character_list):
- self.dict[char] = i
- self.character = character_list
- def pred_reverse(self, pred):
- """pred_reverse"""
- pred_re = []
- c_current = ""
- for c in pred:
- if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
- if c_current != "":
- pred_re.append(c_current)
- pred_re.append(c)
- c_current = ""
- else:
- c_current += c
- if c_current != "":
- pred_re.append(c_current)
- return "".join(pred_re[::-1])
- def add_special_char(self, character_list):
- """add_special_char"""
- return character_list
- def get_word_info(self, text, selection):
- """
- Group the decoded characters and record the corresponding decoded positions.
- Args:
- text: the decoded text
- selection: the bool array that identifies which columns of features are decoded as non-separated characters
- Returns:
- word_list: list of the grouped words
- word_col_list: list of decoding positions corresponding to each character in the grouped word
- state_list: list of marker to identify the type of grouping words, including two types of grouping words:
- - 'cn': continuous chinese characters (e.g., 你好啊)
- - 'en&num': continuous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
- The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
- """
- state = None
- word_content = []
- word_col_content = []
- word_list = []
- word_col_list = []
- state_list = []
- valid_col = np.where(selection == True)[0]
- for c_i, char in enumerate(text):
- if "\u4e00" <= char <= "\u9fff":
- c_state = "cn"
- elif bool(re.search("[a-zA-Z0-9]", char)):
- c_state = "en&num"
- else:
- c_state = "symbol"
- if (
- char == "."
- and state == "en&num"
- and c_i + 1 < len(text)
- and bool(re.search("[0-9]", text[c_i + 1]))
- ):
- c_state = "en&num"
- if char == "-" and state == "en&num":
- c_state = "en&num"
- if state is None:
- state = c_state
- if state != c_state:
- if len(word_content) != 0:
- word_list.append(word_content)
- word_col_list.append(word_col_content)
- state_list.append(state)
- word_content = []
- word_col_content = []
- state = c_state
- word_content.append(char)
- word_col_content.append(int(valid_col[c_i]))
- if len(word_content) != 0:
- word_list.append(word_content)
- word_col_list.append(word_col_content)
- state_list.append(state)
- return word_list, word_col_list, state_list
- def decode(
- self,
- text_index,
- text_prob=None,
- is_remove_duplicate=False,
- return_word_box=False,
- ):
- """convert text-index into text-label."""
- result_list = []
- ignored_tokens = self.get_ignored_tokens()
- batch_size = len(text_index)
- for batch_idx in range(batch_size):
- selection = np.ones(len(text_index[batch_idx]), dtype=bool)
- if is_remove_duplicate:
- selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
- for ignored_token in ignored_tokens:
- selection &= text_index[batch_idx] != ignored_token
- char_list = [
- self.character[text_id] for text_id in text_index[batch_idx][selection]
- ]
- if text_prob is not None:
- conf_list = text_prob[batch_idx][selection]
- else:
- conf_list = [1] * len(selection)
- if len(conf_list) == 0:
- conf_list = [0]
- text = "".join(char_list)
- if self.reverse: # for arabic rec
- text = self.pred_reverse(text)
- if return_word_box:
- word_list, word_col_list, state_list = self.get_word_info(
- text, selection
- )
- result_list.append(
- (
- text,
- np.mean(conf_list).tolist(),
- [
- len(text_index[batch_idx]),
- word_list,
- word_col_list,
- state_list,
- ],
- )
- )
- else:
- result_list.append((text, np.mean(conf_list).tolist()))
- return result_list
- def get_ignored_tokens(self):
- """get_ignored_tokens"""
- return [0] # for ctc blank
- def __call__(self, pred):
- """apply"""
- preds = np.array(pred)
- if isinstance(preds, tuple) or isinstance(preds, list):
- preds = preds[-1]
- preds_idx = preds.argmax(axis=-1)
- preds_prob = preds.max(axis=-1)
- text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
- texts = []
- scores = []
- for t in text:
- texts.append(t[0])
- scores.append(t[1])
- return texts, scores
- @benchmark.timeit
- class CTCLabelDecode(BaseRecLabelDecode):
- """Convert between text-label and text-index"""
- def __init__(self, character_list=None, use_space_char=True):
- super().__init__(character_list, use_space_char=use_space_char)
- def __call__(self, pred, return_word_box=False, **kwargs):
- """apply"""
- preds = np.array(pred[0])
- preds_idx = preds.argmax(axis=-1)
- preds_prob = preds.max(axis=-1)
- text = self.decode(
- preds_idx,
- preds_prob,
- is_remove_duplicate=True,
- return_word_box=return_word_box,
- )
- if return_word_box:
- for rec_idx, rec in enumerate(text):
- wh_ratio = kwargs["wh_ratio_list"][rec_idx]
- max_wh_ratio = kwargs["max_wh_ratio"]
- rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio)
- texts = []
- scores = []
- for t in text:
- texts.append(t[0] if len(t) <= 2 else (t[0], t[2]))
- scores.append(t[1])
- return texts, scores
- def add_special_char(self, character_list):
- """add_special_char"""
- character_list = ["blank"] + character_list
- return character_list
- @benchmark.timeit
- class ToBatch:
- """A class for batching and padding images to a uniform width."""
- def __pad_imgs(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
- """Pad images to the maximum width in the batch.
- Args:
- imgs (list of np.ndarrays): List of images to pad.
- Returns:
- list of np.ndarrays: List of padded images.
- """
- max_width = max(img.shape[2] for img in imgs)
- padded_imgs = []
- for img in imgs:
- _, height, width = img.shape
- pad_width = max_width - width
- padded_img = np.pad(
- img,
- ((0, 0), (0, 0), (0, pad_width)),
- mode="constant",
- constant_values=0,
- )
- padded_imgs.append(padded_img)
- return padded_imgs
- def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
- """Call method to pad images and stack them into a batch.
- Args:
- imgs (list of np.ndarrays): List of images to process.
- Returns:
- list of np.ndarrays: List containing a stacked tensor of the padded images.
- """
- imgs = self.__pad_imgs(imgs)
- return [np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)]
|