ppocr_291_mod.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import copy
  2. import time
  3. import cv2
  4. import numpy as np
  5. from paddleocr import PaddleOCR
  6. from paddleocr.paddleocr import check_img, logger
  7. from paddleocr.ppocr.utils.utility import alpha_to_color, binarize_img
  8. from paddleocr.tools.infer.predict_system import sorted_boxes
  9. from paddleocr.tools.infer.utility import slice_generator, merge_fragmented, get_rotate_crop_image, \
  10. get_minarea_rect_crop
  11. from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes
  12. class ModifiedPaddleOCR(PaddleOCR):
  13. def ocr(
  14. self,
  15. img,
  16. det=True,
  17. rec=True,
  18. cls=True,
  19. bin=False,
  20. inv=False,
  21. alpha_color=(255, 255, 255),
  22. slice={},
  23. mfd_res=None,
  24. ):
  25. """
  26. OCR with PaddleOCR
  27. Args:
  28. img: Image for OCR. It can be an ndarray, img_path, or a list of ndarrays.
  29. det: Use text detection or not. If False, only text recognition will be executed. Default is True.
  30. rec: Use text recognition or not. If False, only text detection will be executed. Default is True.
  31. cls: Use angle classifier or not. Default is True. If True, the text with a rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance.
  32. bin: Binarize image to black and white. Default is False.
  33. inv: Invert image colors. Default is False.
  34. alpha_color: Set RGB color Tuple for transparent parts replacement. Default is pure white.
  35. slice: Use sliding window inference for large images. Both det and rec must be True. Requires int values for slice["horizontal_stride"], slice["vertical_stride"], slice["merge_x_thres"], slice["merge_y_thres"] (See doc/doc_en/slice_en.md). Default is {}.
  36. Returns:
  37. If both det and rec are True, returns a list of OCR results for each image. Each OCR result is a list of bounding boxes and recognized text for each detected text region.
  38. If det is True and rec is False, returns a list of detected bounding boxes for each image.
  39. If det is False and rec is True, returns a list of recognized text for each image.
  40. If both det and rec are False, returns a list of angle classification results for each image.
  41. Raises:
  42. AssertionError: If the input image is not of type ndarray, list, str, or bytes.
  43. SystemExit: If det is True and the input is a list of images.
  44. Note:
  45. - If the angle classifier is not initialized (use_angle_cls=False), it will not be used during the forward process.
  46. - For PDF files, if the input is a list of images and the page_num is specified, only the first page_num images will be processed.
  47. - The preprocess_image function is used to preprocess the input image by applying alpha color replacement, inversion, and binarization if specified.
  48. """
  49. assert isinstance(img, (np.ndarray, list, str, bytes))
  50. if isinstance(img, list) and det == True:
  51. logger.error("When input a list of images, det must be false")
  52. exit(0)
  53. if cls == True and self.use_angle_cls == False:
  54. logger.warning(
  55. "Since the angle classifier is not initialized, it will not be used during the forward process"
  56. )
  57. img, flag_gif, flag_pdf = check_img(img, alpha_color)
  58. # for infer pdf file
  59. if isinstance(img, list) and flag_pdf:
  60. if self.page_num > len(img) or self.page_num == 0:
  61. imgs = img
  62. else:
  63. imgs = img[: self.page_num]
  64. else:
  65. imgs = [img]
  66. def preprocess_image(_image):
  67. _image = alpha_to_color(_image, alpha_color)
  68. if inv:
  69. _image = cv2.bitwise_not(_image)
  70. if bin:
  71. _image = binarize_img(_image)
  72. return _image
  73. if det and rec:
  74. ocr_res = []
  75. for img in imgs:
  76. img = preprocess_image(img)
  77. dt_boxes, rec_res, _ = self.__call__(img, cls, slice, mfd_res=mfd_res)
  78. if not dt_boxes and not rec_res:
  79. ocr_res.append(None)
  80. continue
  81. tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
  82. ocr_res.append(tmp_res)
  83. return ocr_res
  84. elif det and not rec:
  85. ocr_res = []
  86. for img in imgs:
  87. img = preprocess_image(img)
  88. dt_boxes, elapse = self.text_detector(img)
  89. if dt_boxes.size == 0:
  90. ocr_res.append(None)
  91. continue
  92. tmp_res = [box.tolist() for box in dt_boxes]
  93. ocr_res.append(tmp_res)
  94. return ocr_res
  95. else:
  96. ocr_res = []
  97. cls_res = []
  98. for img in imgs:
  99. if not isinstance(img, list):
  100. img = preprocess_image(img)
  101. img = [img]
  102. if self.use_angle_cls and cls:
  103. img, cls_res_tmp, elapse = self.text_classifier(img)
  104. if not rec:
  105. cls_res.append(cls_res_tmp)
  106. rec_res, elapse = self.text_recognizer(img)
  107. ocr_res.append(rec_res)
  108. if not rec:
  109. return cls_res
  110. return ocr_res
  111. def __call__(self, img, cls=True, slice={}, mfd_res=None):
  112. time_dict = {"det": 0, "rec": 0, "cls": 0, "all": 0}
  113. if img is None:
  114. logger.debug("no valid image provided")
  115. return None, None, time_dict
  116. start = time.time()
  117. ori_im = img.copy()
  118. if slice:
  119. slice_gen = slice_generator(
  120. img,
  121. horizontal_stride=slice["horizontal_stride"],
  122. vertical_stride=slice["vertical_stride"],
  123. )
  124. elapsed = []
  125. dt_slice_boxes = []
  126. for slice_crop, v_start, h_start in slice_gen:
  127. dt_boxes, elapse = self.text_detector(slice_crop, use_slice=True)
  128. if dt_boxes.size:
  129. dt_boxes[:, :, 0] += h_start
  130. dt_boxes[:, :, 1] += v_start
  131. dt_slice_boxes.append(dt_boxes)
  132. elapsed.append(elapse)
  133. dt_boxes = np.concatenate(dt_slice_boxes)
  134. dt_boxes = merge_fragmented(
  135. boxes=dt_boxes,
  136. x_threshold=slice["merge_x_thres"],
  137. y_threshold=slice["merge_y_thres"],
  138. )
  139. elapse = sum(elapsed)
  140. else:
  141. dt_boxes, elapse = self.text_detector(img)
  142. time_dict["det"] = elapse
  143. if dt_boxes is None:
  144. logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
  145. end = time.time()
  146. time_dict["all"] = end - start
  147. return None, None, time_dict
  148. else:
  149. logger.debug(
  150. "dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse)
  151. )
  152. img_crop_list = []
  153. dt_boxes = sorted_boxes(dt_boxes)
  154. if mfd_res:
  155. bef = time.time()
  156. dt_boxes = update_det_boxes(dt_boxes, mfd_res)
  157. aft = time.time()
  158. logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
  159. len(dt_boxes), aft - bef))
  160. for bno in range(len(dt_boxes)):
  161. tmp_box = copy.deepcopy(dt_boxes[bno])
  162. if self.args.det_box_type == "quad":
  163. img_crop = get_rotate_crop_image(ori_im, tmp_box)
  164. else:
  165. img_crop = get_minarea_rect_crop(ori_im, tmp_box)
  166. img_crop_list.append(img_crop)
  167. if self.use_angle_cls and cls:
  168. img_crop_list, angle_list, elapse = self.text_classifier(img_crop_list)
  169. time_dict["cls"] = elapse
  170. logger.debug(
  171. "cls num : {}, elapsed : {}".format(len(img_crop_list), elapse)
  172. )
  173. if len(img_crop_list) > 1000:
  174. logger.debug(
  175. f"rec crops num: {len(img_crop_list)}, time and memory cost may be large."
  176. )
  177. rec_res, elapse = self.text_recognizer(img_crop_list)
  178. time_dict["rec"] = elapse
  179. logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
  180. if self.args.save_crop_res:
  181. self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, rec_res)
  182. filter_boxes, filter_rec_res = [], []
  183. for box, rec_result in zip(dt_boxes, rec_res):
  184. text, score = rec_result[0], rec_result[1]
  185. if score >= self.drop_score:
  186. filter_boxes.append(box)
  187. filter_rec_res.append(rec_result)
  188. end = time.time()
  189. time_dict["all"] = end - start
  190. return filter_boxes, filter_rec_res, time_dict