ppocr_273_mod.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import copy
  2. import platform
  3. import time
  4. import cv2
  5. import numpy as np
  6. import torch
  7. from paddleocr import PaddleOCR
  8. from ppocr.utils.logging import get_logger
  9. from ppocr.utils.utility import alpha_to_color, binarize_img
  10. from tools.infer.predict_system import sorted_boxes
  11. from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
  12. from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img, \
  13. ONNXModelSingleton
  14. logger = get_logger()
  15. class ModifiedPaddleOCR(PaddleOCR):
  16. def __init__(self, *args, **kwargs):
  17. super().__init__(*args, **kwargs)
  18. # 在cpu架构为arm且不支持cuda时调用onnx、
  19. if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
  20. self.use_onnx = True
  21. onnx_model_manager = ONNXModelSingleton()
  22. self.additional_ocr = onnx_model_manager.get_onnx_model(**kwargs)
  23. else:
  24. self.use_onnx = False
  25. def ocr(self,
  26. img,
  27. det=True,
  28. rec=True,
  29. cls=True,
  30. bin=False,
  31. inv=False,
  32. alpha_color=(255, 255, 255),
  33. mfd_res=None,
  34. ):
  35. """
  36. OCR with PaddleOCR
  37. args:
  38. img: img for OCR, support ndarray, img_path and list or ndarray
  39. det: use text detection or not. If False, only rec will be exec. Default is True
  40. rec: use text recognition or not. If False, only det will be exec. Default is True
  41. cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
  42. bin: binarize image to black and white. Default is False.
  43. inv: invert image colors. Default is False.
  44. alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
  45. """
  46. assert isinstance(img, (np.ndarray, list, str, bytes))
  47. if isinstance(img, list) and det == True:
  48. logger.error('When input a list of images, det must be false')
  49. exit(0)
  50. if cls == True and self.use_angle_cls == False:
  51. pass
  52. # logger.warning(
  53. # 'Since the angle classifier is not initialized, it will not be used during the forward process'
  54. # )
  55. img = check_img(img)
  56. # for infer pdf file
  57. if isinstance(img, list):
  58. if self.page_num > len(img) or self.page_num == 0:
  59. self.page_num = len(img)
  60. imgs = img[:self.page_num]
  61. else:
  62. imgs = [img]
  63. def preprocess_image(_image):
  64. _image = alpha_to_color(_image, alpha_color)
  65. if inv:
  66. _image = cv2.bitwise_not(_image)
  67. if bin:
  68. _image = binarize_img(_image)
  69. return _image
  70. if det and rec:
  71. ocr_res = []
  72. for img in imgs:
  73. img = preprocess_image(img)
  74. dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res)
  75. if not dt_boxes and not rec_res:
  76. ocr_res.append(None)
  77. continue
  78. tmp_res = [[box.tolist(), res]
  79. for box, res in zip(dt_boxes, rec_res)]
  80. ocr_res.append(tmp_res)
  81. return ocr_res
  82. elif det and not rec:
  83. ocr_res = []
  84. for img in imgs:
  85. img = preprocess_image(img)
  86. if self.use_onnx:
  87. dt_boxes, elapse = self.additional_ocr.text_detector(img)
  88. else:
  89. dt_boxes, elapse = self.text_detector(img)
  90. if dt_boxes is None:
  91. ocr_res.append(None)
  92. continue
  93. dt_boxes = sorted_boxes(dt_boxes)
  94. # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
  95. dt_boxes = merge_det_boxes(dt_boxes)
  96. if mfd_res:
  97. bef = time.time()
  98. dt_boxes = update_det_boxes(dt_boxes, mfd_res)
  99. aft = time.time()
  100. logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
  101. len(dt_boxes), aft - bef))
  102. tmp_res = [box.tolist() for box in dt_boxes]
  103. ocr_res.append(tmp_res)
  104. return ocr_res
  105. else:
  106. ocr_res = []
  107. cls_res = []
  108. for img in imgs:
  109. if not isinstance(img, list):
  110. img = preprocess_image(img)
  111. img = [img]
  112. if self.use_angle_cls and cls:
  113. img, cls_res_tmp, elapse = self.text_classifier(img)
  114. if not rec:
  115. cls_res.append(cls_res_tmp)
  116. if self.use_onnx:
  117. rec_res, elapse = self.additional_ocr.text_recognizer(img)
  118. else:
  119. rec_res, elapse = self.text_recognizer(img)
  120. ocr_res.append(rec_res)
  121. if not rec:
  122. return cls_res
  123. return ocr_res
  124. def __call__(self, img, cls=True, mfd_res=None):
  125. time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
  126. if img is None:
  127. logger.debug("no valid image provided")
  128. return None, None, time_dict
  129. start = time.time()
  130. ori_im = img.copy()
  131. if self.use_onnx:
  132. dt_boxes, elapse = self.additional_ocr.text_detector(img)
  133. else:
  134. dt_boxes, elapse = self.text_detector(img)
  135. time_dict['det'] = elapse
  136. if dt_boxes is None:
  137. logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
  138. end = time.time()
  139. time_dict['all'] = end - start
  140. return None, None, time_dict
  141. else:
  142. logger.debug("dt_boxes num : {}, elapsed : {}".format(
  143. len(dt_boxes), elapse))
  144. img_crop_list = []
  145. dt_boxes = sorted_boxes(dt_boxes)
  146. # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
  147. dt_boxes = merge_det_boxes(dt_boxes)
  148. if mfd_res:
  149. bef = time.time()
  150. dt_boxes = update_det_boxes(dt_boxes, mfd_res)
  151. aft = time.time()
  152. logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
  153. len(dt_boxes), aft - bef))
  154. for bno in range(len(dt_boxes)):
  155. tmp_box = copy.deepcopy(dt_boxes[bno])
  156. if self.args.det_box_type == "quad":
  157. img_crop = get_rotate_crop_image(ori_im, tmp_box)
  158. else:
  159. img_crop = get_minarea_rect_crop(ori_im, tmp_box)
  160. img_crop_list.append(img_crop)
  161. if self.use_angle_cls and cls:
  162. img_crop_list, angle_list, elapse = self.text_classifier(
  163. img_crop_list)
  164. time_dict['cls'] = elapse
  165. logger.debug("cls num : {}, elapsed : {}".format(
  166. len(img_crop_list), elapse))
  167. if self.use_onnx:
  168. rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
  169. else:
  170. rec_res, elapse = self.text_recognizer(img_crop_list)
  171. time_dict['rec'] = elapse
  172. logger.debug("rec_res num : {}, elapsed : {}".format(
  173. len(rec_res), elapse))
  174. if self.args.save_crop_res:
  175. self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
  176. rec_res)
  177. filter_boxes, filter_rec_res = [], []
  178. for box, rec_result in zip(dt_boxes, rec_res):
  179. text, score = rec_result
  180. if score >= self.drop_score:
  181. filter_boxes.append(box)
  182. filter_rec_res.append(rec_result)
  183. end = time.time()
  184. time_dict['all'] = end - start
  185. return filter_boxes, filter_rec_res, time_dict