ppocr_273_mod.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import copy
  2. import time
  3. import cv2
  4. import numpy as np
  5. from paddleocr import PaddleOCR
  6. from paddleocr.paddleocr import check_img, logger
  7. from paddleocr.ppocr.utils.utility import alpha_to_color, binarize_img
  8. from paddleocr.tools.infer.predict_system import sorted_boxes
  9. from paddleocr.tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
  10. from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes
  11. class ModifiedPaddleOCR(PaddleOCR):
  12. def ocr(self,
  13. img,
  14. det=True,
  15. rec=True,
  16. cls=True,
  17. bin=False,
  18. inv=False,
  19. alpha_color=(255, 255, 255),
  20. mfd_res=None,
  21. ):
  22. """
  23. OCR with PaddleOCR
  24. args:
  25. img: img for OCR, support ndarray, img_path and list or ndarray
  26. det: use text detection or not. If False, only rec will be exec. Default is True
  27. rec: use text recognition or not. If False, only det will be exec. Default is True
  28. cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
  29. bin: binarize image to black and white. Default is False.
  30. inv: invert image colors. Default is False.
  31. alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
  32. """
  33. assert isinstance(img, (np.ndarray, list, str, bytes))
  34. if isinstance(img, list) and det == True:
  35. logger.error('When input a list of images, det must be false')
  36. exit(0)
  37. if cls == True and self.use_angle_cls == False:
  38. pass
  39. # logger.warning(
  40. # 'Since the angle classifier is not initialized, it will not be used during the forward process'
  41. # )
  42. img = check_img(img)
  43. # for infer pdf file
  44. if isinstance(img, list):
  45. if self.page_num > len(img) or self.page_num == 0:
  46. self.page_num = len(img)
  47. imgs = img[:self.page_num]
  48. else:
  49. imgs = [img]
  50. def preprocess_image(_image):
  51. _image = alpha_to_color(_image, alpha_color)
  52. if inv:
  53. _image = cv2.bitwise_not(_image)
  54. if bin:
  55. _image = binarize_img(_image)
  56. return _image
  57. if det and rec:
  58. ocr_res = []
  59. for idx, img in enumerate(imgs):
  60. img = preprocess_image(img)
  61. dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res)
  62. if not dt_boxes and not rec_res:
  63. ocr_res.append(None)
  64. continue
  65. tmp_res = [[box.tolist(), res]
  66. for box, res in zip(dt_boxes, rec_res)]
  67. ocr_res.append(tmp_res)
  68. return ocr_res
  69. elif det and not rec:
  70. ocr_res = []
  71. for idx, img in enumerate(imgs):
  72. img = preprocess_image(img)
  73. dt_boxes, elapse = self.text_detector(img)
  74. if not dt_boxes:
  75. ocr_res.append(None)
  76. continue
  77. tmp_res = [box.tolist() for box in dt_boxes]
  78. ocr_res.append(tmp_res)
  79. return ocr_res
  80. else:
  81. ocr_res = []
  82. cls_res = []
  83. for idx, img in enumerate(imgs):
  84. if not isinstance(img, list):
  85. img = preprocess_image(img)
  86. img = [img]
  87. if self.use_angle_cls and cls:
  88. img, cls_res_tmp, elapse = self.text_classifier(img)
  89. if not rec:
  90. cls_res.append(cls_res_tmp)
  91. rec_res, elapse = self.text_recognizer(img)
  92. ocr_res.append(rec_res)
  93. if not rec:
  94. return cls_res
  95. return ocr_res
  96. def __call__(self, img, cls=True, mfd_res=None):
  97. time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
  98. if img is None:
  99. logger.debug("no valid image provided")
  100. return None, None, time_dict
  101. start = time.time()
  102. ori_im = img.copy()
  103. dt_boxes, elapse = self.text_detector(img)
  104. time_dict['det'] = elapse
  105. if dt_boxes is None:
  106. logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
  107. end = time.time()
  108. time_dict['all'] = end - start
  109. return None, None, time_dict
  110. else:
  111. logger.debug("dt_boxes num : {}, elapsed : {}".format(
  112. len(dt_boxes), elapse))
  113. img_crop_list = []
  114. dt_boxes = sorted_boxes(dt_boxes)
  115. # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
  116. dt_boxes = merge_det_boxes(dt_boxes)
  117. if mfd_res:
  118. bef = time.time()
  119. dt_boxes = update_det_boxes(dt_boxes, mfd_res)
  120. aft = time.time()
  121. logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
  122. len(dt_boxes), aft - bef))
  123. for bno in range(len(dt_boxes)):
  124. tmp_box = copy.deepcopy(dt_boxes[bno])
  125. if self.args.det_box_type == "quad":
  126. img_crop = get_rotate_crop_image(ori_im, tmp_box)
  127. else:
  128. img_crop = get_minarea_rect_crop(ori_im, tmp_box)
  129. img_crop_list.append(img_crop)
  130. if self.use_angle_cls and cls:
  131. img_crop_list, angle_list, elapse = self.text_classifier(
  132. img_crop_list)
  133. time_dict['cls'] = elapse
  134. logger.debug("cls num : {}, elapsed : {}".format(
  135. len(img_crop_list), elapse))
  136. rec_res, elapse = self.text_recognizer(img_crop_list)
  137. time_dict['rec'] = elapse
  138. logger.debug("rec_res num : {}, elapsed : {}".format(
  139. len(rec_res), elapse))
  140. if self.args.save_crop_res:
  141. self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
  142. rec_res)
  143. filter_boxes, filter_rec_res = [], []
  144. for box, rec_result in zip(dt_boxes, rec_res):
  145. text, score = rec_result
  146. if score >= self.drop_score:
  147. filter_boxes.append(box)
  148. filter_rec_res.append(rec_result)
  149. end = time.time()
  150. time_dict['all'] = end - start
  151. return filter_boxes, filter_rec_res, time_dict