ppocr_273_mod.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import copy
  2. import platform
  3. import time
  4. import cv2
  5. import numpy as np
  6. import torch
  7. from paddleocr import PaddleOCR
  8. from ppocr.utils.logging import get_logger
  9. from ppocr.utils.utility import alpha_to_color, binarize_img
  10. from tools.infer.predict_system import sorted_boxes
  11. from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
  12. from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img, \
  13. ONNXModelSingleton
  14. logger = get_logger()
  15. class ModifiedPaddleOCR(PaddleOCR):
  16. def __init__(self, *args, **kwargs):
  17. super().__init__(*args, **kwargs)
  18. self.lang = kwargs.get('lang', 'ch')
  19. # 在cpu架构为arm且不支持cuda时调用onnx、
  20. if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
  21. self.use_onnx = True
  22. onnx_model_manager = ONNXModelSingleton()
  23. self.additional_ocr = onnx_model_manager.get_onnx_model(**kwargs)
  24. else:
  25. self.use_onnx = False
  26. def ocr(self,
  27. img,
  28. det=True,
  29. rec=True,
  30. cls=True,
  31. bin=False,
  32. inv=False,
  33. alpha_color=(255, 255, 255),
  34. mfd_res=None,
  35. ):
  36. """
  37. OCR with PaddleOCR
  38. args:
  39. img: img for OCR, support ndarray, img_path and list or ndarray
  40. det: use text detection or not. If False, only rec will be exec. Default is True
  41. rec: use text recognition or not. If False, only det will be exec. Default is True
  42. cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
  43. bin: binarize image to black and white. Default is False.
  44. inv: invert image colors. Default is False.
  45. alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
  46. """
  47. assert isinstance(img, (np.ndarray, list, str, bytes))
  48. if isinstance(img, list) and det == True:
  49. logger.error('When input a list of images, det must be false')
  50. exit(0)
  51. if cls == True and self.use_angle_cls == False:
  52. pass
  53. # logger.warning(
  54. # 'Since the angle classifier is not initialized, it will not be used during the forward process'
  55. # )
  56. img = check_img(img)
  57. # for infer pdf file
  58. if isinstance(img, list):
  59. if self.page_num > len(img) or self.page_num == 0:
  60. self.page_num = len(img)
  61. imgs = img[:self.page_num]
  62. else:
  63. imgs = [img]
  64. def preprocess_image(_image):
  65. _image = alpha_to_color(_image, alpha_color)
  66. if inv:
  67. _image = cv2.bitwise_not(_image)
  68. if bin:
  69. _image = binarize_img(_image)
  70. return _image
  71. if det and rec:
  72. ocr_res = []
  73. for img in imgs:
  74. img = preprocess_image(img)
  75. dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res)
  76. if not dt_boxes and not rec_res:
  77. ocr_res.append(None)
  78. continue
  79. tmp_res = [[box.tolist(), res]
  80. for box, res in zip(dt_boxes, rec_res)]
  81. ocr_res.append(tmp_res)
  82. return ocr_res
  83. elif det and not rec:
  84. ocr_res = []
  85. for img in imgs:
  86. img = preprocess_image(img)
  87. if self.lang in ['ch'] and self.use_onnx:
  88. dt_boxes, elapse = self.additional_ocr.text_detector(img)
  89. else:
  90. dt_boxes, elapse = self.text_detector(img)
  91. if dt_boxes is None:
  92. ocr_res.append(None)
  93. continue
  94. dt_boxes = sorted_boxes(dt_boxes)
  95. # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
  96. dt_boxes = merge_det_boxes(dt_boxes)
  97. if mfd_res:
  98. bef = time.time()
  99. dt_boxes = update_det_boxes(dt_boxes, mfd_res)
  100. aft = time.time()
  101. logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
  102. len(dt_boxes), aft - bef))
  103. tmp_res = [box.tolist() for box in dt_boxes]
  104. ocr_res.append(tmp_res)
  105. return ocr_res
  106. else:
  107. ocr_res = []
  108. cls_res = []
  109. for img in imgs:
  110. if not isinstance(img, list):
  111. img = preprocess_image(img)
  112. img = [img]
  113. if self.use_angle_cls and cls:
  114. img, cls_res_tmp, elapse = self.text_classifier(img)
  115. if not rec:
  116. cls_res.append(cls_res_tmp)
  117. if self.lang in ['ch'] and self.use_onnx:
  118. rec_res, elapse = self.additional_ocr.text_recognizer(img)
  119. else:
  120. rec_res, elapse = self.text_recognizer(img)
  121. ocr_res.append(rec_res)
  122. if not rec:
  123. return cls_res
  124. return ocr_res
  125. def __call__(self, img, cls=True, mfd_res=None):
  126. time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
  127. if img is None:
  128. logger.debug("no valid image provided")
  129. return None, None, time_dict
  130. start = time.time()
  131. ori_im = img.copy()
  132. if self.lang in ['ch'] and self.use_onnx:
  133. dt_boxes, elapse = self.additional_ocr.text_detector(img)
  134. else:
  135. dt_boxes, elapse = self.text_detector(img)
  136. time_dict['det'] = elapse
  137. if dt_boxes is None:
  138. logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
  139. end = time.time()
  140. time_dict['all'] = end - start
  141. return None, None, time_dict
  142. else:
  143. logger.debug("dt_boxes num : {}, elapsed : {}".format(
  144. len(dt_boxes), elapse))
  145. img_crop_list = []
  146. dt_boxes = sorted_boxes(dt_boxes)
  147. # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
  148. dt_boxes = merge_det_boxes(dt_boxes)
  149. if mfd_res:
  150. bef = time.time()
  151. dt_boxes = update_det_boxes(dt_boxes, mfd_res)
  152. aft = time.time()
  153. logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
  154. len(dt_boxes), aft - bef))
  155. for bno in range(len(dt_boxes)):
  156. tmp_box = copy.deepcopy(dt_boxes[bno])
  157. if self.args.det_box_type == "quad":
  158. img_crop = get_rotate_crop_image(ori_im, tmp_box)
  159. else:
  160. img_crop = get_minarea_rect_crop(ori_im, tmp_box)
  161. img_crop_list.append(img_crop)
  162. if self.use_angle_cls and cls:
  163. img_crop_list, angle_list, elapse = self.text_classifier(
  164. img_crop_list)
  165. time_dict['cls'] = elapse
  166. logger.debug("cls num : {}, elapsed : {}".format(
  167. len(img_crop_list), elapse))
  168. if self.lang in ['ch'] and self.use_onnx:
  169. rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
  170. else:
  171. rec_res, elapse = self.text_recognizer(img_crop_list)
  172. time_dict['rec'] = elapse
  173. logger.debug("rec_res num : {}, elapsed : {}".format(
  174. len(rec_res), elapse))
  175. if self.args.save_crop_res:
  176. self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
  177. rec_res)
  178. filter_boxes, filter_rec_res = [], []
  179. for box, rec_result in zip(dt_boxes, rec_res):
  180. text, score = rec_result
  181. if score >= self.drop_score:
  182. filter_boxes.append(box)
  183. filter_rec_res.append(rec_result)
  184. end = time.time()
  185. time_dict['all'] = end - start
  186. return filter_boxes, filter_rec_res, time_dict