ppocr_273_mod.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import copy
  2. import time
  3. import cv2
  4. import numpy as np
  5. from paddleocr import PaddleOCR
  6. from ppocr.utils.logging import get_logger
  7. from ppocr.utils.utility import alpha_to_color, binarize_img
  8. from tools.infer.predict_system import sorted_boxes
  9. from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
  10. from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img
  11. logger = get_logger()
  12. class ModifiedPaddleOCR(PaddleOCR):
  13. def ocr(self,
  14. img,
  15. det=True,
  16. rec=True,
  17. cls=True,
  18. bin=False,
  19. inv=False,
  20. alpha_color=(255, 255, 255),
  21. mfd_res=None,
  22. ):
  23. """
  24. OCR with PaddleOCR
  25. args:
  26. img: img for OCR, support ndarray, img_path and list or ndarray
  27. det: use text detection or not. If False, only rec will be exec. Default is True
  28. rec: use text recognition or not. If False, only det will be exec. Default is True
  29. cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
  30. bin: binarize image to black and white. Default is False.
  31. inv: invert image colors. Default is False.
  32. alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
  33. """
  34. assert isinstance(img, (np.ndarray, list, str, bytes))
  35. if isinstance(img, list) and det == True:
  36. logger.error('When input a list of images, det must be false')
  37. exit(0)
  38. if cls == True and self.use_angle_cls == False:
  39. pass
  40. # logger.warning(
  41. # 'Since the angle classifier is not initialized, it will not be used during the forward process'
  42. # )
  43. img = check_img(img)
  44. # for infer pdf file
  45. if isinstance(img, list):
  46. if self.page_num > len(img) or self.page_num == 0:
  47. self.page_num = len(img)
  48. imgs = img[:self.page_num]
  49. else:
  50. imgs = [img]
  51. def preprocess_image(_image):
  52. _image = alpha_to_color(_image, alpha_color)
  53. if inv:
  54. _image = cv2.bitwise_not(_image)
  55. if bin:
  56. _image = binarize_img(_image)
  57. return _image
  58. if det and rec:
  59. ocr_res = []
  60. for img in imgs:
  61. img = preprocess_image(img)
  62. dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res)
  63. if not dt_boxes and not rec_res:
  64. ocr_res.append(None)
  65. continue
  66. tmp_res = [[box.tolist(), res]
  67. for box, res in zip(dt_boxes, rec_res)]
  68. ocr_res.append(tmp_res)
  69. return ocr_res
  70. elif det and not rec:
  71. ocr_res = []
  72. for img in imgs:
  73. img = preprocess_image(img)
  74. dt_boxes, elapse = self.text_detector(img)
  75. if dt_boxes is None:
  76. ocr_res.append(None)
  77. continue
  78. dt_boxes = sorted_boxes(dt_boxes)
  79. # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
  80. dt_boxes = merge_det_boxes(dt_boxes)
  81. if mfd_res:
  82. bef = time.time()
  83. dt_boxes = update_det_boxes(dt_boxes, mfd_res)
  84. aft = time.time()
  85. logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
  86. len(dt_boxes), aft - bef))
  87. tmp_res = [box.tolist() for box in dt_boxes]
  88. ocr_res.append(tmp_res)
  89. return ocr_res
  90. else:
  91. ocr_res = []
  92. cls_res = []
  93. for img in imgs:
  94. if not isinstance(img, list):
  95. img = preprocess_image(img)
  96. img = [img]
  97. if self.use_angle_cls and cls:
  98. img, cls_res_tmp, elapse = self.text_classifier(img)
  99. if not rec:
  100. cls_res.append(cls_res_tmp)
  101. rec_res, elapse = self.text_recognizer(img)
  102. ocr_res.append(rec_res)
  103. if not rec:
  104. return cls_res
  105. return ocr_res
  106. def __call__(self, img, cls=True, mfd_res=None):
  107. time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
  108. if img is None:
  109. logger.debug("no valid image provided")
  110. return None, None, time_dict
  111. start = time.time()
  112. ori_im = img.copy()
  113. dt_boxes, elapse = self.text_detector(img)
  114. time_dict['det'] = elapse
  115. if dt_boxes is None:
  116. logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
  117. end = time.time()
  118. time_dict['all'] = end - start
  119. return None, None, time_dict
  120. else:
  121. logger.debug("dt_boxes num : {}, elapsed : {}".format(
  122. len(dt_boxes), elapse))
  123. img_crop_list = []
  124. dt_boxes = sorted_boxes(dt_boxes)
  125. # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
  126. dt_boxes = merge_det_boxes(dt_boxes)
  127. if mfd_res:
  128. bef = time.time()
  129. dt_boxes = update_det_boxes(dt_boxes, mfd_res)
  130. aft = time.time()
  131. logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
  132. len(dt_boxes), aft - bef))
  133. for bno in range(len(dt_boxes)):
  134. tmp_box = copy.deepcopy(dt_boxes[bno])
  135. if self.args.det_box_type == "quad":
  136. img_crop = get_rotate_crop_image(ori_im, tmp_box)
  137. else:
  138. img_crop = get_minarea_rect_crop(ori_im, tmp_box)
  139. img_crop_list.append(img_crop)
  140. if self.use_angle_cls and cls:
  141. img_crop_list, angle_list, elapse = self.text_classifier(
  142. img_crop_list)
  143. time_dict['cls'] = elapse
  144. logger.debug("cls num : {}, elapsed : {}".format(
  145. len(img_crop_list), elapse))
  146. rec_res, elapse = self.text_recognizer(img_crop_list)
  147. time_dict['rec'] = elapse
  148. logger.debug("rec_res num : {}, elapsed : {}".format(
  149. len(rec_res), elapse))
  150. if self.args.save_crop_res:
  151. self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
  152. rec_res)
  153. filter_boxes, filter_rec_res = [], []
  154. for box, rec_result in zip(dt_boxes, rec_res):
  155. text, score = rec_result
  156. if score >= self.drop_score:
  157. filter_boxes.append(box)
  158. filter_rec_res.append(rec_result)
  159. end = time.time()
  160. time_dict['all'] = end - start
  161. return filter_boxes, filter_rec_res, time_dict