| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459 |
- # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import os.path as osp
- import re
- import numpy as np
- from PIL import Image
- import cv2
- import math
- import paddle
- import json
- import tempfile
- from tokenizers import Tokenizer as TokenizerFast
- from ....utils import logging
- from ...base.predictor import BaseTransform
- from ...base.predictor.io.writers import TextWriter
- from .keys import TextRecKeys as K
- __all__ = ['OCRReisizeNormImg', 'LaTeXOCRReisizeNormImg', 'CTCLabelDecode', 'LaTeXOCRDecode', 'SaveTextRecResults']
- class OCRReisizeNormImg(BaseTransform):
- """ for ocr image resize and normalization """
- def __init__(self, rec_image_shape=[3, 48, 320]):
- super().__init__()
- self.rec_image_shape = rec_image_shape
- def resize_norm_img(self, img, max_wh_ratio):
- """ resize and normalize the img """
- imgC, imgH, imgW = self.rec_image_shape
- assert imgC == img.shape[2]
- imgW = int((imgH * max_wh_ratio))
- h, w = img.shape[:2]
- ratio = w / float(h)
- if math.ceil(imgH * ratio) > imgW:
- resized_w = imgW
- else:
- resized_w = int(math.ceil(imgH * ratio))
- resized_image = cv2.resize(img, (resized_w, imgH))
- resized_image = resized_image.astype('float32')
- resized_image = resized_image.transpose((2, 0, 1)) / 255
- resized_image -= 0.5
- resized_image /= 0.5
- padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
- padding_im[:, :, 0:resized_w] = resized_image
- return padding_im
- def apply(self, data):
- """ apply """
- imgC, imgH, imgW = self.rec_image_shape
- max_wh_ratio = imgW / imgH
- w, h = data[K.ORI_IM_SIZE]
- wh_ratio = w * 1.0 / h
- max_wh_ratio = max(max_wh_ratio, wh_ratio)
- data[K.IMAGE] = self.resize_norm_img(data[K.IMAGE], max_wh_ratio)
- return data
- @classmethod
- def get_input_keys(cls):
- """ get input keys """
- return [K.IMAGE, K.ORI_IM_SIZE]
- @classmethod
- def get_output_keys(cls):
- """ get output keys """
- return [K.IMAGE]
- class LaTeXOCRReisizeNormImg(BaseTransform):
- """ for ocr image resize and normalization """
- def __init__(self, rec_image_shape=[3, 48, 320]):
- super().__init__()
- self.rec_image_shape = rec_image_shape
- def pad_(self, img, divable=32):
- threshold = 128
- data = np.array(img.convert("LA"))
- if data[..., -1].var() == 0:
- data = (data[..., 0]).astype(np.uint8)
- else:
- data = (255 - data[..., -1]).astype(np.uint8)
- data = (data - data.min()) / (data.max() - data.min()) * 255
- if data.mean() > threshold:
- # To invert the text to white
- gray = 255 * (data < threshold).astype(np.uint8)
- else:
- gray = 255 * (data > threshold).astype(np.uint8)
- data = 255 - data
- coords = cv2.findNonZero(gray) # Find all non-zero points (text)
- a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
- rect = data[b: b + h, a: a + w]
- im = Image.fromarray(rect).convert("L")
- dims = []
- for x in [w, h]:
- div, mod = divmod(x, divable)
- dims.append(divable * (div + (1 if mod > 0 else 0)))
- padded = Image.new("L", dims, 255)
- padded.paste(im, (0, 0, im.size[0], im.size[1]))
- return padded
- def minmax_size_(
- self,
- img,
- max_dimensions,
- min_dimensions,
- ):
- if max_dimensions is not None:
- ratios = [a / b for a, b in zip(img.size, max_dimensions)]
- if any([r > 1 for r in ratios]):
- size = np.array(img.size) // max(ratios)
- img = img.resize(tuple(size.astype(int)), Image.BILINEAR)
- if min_dimensions is not None:
- # hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions
- padded_size = [
- max(img_dim, min_dim)
- for img_dim, min_dim in zip(img.size, min_dimensions)
- ]
- if padded_size != list(img.size): # assert hypothesis
- padded_im = Image.new("L", padded_size, 255)
- padded_im.paste(img, img.getbbox())
- img = padded_im
- return img
- def norm_img_latexocr(self, img):
- # CAN only predict gray scale image
- shape = (1, 1, 3)
- mean = [0.7931, 0.7931, 0.7931]
- std = [0.1738, 0.1738, 0.1738]
- scale = np.float32(1.0 / 255.0)
- min_dimensions = [32, 32]
- max_dimensions = [672, 192]
- mean = np.array(mean).reshape(shape).astype("float32")
- std = np.array(std).reshape(shape).astype("float32")
- im_h, im_w = img.shape[:2]
- if (
- min_dimensions[0] <= im_w <= max_dimensions[0]
- and min_dimensions[1] <= im_h <= max_dimensions[1]
- ):
- pass
- else:
- img = Image.fromarray(np.uint8(img))
- img = self.minmax_size_(self.pad_(img), max_dimensions, min_dimensions)
- img = np.array(img)
- im_h, im_w = img.shape[:2]
- img = np.dstack([img, img, img])
- img = (img.astype("float32") * scale - mean) / std
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- divide_h = math.ceil(im_h / 16) * 16
- divide_w = math.ceil(im_w / 16) * 16
- img = np.pad(
- img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
- )
- img = img[:, :, np.newaxis].transpose(2, 0, 1)
- img = img.astype("float32")
- return img
- def apply(self, data):
- """ apply """
- data[K.IMAGE] = self.norm_img_latexocr(data[K.IMAGE])
- return data
- @classmethod
- def get_input_keys(cls):
- """ get input keys """
- return [K.IMAGE, K.ORI_IM_SIZE]
- @classmethod
- def get_output_keys(cls):
- """ get output keys """
- return [K.IMAGE]
- class BaseRecLabelDecode(BaseTransform):
- """ Convert between text-label and text-index """
- def __init__(self, character_str=None, use_space_char=True):
- self.reverse = False
- character_list = list(
- character_str) if character_str is not None else list(
- "0123456789abcdefghijklmnopqrstuvwxyz")
- if use_space_char:
- character_list.append(" ")
- character_list = self.add_special_char(character_list)
- self.dict = {}
- for i, char in enumerate(character_list):
- self.dict[char] = i
- self.character = character_list
- def pred_reverse(self, pred):
- """ pred_reverse """
- pred_re = []
- c_current = ''
- for c in pred:
- if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
- if c_current != '':
- pred_re.append(c_current)
- pred_re.append(c)
- c_current = ''
- else:
- c_current += c
- if c_current != '':
- pred_re.append(c_current)
- return ''.join(pred_re[::-1])
- def add_special_char(self, character_list):
- """ add_special_char """
- return character_list
- def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
- """ convert text-index into text-label. """
- result_list = []
- ignored_tokens = self.get_ignored_tokens()
- batch_size = len(text_index)
- for batch_idx in range(batch_size):
- selection = np.ones(len(text_index[batch_idx]), dtype=bool)
- if is_remove_duplicate:
- selection[1:] = text_index[batch_idx][1:] != text_index[
- batch_idx][:-1]
- for ignored_token in ignored_tokens:
- selection &= text_index[batch_idx] != ignored_token
- char_list = [
- self.character[text_id]
- for text_id in text_index[batch_idx][selection]
- ]
- if text_prob is not None:
- conf_list = text_prob[batch_idx][selection]
- else:
- conf_list = [1] * len(selection)
- if len(conf_list) == 0:
- conf_list = [0]
- text = ''.join(char_list)
- if self.reverse: # for arabic rec
- text = self.pred_reverse(text)
- result_list.append((text, np.mean(conf_list).tolist()))
- return result_list
- def get_ignored_tokens(self):
- """ get_ignored_tokens """
- return [0] # for ctc blank
- def apply(self, data):
- """ apply """
- preds = data[K.REC_PROBS]
- if isinstance(preds, tuple) or isinstance(preds, list):
- preds = preds[-1]
- preds_idx = preds.argmax(axis=2)
- preds_prob = preds.max(axis=2)
- text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
- data[K.REC_TEXT] = []
- data[K.REC_SCORE] = []
- for t in text:
- data[K.REC_TEXT].append(t[0])
- data[K.REC_SCORE].append(t[1])
- return data
- @classmethod
- def get_input_keys(cls):
- """ get_input_keys """
- return [K.REC_PROBS]
- @classmethod
- def get_output_keys(cls):
- """ get_output_keys """
- return [K.REC_TEXT, K.REC_SCORE]
- class CTCLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
- def __init__(self, post_process_cfg=None, use_space_char=True):
- assert post_process_cfg['name'] == 'CTCLabelDecode'
- character_list = post_process_cfg['character_dict']
- super().__init__(character_list, use_space_char=use_space_char)
- def apply(self, data):
- """ apply """
- preds = data[K.REC_PROBS]
- if isinstance(preds, tuple) or isinstance(preds, list):
- preds = preds[-1]
- preds_idx = preds.argmax(axis=2)
- preds_prob = preds.max(axis=2)
- text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
- data[K.REC_TEXT] = []
- data[K.REC_SCORE] = []
- for t in text:
- data[K.REC_TEXT].append(t[0])
- data[K.REC_SCORE].append(t[1])
- return data
- def add_special_char(self, character_list):
- """ add_special_char """
- character_list = ['blank'] + character_list
- return character_list
- @classmethod
- def get_input_keys(cls):
- """ get_input_keys """
- return [K.REC_PROBS]
- @classmethod
- def get_output_keys(cls):
- """ get_output_keys """
- return [K.REC_TEXT, K.REC_SCORE]
- class LaTeXOCRDecode(object):
- """Convert between latex-symbol and symbol-index"""
- def __init__(self, post_process_cfg=None, **kwargs):
- assert post_process_cfg['name'] == 'LaTeXOCRDecode'
- super(LaTeXOCRDecode, self).__init__()
- character_list = post_process_cfg['character_dict']
- temp_path = tempfile.gettempdir()
- rec_char_dict_path = os.path.join(temp_path, "latexocr_tokenizer.json")
- try:
- with open(rec_char_dict_path, "w") as f:
- json.dump(character_list, f)
- except Exception as e:
- print(f'创建 latexocr_tokenizer.json 文件失败, 原因{str(e)}')
- self.tokenizer = TokenizerFast.from_file(rec_char_dict_path)
- def post_process(self, s):
- text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
- letter = "[a-zA-Z]"
- noletter = "[\W_^\d]"
- names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
- s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
- news = s
- while True:
- s = news
- news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
- news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
- news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
- if news == s:
- break
- return s
- def decode(self, tokens):
- if len(tokens.shape) == 1:
- tokens = tokens[None, :]
- dec = [self.tokenizer.decode(tok) for tok in tokens]
- dec_str_list = [
- "".join(detok.split(" "))
- .replace("Ġ", " ")
- .replace("[EOS]", "")
- .replace("[BOS]", "")
- .replace("[PAD]", "")
- .strip()
- for detok in dec
- ]
- return [str(self.post_process(dec_str)) for dec_str in dec_str_list]
- def __call__(self, data):
- preds = data[K.REC_PROBS]
- text = self.decode(preds)
- data[K.REC_TEXT] = text[0]
- return data
- class SaveTextRecResults(BaseTransform):
- """ SaveTextRecResults """
- _TEXT_REC_RES_SUFFIX = '_text_rec'
- _FILE_EXT = '.txt'
- def __init__(self, save_dir):
- super().__init__()
- self.save_dir = save_dir
- # We use python backend to save text object
- self._writer = TextWriter(backend='python')
- def apply(self, data):
- """ apply """
- ori_path = data[K.IM_PATH]
- file_name = os.path.basename(ori_path)
- file_name = self._replace_ext(file_name, self._FILE_EXT)
- text_rec_res_save_path = os.path.join(self.save_dir, file_name)
- rec_res = ''
- for text, score in zip(data[K.REC_TEXT], data[K.REC_SCORE]):
- line = text + '\t' + str(score) + '\n'
- rec_res += line
- text_rec_res_save_path = self._add_suffix(text_rec_res_save_path,
- self._TEXT_REC_RES_SUFFIX)
- self._write_txt(text_rec_res_save_path, rec_res)
- return data
- @classmethod
- def get_input_keys(cls):
- """ get_input_keys """
- return [K.IM_PATH, K.REC_TEXT, K.REC_SCORE]
- @classmethod
- def get_output_keys(cls):
- """ get_output_keys """
- return []
- def _write_txt(self, path, txt_str):
- """ _write_txt """
- if os.path.exists(path):
- logging.warning(f"{path} already exists. Overwriting it.")
- self._writer.write(path, txt_str)
- @staticmethod
- def _add_suffix(path, suffix):
- """ _add_suffix """
- stem, ext = os.path.splitext(path)
- return stem + suffix + ext
- @staticmethod
- def _replace_ext(path, new_ext):
- """ _replace_ext """
- stem, _ = os.path.splitext(path)
- return stem + new_ext
- class PrintResult(BaseTransform):
- """ Print Result Transform """
- def apply(self, data):
- """ apply """
- logging.info("The prediction result is:")
- logging.info(data[K.REC_TEXT])
- return data
- @classmethod
- def get_input_keys(cls):
- """ get input keys """
- return [K.REC_TEXT]
- @classmethod
- def get_output_keys(cls):
- """ get output keys """
- return []
|