# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import re from typing import List import numpy as np from ....utils.deps import class_requires_deps, is_dep_available from ...utils.benchmark import benchmark if is_dep_available("opencv-contrib-python"): import cv2 @benchmark.timeit @class_requires_deps("opencv-contrib-python") class OCRReisizeNormImg: """for ocr image resize and normalization""" def __init__(self, rec_image_shape=[3, 48, 320], input_shape=None): super().__init__() self.rec_image_shape = rec_image_shape self.input_shape = input_shape self.max_imgW = 3200 def resize_norm_img(self, img, max_wh_ratio): """resize and normalize the img""" imgC, imgH, imgW = self.rec_image_shape assert imgC == img.shape[2] imgW = int((imgH * max_wh_ratio)) if imgW > self.max_imgW: resized_image = cv2.resize(img, (self.max_imgW, imgH)) resized_w = self.max_imgW imgW = self.max_imgW else: h, w = img.shape[:2] ratio = w / float(h) if math.ceil(imgH * ratio) > imgW: resized_w = imgW else: resized_w = int(math.ceil(imgH * ratio)) resized_image = cv2.resize(img, (resized_w, imgH)) resized_image = resized_image.astype("float32") resized_image = resized_image.transpose((2, 0, 1)) / 255 resized_image -= 0.5 resized_image /= 0.5 padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) padding_im[:, :, 0:resized_w] = resized_image return padding_im def __call__(self, imgs): """apply""" if self.input_shape is None: return [self.resize(img) for img in imgs] else: return [self.staticResize(img) for img in imgs] def resize(self, img): imgC, imgH, imgW = self.rec_image_shape max_wh_ratio = imgW / imgH h, w = img.shape[:2] wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) img = self.resize_norm_img(img, max_wh_ratio) return img def staticResize(self, img): imgC, imgH, imgW = self.input_shape resized_image = cv2.resize(img, (int(imgW), int(imgH))) resized_image = resized_image.transpose((2, 0, 1)) / 255 resized_image -= 0.5 resized_image /= 0.5 return resized_image @benchmark.timeit class BaseRecLabelDecode: """Convert between text-label and text-index""" def __init__(self, character_str=None, use_space_char=True): super().__init__() self.reverse = False character_list = ( list(character_str) if character_str is not None else list("0123456789abcdefghijklmnopqrstuvwxyz") ) if use_space_char: character_list.append(" ") character_list = self.add_special_char(character_list) self.dict = {} for i, char in enumerate(character_list): self.dict[char] = i self.character = character_list def pred_reverse(self, pred): """pred_reverse""" pred_re = [] c_current = "" for c in pred: if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)): if c_current != "": pred_re.append(c_current) pred_re.append(c) c_current = "" else: c_current += c if c_current != "": pred_re.append(c_current) return "".join(pred_re[::-1]) def add_special_char(self, character_list): """add_special_char""" return character_list def get_word_info(self, text, selection): """ Group the decoded characters and record the corresponding decoded positions. Args: text: the decoded text selection: the bool array that identifies which columns of features are decoded as non-separated characters Returns: word_list: list of the grouped words word_col_list: list of decoding positions corresponding to each character in the grouped word state_list: list of marker to identify the type of grouping words, including two types of grouping words: - 'cn': continuous chinese characters (e.g., 你好啊) - 'en&num': continuous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16) The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.). """ state = None word_content = [] word_col_content = [] word_list = [] word_col_list = [] state_list = [] valid_col = np.where(selection == True)[0] for c_i, char in enumerate(text): if "\u4e00" <= char <= "\u9fff": c_state = "cn" elif bool(re.search("[a-zA-Z0-9]", char)): c_state = "en&num" else: c_state = "symbol" if ( char == "." and state == "en&num" and c_i + 1 < len(text) and bool(re.search("[0-9]", text[c_i + 1])) ): c_state = "en&num" if char == "-" and state == "en&num": c_state = "en&num" if state is None: state = c_state if state != c_state: if len(word_content) != 0: word_list.append(word_content) word_col_list.append(word_col_content) state_list.append(state) word_content = [] word_col_content = [] state = c_state word_content.append(char) word_col_content.append(int(valid_col[c_i])) if len(word_content) != 0: word_list.append(word_content) word_col_list.append(word_col_content) state_list.append(state) return word_list, word_col_list, state_list def decode( self, text_index, text_prob=None, is_remove_duplicate=False, return_word_box=False, ): """convert text-index into text-label.""" result_list = [] ignored_tokens = self.get_ignored_tokens() batch_size = len(text_index) for batch_idx in range(batch_size): selection = np.ones(len(text_index[batch_idx]), dtype=bool) if is_remove_duplicate: selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1] for ignored_token in ignored_tokens: selection &= text_index[batch_idx] != ignored_token char_list = [ self.character[text_id] for text_id in text_index[batch_idx][selection] ] if text_prob is not None: conf_list = text_prob[batch_idx][selection] else: conf_list = [1] * len(selection) if len(conf_list) == 0: conf_list = [0] text = "".join(char_list) if self.reverse: # for arabic rec text = self.pred_reverse(text) if return_word_box: word_list, word_col_list, state_list = self.get_word_info( text, selection ) result_list.append( ( text, np.mean(conf_list).tolist(), [ len(text_index[batch_idx]), word_list, word_col_list, state_list, ], ) ) else: result_list.append((text, np.mean(conf_list).tolist())) return result_list def get_ignored_tokens(self): """get_ignored_tokens""" return [0] # for ctc blank def __call__(self, pred): """apply""" preds = np.array(pred) if isinstance(preds, tuple) or isinstance(preds, list): preds = preds[-1] preds_idx = preds.argmax(axis=-1) preds_prob = preds.max(axis=-1) text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) texts = [] scores = [] for t in text: texts.append(t[0]) scores.append(t[1]) return texts, scores @benchmark.timeit class CTCLabelDecode(BaseRecLabelDecode): """Convert between text-label and text-index""" def __init__(self, character_list=None, use_space_char=True): super().__init__(character_list, use_space_char=use_space_char) def __call__(self, pred, return_word_box=False, **kwargs): """apply""" preds = np.array(pred[0]) preds_idx = preds.argmax(axis=-1) preds_prob = preds.max(axis=-1) text = self.decode( preds_idx, preds_prob, is_remove_duplicate=True, return_word_box=return_word_box, ) if return_word_box: for rec_idx, rec in enumerate(text): wh_ratio = kwargs["wh_ratio_list"][rec_idx] max_wh_ratio = kwargs["max_wh_ratio"] rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio) texts = [] scores = [] for t in text: texts.append(t[0] if len(t) <= 2 else (t[0], t[2])) scores.append(t[1]) return texts, scores def add_special_char(self, character_list): """add_special_char""" character_list = ["blank"] + character_list return character_list @benchmark.timeit class ToBatch: """A class for batching and padding images to a uniform width.""" def __pad_imgs(self, imgs: List[np.ndarray]) -> List[np.ndarray]: """Pad images to the maximum width in the batch. Args: imgs (list of np.ndarrays): List of images to pad. Returns: list of np.ndarrays: List of padded images. """ max_width = max(img.shape[2] for img in imgs) padded_imgs = [] for img in imgs: _, height, width = img.shape pad_width = max_width - width padded_img = np.pad( img, ((0, 0), (0, 0), (0, pad_width)), mode="constant", constant_values=0, ) padded_imgs.append(padded_img) return padded_imgs def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]: """Call method to pad images and stack them into a batch. Args: imgs (list of np.ndarrays): List of images to process. Returns: list of np.ndarrays: List containing a stacked tensor of the padded images. """ imgs = self.__pad_imgs(imgs) return [np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)]