# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import os.path as osp import re import numpy as np from PIL import Image import cv2 import math import lazy_paddle as paddle import json import tempfile from tokenizers import Tokenizer as TokenizerFast from ....utils import logging from ...base.predictor import BaseTransform from ...base.predictor.io.writers import TextWriter from .keys import TextRecKeys as K __all__ = [ "OCRReisizeNormImg", "LaTeXOCRReisizeNormImg", "CTCLabelDecode", "LaTeXOCRDecode", "SaveTextRecResults", ] class OCRReisizeNormImg(BaseTransform): """for ocr image resize and normalization""" def __init__(self, rec_image_shape=[3, 48, 320]): super().__init__() self.rec_image_shape = rec_image_shape def resize_norm_img(self, img, max_wh_ratio): """resize and normalize the img""" imgC, imgH, imgW = self.rec_image_shape assert imgC == img.shape[2] imgW = int((imgH * max_wh_ratio)) h, w = img.shape[:2] ratio = w / float(h) if math.ceil(imgH * ratio) > imgW: resized_w = imgW else: resized_w = int(math.ceil(imgH * ratio)) resized_image = cv2.resize(img, (resized_w, imgH)) resized_image = resized_image.astype("float32") resized_image = resized_image.transpose((2, 0, 1)) / 255 resized_image -= 0.5 resized_image /= 0.5 padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) padding_im[:, :, 0:resized_w] = resized_image return padding_im def apply(self, data): """apply""" imgC, imgH, imgW = self.rec_image_shape max_wh_ratio = imgW / imgH w, h = data[K.ORI_IM_SIZE] wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) data[K.IMAGE] = self.resize_norm_img(data[K.IMAGE], max_wh_ratio) return data @classmethod def get_input_keys(cls): """get input keys""" return [K.IMAGE, K.ORI_IM_SIZE] @classmethod def get_output_keys(cls): """get output keys""" return [K.IMAGE] class LaTeXOCRReisizeNormImg(BaseTransform): """for ocr image resize and normalization""" def __init__(self, rec_image_shape=[3, 48, 320]): super().__init__() self.rec_image_shape = rec_image_shape def pad_(self, img, divable=32): threshold = 128 data = np.array(img.convert("LA")) if data[..., -1].var() == 0: data = (data[..., 0]).astype(np.uint8) else: data = (255 - data[..., -1]).astype(np.uint8) data = (data - data.min()) / (data.max() - data.min()) * 255 if data.mean() > threshold: # To invert the text to white gray = 255 * (data < threshold).astype(np.uint8) else: gray = 255 * (data > threshold).astype(np.uint8) data = 255 - data coords = cv2.findNonZero(gray) # Find all non-zero points (text) a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box rect = data[b : b + h, a : a + w] im = Image.fromarray(rect).convert("L") dims = [] for x in [w, h]: div, mod = divmod(x, divable) dims.append(divable * (div + (1 if mod > 0 else 0))) padded = Image.new("L", dims, 255) padded.paste(im, (0, 0, im.size[0], im.size[1])) return padded def minmax_size_( self, img, max_dimensions, min_dimensions, ): if max_dimensions is not None: ratios = [a / b for a, b in zip(img.size, max_dimensions)] if any([r > 1 for r in ratios]): size = np.array(img.size) // max(ratios) img = img.resize(tuple(size.astype(int)), Image.BILINEAR) if min_dimensions is not None: # hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions padded_size = [ max(img_dim, min_dim) for img_dim, min_dim in zip(img.size, min_dimensions) ] if padded_size != list(img.size): # assert hypothesis padded_im = Image.new("L", padded_size, 255) padded_im.paste(img, img.getbbox()) img = padded_im return img def norm_img_latexocr(self, img): # CAN only predict gray scale image shape = (1, 1, 3) mean = [0.7931, 0.7931, 0.7931] std = [0.1738, 0.1738, 0.1738] scale = np.float32(1.0 / 255.0) min_dimensions = [32, 32] max_dimensions = [672, 192] mean = np.array(mean).reshape(shape).astype("float32") std = np.array(std).reshape(shape).astype("float32") im_h, im_w = img.shape[:2] if ( min_dimensions[0] <= im_w <= max_dimensions[0] and min_dimensions[1] <= im_h <= max_dimensions[1] ): pass else: img = Image.fromarray(np.uint8(img)) img = self.minmax_size_(self.pad_(img), max_dimensions, min_dimensions) img = np.array(img) im_h, im_w = img.shape[:2] img = np.dstack([img, img, img]) img = (img.astype("float32") * scale - mean) / std img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) divide_h = math.ceil(im_h / 16) * 16 divide_w = math.ceil(im_w / 16) * 16 img = np.pad( img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1) ) img = img[:, :, np.newaxis].transpose(2, 0, 1) img = img.astype("float32") return img def apply(self, data): """apply""" data[K.IMAGE] = self.norm_img_latexocr(data[K.IMAGE]) return data @classmethod def get_input_keys(cls): """get input keys""" return [K.IMAGE, K.ORI_IM_SIZE] @classmethod def get_output_keys(cls): """get output keys""" return [K.IMAGE] class BaseRecLabelDecode(BaseTransform): """Convert between text-label and text-index""" def __init__(self, character_str=None, use_space_char=True): self.reverse = False character_list = ( list(character_str) if character_str is not None else list("0123456789abcdefghijklmnopqrstuvwxyz") ) if use_space_char: character_list.append(" ") character_list = self.add_special_char(character_list) self.dict = {} for i, char in enumerate(character_list): self.dict[char] = i self.character = character_list def pred_reverse(self, pred): """pred_reverse""" pred_re = [] c_current = "" for c in pred: if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)): if c_current != "": pred_re.append(c_current) pred_re.append(c) c_current = "" else: c_current += c if c_current != "": pred_re.append(c_current) return "".join(pred_re[::-1]) def add_special_char(self, character_list): """add_special_char""" return character_list def decode(self, text_index, text_prob=None, is_remove_duplicate=False): """convert text-index into text-label.""" result_list = [] ignored_tokens = self.get_ignored_tokens() batch_size = len(text_index) for batch_idx in range(batch_size): selection = np.ones(len(text_index[batch_idx]), dtype=bool) if is_remove_duplicate: selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1] for ignored_token in ignored_tokens: selection &= text_index[batch_idx] != ignored_token char_list = [ self.character[text_id] for text_id in text_index[batch_idx][selection] ] if text_prob is not None: conf_list = text_prob[batch_idx][selection] else: conf_list = [1] * len(selection) if len(conf_list) == 0: conf_list = [0] text = "".join(char_list) if self.reverse: # for arabic rec text = self.pred_reverse(text) result_list.append((text, np.mean(conf_list).tolist())) return result_list def get_ignored_tokens(self): """get_ignored_tokens""" return [0] # for ctc blank def apply(self, data): """apply""" preds = data[K.REC_PROBS] if isinstance(preds, tuple) or isinstance(preds, list): preds = preds[-1] preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) data[K.REC_TEXT] = [] data[K.REC_SCORE] = [] for t in text: data[K.REC_TEXT].append(t[0]) data[K.REC_SCORE].append(t[1]) return data @classmethod def get_input_keys(cls): """get_input_keys""" return [K.REC_PROBS] @classmethod def get_output_keys(cls): """get_output_keys""" return [K.REC_TEXT, K.REC_SCORE] class CTCLabelDecode(BaseRecLabelDecode): """Convert between text-label and text-index""" def __init__(self, post_process_cfg=None, use_space_char=True): assert post_process_cfg["name"] == "CTCLabelDecode" character_list = post_process_cfg["character_dict"] super().__init__(character_list, use_space_char=use_space_char) def apply(self, data): """apply""" preds = data[K.REC_PROBS] if isinstance(preds, tuple) or isinstance(preds, list): preds = preds[-1] preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) data[K.REC_TEXT] = [] data[K.REC_SCORE] = [] for t in text: data[K.REC_TEXT].append(t[0]) data[K.REC_SCORE].append(t[1]) return data def add_special_char(self, character_list): """add_special_char""" character_list = ["blank"] + character_list return character_list @classmethod def get_input_keys(cls): """get_input_keys""" return [K.REC_PROBS] @classmethod def get_output_keys(cls): """get_output_keys""" return [K.REC_TEXT, K.REC_SCORE] class LaTeXOCRDecode(object): """Convert between latex-symbol and symbol-index""" def __init__(self, post_process_cfg=None, **kwargs): assert post_process_cfg["name"] == "LaTeXOCRDecode" super(LaTeXOCRDecode, self).__init__() character_list = post_process_cfg["character_dict"] temp_path = tempfile.gettempdir() rec_char_dict_path = os.path.join(temp_path, "latexocr_tokenizer.json") try: with open(rec_char_dict_path, "w") as f: json.dump(character_list, f) except Exception as e: print(f"创建 latexocr_tokenizer.json 文件失败, 原因{str(e)}") self.tokenizer = TokenizerFast.from_file(rec_char_dict_path) def post_process(self, s): text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})" letter = "[a-zA-Z]" noletter = "[\W_^\d]" names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)] s = re.sub(text_reg, lambda match: str(names.pop(0)), s) news = s while True: s = news news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s) news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news) news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news) if news == s: break return s def decode(self, tokens): if len(tokens.shape) == 1: tokens = tokens[None, :] dec = [self.tokenizer.decode(tok) for tok in tokens] dec_str_list = [ "".join(detok.split(" ")) .replace("Ġ", " ") .replace("[EOS]", "") .replace("[BOS]", "") .replace("[PAD]", "") .strip() for detok in dec ] return [str(self.post_process(dec_str)) for dec_str in dec_str_list] def __call__(self, data): preds = data[K.REC_PROBS] text = self.decode(preds) data[K.REC_TEXT] = text[0] return data class SaveTextRecResults(BaseTransform): """SaveTextRecResults""" _TEXT_REC_RES_SUFFIX = "_text_rec" _FILE_EXT = ".txt" def __init__(self, save_dir): super().__init__() self.save_dir = save_dir # We use python backend to save text object self._writer = TextWriter(backend="python") def apply(self, data): """apply""" ori_path = data[K.IM_PATH] file_name = os.path.basename(ori_path) file_name = self._replace_ext(file_name, self._FILE_EXT) text_rec_res_save_path = os.path.join(self.save_dir, file_name) rec_res = "" for text, score in zip(data[K.REC_TEXT], data[K.REC_SCORE]): line = text + "\t" + str(score) + "\n" rec_res += line text_rec_res_save_path = self._add_suffix( text_rec_res_save_path, self._TEXT_REC_RES_SUFFIX ) self._write_txt(text_rec_res_save_path, rec_res) return data @classmethod def get_input_keys(cls): """get_input_keys""" return [K.IM_PATH, K.REC_TEXT, K.REC_SCORE] @classmethod def get_output_keys(cls): """get_output_keys""" return [] def _write_txt(self, path, txt_str): """_write_txt""" if os.path.exists(path): logging.warning(f"{path} already exists. Overwriting it.") self._writer.write(path, txt_str) @staticmethod def _add_suffix(path, suffix): """_add_suffix""" stem, ext = os.path.splitext(path) return stem + suffix + ext @staticmethod def _replace_ext(path, new_ext): """_replace_ext""" stem, _ = os.path.splitext(path) return stem + new_ext class PrintResult(BaseTransform): """Print Result Transform""" def apply(self, data): """apply""" logging.info("The prediction result is:") logging.info(data[K.REC_TEXT]) return data @classmethod def get_input_keys(cls): """get input keys""" return [K.REC_TEXT] @classmethod def get_output_keys(cls): """get output keys""" return []