ppocr_273_mod.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import copy
  2. import time
  3. import cv2
  4. import numpy as np
  5. from paddleocr import PaddleOCR
  6. from paddleocr.paddleocr import check_img, logger
  7. from paddleocr.ppocr.utils.utility import alpha_to_color, binarize_img
  8. from paddleocr.tools.infer.predict_system import sorted_boxes
  9. from paddleocr.tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
  10. from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes
  11. class ModifiedPaddleOCR(PaddleOCR):
  12. def ocr(self,
  13. img,
  14. det=True,
  15. rec=True,
  16. cls=True,
  17. bin=False,
  18. inv=False,
  19. alpha_color=(255, 255, 255),
  20. mfd_res=None,
  21. ):
  22. """
  23. OCR with PaddleOCR
  24. args:
  25. img: img for OCR, support ndarray, img_path and list or ndarray
  26. det: use text detection or not. If False, only rec will be exec. Default is True
  27. rec: use text recognition or not. If False, only det will be exec. Default is True
  28. cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
  29. bin: binarize image to black and white. Default is False.
  30. inv: invert image colors. Default is False.
  31. alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
  32. """
  33. assert isinstance(img, (np.ndarray, list, str, bytes))
  34. if isinstance(img, list) and det == True:
  35. logger.error('When input a list of images, det must be false')
  36. exit(0)
  37. if cls == True and self.use_angle_cls == False:
  38. pass
  39. # logger.warning(
  40. # 'Since the angle classifier is not initialized, it will not be used during the forward process'
  41. # )
  42. img = check_img(img)
  43. # for infer pdf file
  44. if isinstance(img, list):
  45. if self.page_num > len(img) or self.page_num == 0:
  46. self.page_num = len(img)
  47. imgs = img[:self.page_num]
  48. else:
  49. imgs = [img]
  50. def preprocess_image(_image):
  51. _image = alpha_to_color(_image, alpha_color)
  52. if inv:
  53. _image = cv2.bitwise_not(_image)
  54. if bin:
  55. _image = binarize_img(_image)
  56. return _image
  57. if det and rec:
  58. ocr_res = []
  59. for img in imgs:
  60. img = preprocess_image(img)
  61. dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res)
  62. if not dt_boxes and not rec_res:
  63. ocr_res.append(None)
  64. continue
  65. tmp_res = [[box.tolist(), res]
  66. for box, res in zip(dt_boxes, rec_res)]
  67. ocr_res.append(tmp_res)
  68. return ocr_res
  69. elif det and not rec:
  70. ocr_res = []
  71. for img in imgs:
  72. img = preprocess_image(img)
  73. dt_boxes, elapse = self.text_detector(img)
  74. if dt_boxes is None:
  75. ocr_res.append(None)
  76. continue
  77. dt_boxes = sorted_boxes(dt_boxes)
  78. # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
  79. dt_boxes = merge_det_boxes(dt_boxes)
  80. if mfd_res:
  81. bef = time.time()
  82. dt_boxes = update_det_boxes(dt_boxes, mfd_res)
  83. aft = time.time()
  84. logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
  85. len(dt_boxes), aft - bef))
  86. tmp_res = [box.tolist() for box in dt_boxes]
  87. ocr_res.append(tmp_res)
  88. return ocr_res
  89. else:
  90. ocr_res = []
  91. cls_res = []
  92. for img in imgs:
  93. if not isinstance(img, list):
  94. img = preprocess_image(img)
  95. img = [img]
  96. if self.use_angle_cls and cls:
  97. img, cls_res_tmp, elapse = self.text_classifier(img)
  98. if not rec:
  99. cls_res.append(cls_res_tmp)
  100. rec_res, elapse = self.text_recognizer(img)
  101. ocr_res.append(rec_res)
  102. if not rec:
  103. return cls_res
  104. return ocr_res
  105. def __call__(self, img, cls=True, mfd_res=None):
  106. time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
  107. if img is None:
  108. logger.debug("no valid image provided")
  109. return None, None, time_dict
  110. start = time.time()
  111. ori_im = img.copy()
  112. dt_boxes, elapse = self.text_detector(img)
  113. time_dict['det'] = elapse
  114. if dt_boxes is None:
  115. logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
  116. end = time.time()
  117. time_dict['all'] = end - start
  118. return None, None, time_dict
  119. else:
  120. logger.debug("dt_boxes num : {}, elapsed : {}".format(
  121. len(dt_boxes), elapse))
  122. img_crop_list = []
  123. dt_boxes = sorted_boxes(dt_boxes)
  124. # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
  125. dt_boxes = merge_det_boxes(dt_boxes)
  126. if mfd_res:
  127. bef = time.time()
  128. dt_boxes = update_det_boxes(dt_boxes, mfd_res)
  129. aft = time.time()
  130. logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
  131. len(dt_boxes), aft - bef))
  132. for bno in range(len(dt_boxes)):
  133. tmp_box = copy.deepcopy(dt_boxes[bno])
  134. if self.args.det_box_type == "quad":
  135. img_crop = get_rotate_crop_image(ori_im, tmp_box)
  136. else:
  137. img_crop = get_minarea_rect_crop(ori_im, tmp_box)
  138. img_crop_list.append(img_crop)
  139. if self.use_angle_cls and cls:
  140. img_crop_list, angle_list, elapse = self.text_classifier(
  141. img_crop_list)
  142. time_dict['cls'] = elapse
  143. logger.debug("cls num : {}, elapsed : {}".format(
  144. len(img_crop_list), elapse))
  145. rec_res, elapse = self.text_recognizer(img_crop_list)
  146. time_dict['rec'] = elapse
  147. logger.debug("rec_res num : {}, elapsed : {}".format(
  148. len(rec_res), elapse))
  149. if self.args.save_crop_res:
  150. self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
  151. rec_res)
  152. filter_boxes, filter_rec_res = [], []
  153. for box, rec_result in zip(dt_boxes, rec_res):
  154. text, score = rec_result
  155. if score >= self.drop_score:
  156. filter_boxes.append(box)
  157. filter_rec_res.append(rec_result)
  158. end = time.time()
  159. time_dict['all'] = end - start
  160. return filter_boxes, filter_rec_res, time_dict