pytorch_paddle.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import copy
  3. import os
  4. import warnings
  5. from pathlib import Path
  6. import cv2
  7. import numpy as np
  8. import yaml
  9. from loguru import logger
  10. from mineru.utils.config_reader import get_device
  11. from mineru.utils.enum_class import ModelPath
  12. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  13. from mineru.utils.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
  14. from mineru.model.utils.tools.infer.predict_system import TextSystem
  15. from mineru.model.utils.tools.infer import pytorchocr_utility as utility
  16. import argparse
  17. latin_lang = [
  18. "af",
  19. "az",
  20. "bs",
  21. "cs",
  22. "cy",
  23. "da",
  24. "de",
  25. "es",
  26. "et",
  27. "fr",
  28. "ga",
  29. "hr",
  30. "hu",
  31. "id",
  32. "is",
  33. "it",
  34. "ku",
  35. "la",
  36. "lt",
  37. "lv",
  38. "mi",
  39. "ms",
  40. "mt",
  41. "nl",
  42. "no",
  43. "oc",
  44. "pi",
  45. "pl",
  46. "pt",
  47. "ro",
  48. "rs_latin",
  49. "sk",
  50. "sl",
  51. "sq",
  52. "sv",
  53. "sw",
  54. "tl",
  55. "tr",
  56. "uz",
  57. "vi",
  58. "french",
  59. "german",
  60. "fi",
  61. "eu",
  62. "gl",
  63. "lb",
  64. "rm",
  65. "ca",
  66. "qu",
  67. ]
  68. arabic_lang = ["ar", "fa", "ug", "ur", "ps", "ku", "sd", "bal"]
  69. cyrillic_lang = [
  70. "ru",
  71. "rs_cyrillic",
  72. "be",
  73. "bg",
  74. "uk",
  75. "mn",
  76. "abq",
  77. "ady",
  78. "kbd",
  79. "ava",
  80. "dar",
  81. "inh",
  82. "che",
  83. "lbe",
  84. "lez",
  85. "tab",
  86. "kk",
  87. "ky",
  88. "tg",
  89. "mk",
  90. "tt",
  91. "cv",
  92. "ba",
  93. "mhr",
  94. "mo",
  95. "udm",
  96. "kv",
  97. "os",
  98. "bua",
  99. "xal",
  100. "tyv",
  101. "sah",
  102. "kaa",
  103. ]
  104. east_slavic_lang = ["ru", "be", "uk"]
  105. devanagari_lang = [
  106. "hi",
  107. "mr",
  108. "ne",
  109. "bh",
  110. "mai",
  111. "ang",
  112. "bho",
  113. "mah",
  114. "sck",
  115. "new",
  116. "gom",
  117. "sa",
  118. "bgc",
  119. ]
  120. def get_model_params(lang, config):
  121. if lang in config['lang']:
  122. params = config['lang'][lang]
  123. det = params.get('det')
  124. rec = params.get('rec')
  125. dict_file = params.get('dict')
  126. return det, rec, dict_file
  127. else:
  128. raise Exception (f'Language {lang} not supported')
  129. root_dir = os.path.join(Path(__file__).resolve().parent.parent, 'utils')
  130. class PytorchPaddleOCR(TextSystem):
  131. def __init__(self, *args, **kwargs):
  132. parser = utility.init_args()
  133. args = parser.parse_args(args)
  134. self.lang = kwargs.get('lang', 'ch')
  135. self.enable_merge_det_boxes = kwargs.get("enable_merge_det_boxes", True)
  136. device = get_device()
  137. if device == 'cpu' and self.lang in ['ch', 'ch_server', 'japan', 'chinese_cht']:
  138. # logger.warning("The current device in use is CPU. To ensure the speed of parsing, the language is automatically switched to ch_lite.")
  139. self.lang = 'ch_lite'
  140. if self.lang in latin_lang:
  141. self.lang = 'latin'
  142. elif self.lang in east_slavic_lang:
  143. self.lang = 'east_slavic'
  144. elif self.lang in arabic_lang:
  145. self.lang = 'arabic'
  146. elif self.lang in cyrillic_lang:
  147. self.lang = 'cyrillic'
  148. elif self.lang in devanagari_lang:
  149. self.lang = 'devanagari'
  150. else:
  151. pass
  152. models_config_path = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'models_config.yml')
  153. with open(models_config_path) as file:
  154. config = yaml.safe_load(file)
  155. det, rec, dict_file = get_model_params(self.lang, config)
  156. ocr_models_dir = ModelPath.pytorch_paddle
  157. det_model_path = f"{ocr_models_dir}/{det}"
  158. det_model_path = os.path.join(auto_download_and_get_model_root_path(det_model_path), det_model_path)
  159. rec_model_path = f"{ocr_models_dir}/{rec}"
  160. rec_model_path = os.path.join(auto_download_and_get_model_root_path(rec_model_path), rec_model_path)
  161. kwargs['det_model_path'] = det_model_path
  162. kwargs['rec_model_path'] = rec_model_path
  163. kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
  164. kwargs['rec_batch_num'] = 6
  165. kwargs['device'] = device
  166. default_args = vars(args)
  167. default_args.update(kwargs)
  168. args = argparse.Namespace(**default_args)
  169. super().__init__(args)
  170. def ocr(self,
  171. img,
  172. det=True,
  173. rec=True,
  174. mfd_res=None,
  175. tqdm_enable=False,
  176. tqdm_desc="OCR-rec Predict",
  177. ):
  178. assert isinstance(img, (np.ndarray, list, str, bytes))
  179. if isinstance(img, list) and det == True:
  180. logger.error('When input a list of images, det must be false')
  181. exit(0)
  182. img = check_img(img)
  183. imgs = [img]
  184. with warnings.catch_warnings():
  185. warnings.simplefilter("ignore", category=RuntimeWarning)
  186. if det and rec:
  187. ocr_res = []
  188. for img in imgs:
  189. img = preprocess_image(img)
  190. dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res)
  191. if not dt_boxes and not rec_res:
  192. ocr_res.append(None)
  193. continue
  194. tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
  195. ocr_res.append(tmp_res)
  196. return ocr_res
  197. elif det and not rec:
  198. ocr_res = []
  199. for img in imgs:
  200. img = preprocess_image(img)
  201. dt_boxes, elapse = self.text_detector(img)
  202. # logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
  203. if dt_boxes is None:
  204. ocr_res.append(None)
  205. continue
  206. dt_boxes = sorted_boxes(dt_boxes)
  207. # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
  208. if self.enable_merge_det_boxes:
  209. dt_boxes = merge_det_boxes(dt_boxes)
  210. if mfd_res:
  211. dt_boxes = update_det_boxes(dt_boxes, mfd_res)
  212. tmp_res = [box.tolist() for box in dt_boxes]
  213. ocr_res.append(tmp_res)
  214. return ocr_res
  215. elif not det and rec:
  216. ocr_res = []
  217. for img in imgs:
  218. if not isinstance(img, list):
  219. img = preprocess_image(img)
  220. img = [img]
  221. rec_res, elapse = self.text_recognizer(img, tqdm_enable=tqdm_enable, tqdm_desc=tqdm_desc)
  222. # logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
  223. ocr_res.append(rec_res)
  224. return ocr_res
  225. def __call__(self, img, mfd_res=None):
  226. if img is None:
  227. logger.debug("no valid image provided")
  228. return None, None
  229. ori_im = img.copy()
  230. dt_boxes, elapse = self.text_detector(img)
  231. if dt_boxes is None:
  232. logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
  233. return None, None
  234. else:
  235. pass
  236. # logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
  237. img_crop_list = []
  238. dt_boxes = sorted_boxes(dt_boxes)
  239. # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
  240. if self.enable_merge_det_boxes:
  241. dt_boxes = merge_det_boxes(dt_boxes)
  242. if mfd_res:
  243. dt_boxes = update_det_boxes(dt_boxes, mfd_res)
  244. for bno in range(len(dt_boxes)):
  245. tmp_box = copy.deepcopy(dt_boxes[bno])
  246. img_crop = get_rotate_crop_image(ori_im, tmp_box)
  247. img_crop_list.append(img_crop)
  248. rec_res, elapse = self.text_recognizer(img_crop_list)
  249. # logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
  250. filter_boxes, filter_rec_res = [], []
  251. for box, rec_result in zip(dt_boxes, rec_res):
  252. text, score = rec_result
  253. if score >= self.drop_score:
  254. filter_boxes.append(box)
  255. filter_rec_res.append(rec_result)
  256. return filter_boxes, filter_rec_res
  257. if __name__ == '__main__':
  258. pytorch_paddle_ocr = PytorchPaddleOCR()
  259. img = cv2.imread("/Users/myhloli/Downloads/screenshot-20250326-194348.png")
  260. dt_boxes, rec_res = pytorch_paddle_ocr(img)
  261. ocr_res = []
  262. if not dt_boxes and not rec_res:
  263. ocr_res.append(None)
  264. else:
  265. tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
  266. ocr_res.append(tmp_res)
  267. print(ocr_res)