| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209 |
- # Copyright (c) Opendatalab. All rights reserved.
- import copy
- import os
- import warnings
- from pathlib import Path
- import cv2
- import numpy as np
- import yaml
- from loguru import logger
- from mineru.utils.config_reader import get_device
- from mineru.utils.enum_class import ModelPath
- from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
- from ....utils.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
- from .tools.infer.predict_system import TextSystem
- from .tools.infer import pytorchocr_utility as utility
- import argparse
- latin_lang = [
- 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr', # noqa: E126
- 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
- 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
- 'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
- ]
- arabic_lang = ['ar', 'fa', 'ug', 'ur']
- cyrillic_lang = [
- 'rs_cyrillic', 'bg', 'mn', 'abq', 'ady', 'kbd', 'ava', # noqa: E126
- 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
- ]
- east_slavic_lang = ["ru", "be", "uk"]
- devanagari_lang = [
- 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', # noqa: E126
- 'sa', 'bgc'
- ]
- def get_model_params(lang, config):
- if lang in config['lang']:
- params = config['lang'][lang]
- det = params.get('det')
- rec = params.get('rec')
- dict_file = params.get('dict')
- return det, rec, dict_file
- else:
- raise Exception (f'Language {lang} not supported')
- root_dir = Path(__file__).resolve().parent
- class PytorchPaddleOCR(TextSystem):
- def __init__(self, *args, **kwargs):
- parser = utility.init_args()
- args = parser.parse_args(args)
- self.lang = kwargs.get('lang', 'ch')
- device = get_device()
- if device == 'cpu' and self.lang in ['ch', 'ch_server', 'japan', 'chinese_cht']:
- # logger.warning("The current device in use is CPU. To ensure the speed of parsing, the language is automatically switched to ch_lite.")
- self.lang = 'ch_lite'
- if self.lang in latin_lang:
- self.lang = 'latin'
- elif self.lang in arabic_lang:
- self.lang = 'arabic'
- elif self.lang in cyrillic_lang:
- self.lang = 'cyrillic'
- elif self.lang in devanagari_lang:
- self.lang = 'devanagari'
- elif self.lang in east_slavic_lang:
- self.lang = 'east_slavic'
- else:
- pass
- models_config_path = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'models_config.yml')
- with open(models_config_path) as file:
- config = yaml.safe_load(file)
- det, rec, dict_file = get_model_params(self.lang, config)
- ocr_models_dir = ModelPath.pytorch_paddle
- det_model_path = f"{ocr_models_dir}/{det}"
- det_model_path = os.path.join(auto_download_and_get_model_root_path(det_model_path), det_model_path)
- rec_model_path = f"{ocr_models_dir}/{rec}"
- rec_model_path = os.path.join(auto_download_and_get_model_root_path(rec_model_path), rec_model_path)
- kwargs['det_model_path'] = det_model_path
- kwargs['rec_model_path'] = rec_model_path
- kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
- kwargs['rec_batch_num'] = 16
- kwargs['device'] = device
- default_args = vars(args)
- default_args.update(kwargs)
- args = argparse.Namespace(**default_args)
- super().__init__(args)
- def ocr(self,
- img,
- det=True,
- rec=True,
- mfd_res=None,
- tqdm_enable=False,
- ):
- assert isinstance(img, (np.ndarray, list, str, bytes))
- if isinstance(img, list) and det == True:
- logger.error('When input a list of images, det must be false')
- exit(0)
- img = check_img(img)
- imgs = [img]
- with warnings.catch_warnings():
- warnings.simplefilter("ignore", category=RuntimeWarning)
- if det and rec:
- ocr_res = []
- for img in imgs:
- img = preprocess_image(img)
- dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res)
- if not dt_boxes and not rec_res:
- ocr_res.append(None)
- continue
- tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
- ocr_res.append(tmp_res)
- return ocr_res
- elif det and not rec:
- ocr_res = []
- for img in imgs:
- img = preprocess_image(img)
- dt_boxes, elapse = self.text_detector(img)
- # logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
- if dt_boxes is None:
- ocr_res.append(None)
- continue
- dt_boxes = sorted_boxes(dt_boxes)
- # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
- dt_boxes = merge_det_boxes(dt_boxes)
- if mfd_res:
- dt_boxes = update_det_boxes(dt_boxes, mfd_res)
- tmp_res = [box.tolist() for box in dt_boxes]
- ocr_res.append(tmp_res)
- return ocr_res
- elif not det and rec:
- ocr_res = []
- for img in imgs:
- if not isinstance(img, list):
- img = preprocess_image(img)
- img = [img]
- rec_res, elapse = self.text_recognizer(img, tqdm_enable=tqdm_enable)
- # logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
- ocr_res.append(rec_res)
- return ocr_res
- def __call__(self, img, mfd_res=None):
- if img is None:
- logger.debug("no valid image provided")
- return None, None
- ori_im = img.copy()
- dt_boxes, elapse = self.text_detector(img)
- if dt_boxes is None:
- logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
- return None, None
- else:
- pass
- # logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
- img_crop_list = []
- dt_boxes = sorted_boxes(dt_boxes)
- # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
- dt_boxes = merge_det_boxes(dt_boxes)
- if mfd_res:
- dt_boxes = update_det_boxes(dt_boxes, mfd_res)
- for bno in range(len(dt_boxes)):
- tmp_box = copy.deepcopy(dt_boxes[bno])
- img_crop = get_rotate_crop_image(ori_im, tmp_box)
- img_crop_list.append(img_crop)
- rec_res, elapse = self.text_recognizer(img_crop_list)
- # logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
- filter_boxes, filter_rec_res = [], []
- for box, rec_result in zip(dt_boxes, rec_res):
- text, score = rec_result
- if score >= self.drop_score:
- filter_boxes.append(box)
- filter_rec_res.append(rec_result)
- return filter_boxes, filter_rec_res
- if __name__ == '__main__':
- pytorch_paddle_ocr = PytorchPaddleOCR()
- img = cv2.imread("/Users/myhloli/Downloads/screenshot-20250326-194348.png")
- dt_boxes, rec_res = pytorch_paddle_ocr(img)
- ocr_res = []
- if not dt_boxes and not rec_res:
- ocr_res.append(None)
- else:
- tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
- ocr_res.append(tmp_res)
- print(ocr_res)
|