# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import os.path as osp import re import numpy as np from PIL import Image import cv2 import math import lazy_paddle as paddle import json import tempfile from tokenizers import Tokenizer as TokenizerFast from ....utils import logging from ..base import BaseComponent __all__ = [ "OCRReisizeNormImg", # "LaTeXOCRReisizeNormImg", "CTCLabelDecode", # "LaTeXOCRDecode", ] class OCRReisizeNormImg(BaseComponent): """for ocr image resize and normalization""" INPUT_KEYS = ["img", "img_size"] OUTPUT_KEYS = ["img"] DEAULT_INPUTS = {"img": "img", "img_size": "img_size"} DEAULT_OUTPUTS = {"img": "img"} def __init__(self, rec_image_shape=[3, 48, 320]): super().__init__() self.rec_image_shape = rec_image_shape def resize_norm_img(self, img, max_wh_ratio): """resize and normalize the img""" imgC, imgH, imgW = self.rec_image_shape assert imgC == img.shape[2] imgW = int((imgH * max_wh_ratio)) h, w = img.shape[:2] ratio = w / float(h) if math.ceil(imgH * ratio) > imgW: resized_w = imgW else: resized_w = int(math.ceil(imgH * ratio)) resized_image = cv2.resize(img, (resized_w, imgH)) resized_image = resized_image.astype("float32") resized_image = resized_image.transpose((2, 0, 1)) / 255 resized_image -= 0.5 resized_image /= 0.5 padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) padding_im[:, :, 0:resized_w] = resized_image return padding_im def apply(self, img, img_size): """apply""" imgC, imgH, imgW = self.rec_image_shape max_wh_ratio = imgW / imgH w, h = img_size[:2] wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) img = self.resize_norm_img(img, max_wh_ratio) return {"img": img} # class LaTeXOCRReisizeNormImg(BaseComponent): # """for ocr image resize and normalization""" # def __init__(self, rec_image_shape=[3, 48, 320]): # super().__init__() # self.rec_image_shape = rec_image_shape # def pad_(self, img, divable=32): # threshold = 128 # data = np.array(img.convert("LA")) # if data[..., -1].var() == 0: # data = (data[..., 0]).astype(np.uint8) # else: # data = (255 - data[..., -1]).astype(np.uint8) # data = (data - data.min()) / (data.max() - data.min()) * 255 # if data.mean() > threshold: # # To invert the text to white # gray = 255 * (data < threshold).astype(np.uint8) # else: # gray = 255 * (data > threshold).astype(np.uint8) # data = 255 - data # coords = cv2.findNonZero(gray) # Find all non-zero points (text) # a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box # rect = data[b : b + h, a : a + w] # im = Image.fromarray(rect).convert("L") # dims = [] # for x in [w, h]: # div, mod = divmod(x, divable) # dims.append(divable * (div + (1 if mod > 0 else 0))) # padded = Image.new("L", dims, 255) # padded.paste(im, (0, 0, im.size[0], im.size[1])) # return padded # def minmax_size_( # self, # img, # max_dimensions, # min_dimensions, # ): # if max_dimensions is not None: # ratios = [a / b for a, b in zip(img.size, max_dimensions)] # if any([r > 1 for r in ratios]): # size = np.array(img.size) // max(ratios) # img = img.resize(tuple(size.astype(int)), Image.BILINEAR) # if min_dimensions is not None: # # hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions # padded_size = [ # max(img_dim, min_dim) # for img_dim, min_dim in zip(img.size, min_dimensions) # ] # if padded_size != list(img.size): # assert hypothesis # padded_im = Image.new("L", padded_size, 255) # padded_im.paste(img, img.getbbox()) # img = padded_im # return img # def norm_img_latexocr(self, img): # # CAN only predict gray scale image # shape = (1, 1, 3) # mean = [0.7931, 0.7931, 0.7931] # std = [0.1738, 0.1738, 0.1738] # scale = np.float32(1.0 / 255.0) # min_dimensions = [32, 32] # max_dimensions = [672, 192] # mean = np.array(mean).reshape(shape).astype("float32") # std = np.array(std).reshape(shape).astype("float32") # im_h, im_w = img.shape[:2] # if ( # min_dimensions[0] <= im_w <= max_dimensions[0] # and min_dimensions[1] <= im_h <= max_dimensions[1] # ): # pass # else: # img = Image.fromarray(np.uint8(img)) # img = self.minmax_size_(self.pad_(img), max_dimensions, min_dimensions) # img = np.array(img) # im_h, im_w = img.shape[:2] # img = np.dstack([img, img, img]) # img = (img.astype("float32") * scale - mean) / std # img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # divide_h = math.ceil(im_h / 16) * 16 # divide_w = math.ceil(im_w / 16) * 16 # img = np.pad( # img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1) # ) # img = img[:, :, np.newaxis].transpose(2, 0, 1) # img = img.astype("float32") # return img # def apply(self, data): # """apply""" # data[K.IMAGE] = self.norm_img_latexocr(data[K.IMAGE]) # return data # @classmethod # def get_input_keys(cls): # """get input keys""" # return [K.IMAGE, K.ORI_IM_SIZE] # @classmethod # def get_output_keys(cls): # """get output keys""" # return [K.IMAGE] class BaseRecLabelDecode(BaseComponent): """Convert between text-label and text-index""" INPUT_KEYS = ["pred"] OUTPUT_KEYS = ["rec_text", "rec_score"] DEAULT_INPUTS = {"pred": "pred"} DEAULT_OUTPUTS = {"rec_text": "rec_text", "rec_score": "rec_score"} ENABLE_BATCH = True def __init__(self, character_str=None, use_space_char=True): super().__init__() self.reverse = False character_list = ( list(character_str) if character_str is not None else list("0123456789abcdefghijklmnopqrstuvwxyz") ) if use_space_char: character_list.append(" ") character_list = self.add_special_char(character_list) self.dict = {} for i, char in enumerate(character_list): self.dict[char] = i self.character = character_list def pred_reverse(self, pred): """pred_reverse""" pred_re = [] c_current = "" for c in pred: if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)): if c_current != "": pred_re.append(c_current) pred_re.append(c) c_current = "" else: c_current += c if c_current != "": pred_re.append(c_current) return "".join(pred_re[::-1]) def add_special_char(self, character_list): """add_special_char""" return character_list def decode(self, text_index, text_prob=None, is_remove_duplicate=False): """convert text-index into text-label.""" result_list = [] ignored_tokens = self.get_ignored_tokens() batch_size = len(text_index) for batch_idx in range(batch_size): selection = np.ones(len(text_index[batch_idx]), dtype=bool) if is_remove_duplicate: selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1] for ignored_token in ignored_tokens: selection &= text_index[batch_idx] != ignored_token char_list = [ self.character[text_id] for text_id in text_index[batch_idx][selection] ] if text_prob is not None: conf_list = text_prob[batch_idx][selection] else: conf_list = [1] * len(selection) if len(conf_list) == 0: conf_list = [0] text = "".join(char_list) if self.reverse: # for arabic rec text = self.pred_reverse(text) result_list.append((text, np.mean(conf_list).tolist())) return result_list def get_ignored_tokens(self): """get_ignored_tokens""" return [0] # for ctc blank def apply(self, pred): """apply""" preds = np.array(pred) if isinstance(preds, tuple) or isinstance(preds, list): preds = preds[-1] preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) return [{"rec_text": t[0], "rec_score": t[1]} for t in text] class CTCLabelDecode(BaseRecLabelDecode): """Convert between text-label and text-index""" def __init__(self, character_list=None, use_space_char=True): super().__init__(character_list, use_space_char=use_space_char) def apply(self, pred): """apply""" preds = np.array(pred[0]) preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) return [{"rec_text": t[0], "rec_score": t[1]} for t in text] def add_special_char(self, character_list): """add_special_char""" character_list = ["blank"] + character_list return character_list # class LaTeXOCRDecode(object): # """Convert between latex-symbol and symbol-index""" # def __init__(self, post_process_cfg=None, **kwargs): # assert post_process_cfg["name"] == "LaTeXOCRDecode" # super(LaTeXOCRDecode, self).__init__() # character_list = post_process_cfg["character_dict"] # temp_path = tempfile.gettempdir() # rec_char_dict_path = os.path.join(temp_path, "latexocr_tokenizer.json") # try: # with open(rec_char_dict_path, "w") as f: # json.dump(character_list, f) # except Exception as e: # print(f"创建 latexocr_tokenizer.json 文件失败, 原因{str(e)}") # self.tokenizer = TokenizerFast.from_file(rec_char_dict_path) # def post_process(self, s): # text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})" # letter = "[a-zA-Z]" # noletter = "[\W_^\d]" # names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)] # s = re.sub(text_reg, lambda match: str(names.pop(0)), s) # news = s # while True: # s = news # news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s) # news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news) # news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news) # if news == s: # break # return s # def decode(self, tokens): # if len(tokens.shape) == 1: # tokens = tokens[None, :] # dec = [self.tokenizer.decode(tok) for tok in tokens] # dec_str_list = [ # "".join(detok.split(" ")) # .replace("Ġ", " ") # .replace("[EOS]", "") # .replace("[BOS]", "") # .replace("[PAD]", "") # .strip() # for detok in dec # ] # return [str(self.post_process(dec_str)) for dec_str in dec_str_list] # def __call__(self, data): # preds = data[K.REC_PROBS] # text = self.decode(preds) # data[K.REC_TEXT] = text[0] # return data # class SaveTextRecResults(BaseComponent): # """SaveTextRecResults""" # _TEXT_REC_RES_SUFFIX = "_text_rec" # _FILE_EXT = ".txt" # def __init__(self, save_dir): # super().__init__() # self.save_dir = save_dir # # We use python backend to save text object # self._writer = TextWriter(backend="python") # def apply(self, data): # """apply""" # ori_path = data[K.IM_PATH] # file_name = os.path.basename(ori_path) # file_name = self._replace_ext(file_name, self._FILE_EXT) # text_rec_res_save_path = os.path.join(self.save_dir, file_name) # rec_res = "" # for text, score in zip(data[K.REC_TEXT], data[K.REC_SCORE]): # line = text + "\t" + str(score) + "\n" # rec_res += line # text_rec_res_save_path = self._add_suffix( # text_rec_res_save_path, self._TEXT_REC_RES_SUFFIX # ) # self._write_txt(text_rec_res_save_path, rec_res) # return data # @classmethod # def get_input_keys(cls): # """get_input_keys""" # return [K.IM_PATH, K.REC_TEXT, K.REC_SCORE] # @classmethod # def get_output_keys(cls): # """get_output_keys""" # return [] # def _write_txt(self, path, txt_str): # """_write_txt""" # if os.path.exists(path): # logging.warning(f"{path} already exists. Overwriting it.") # self._writer.write(path, txt_str) # @staticmethod # def _add_suffix(path, suffix): # """_add_suffix""" # stem, ext = os.path.splitext(path) # return stem + suffix + ext # @staticmethod # def _replace_ext(path, new_ext): # """_replace_ext""" # stem, _ = os.path.splitext(path) # return stem + new_ext # class PrintResult(BaseComponent): # """Print Result Transform""" # def apply(self, data): # """apply""" # logging.info("The prediction result is:") # logging.info(data[K.REC_TEXT]) # return data # @classmethod # def get_input_keys(cls): # """get input keys""" # return [K.REC_TEXT] # @classmethod # def get_output_keys(cls): # """get output keys""" # return []