ocr_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. import cv2
  2. import numpy as np
  3. from loguru import logger
  4. from io import BytesIO
  5. from PIL import Image
  6. import base64
  7. from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
  8. from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
  9. from importlib.resources import files
  10. from paddleocr import PaddleOCR
  11. from ppocr.utils.utility import check_and_read
  12. def img_decode(content: bytes):
  13. np_arr = np.frombuffer(content, dtype=np.uint8)
  14. return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
  15. def check_img(img):
  16. if isinstance(img, bytes):
  17. img = img_decode(img)
  18. if isinstance(img, str):
  19. image_file = img
  20. img, flag_gif, flag_pdf = check_and_read(image_file)
  21. if not flag_gif and not flag_pdf:
  22. with open(image_file, 'rb') as f:
  23. img_str = f.read()
  24. img = img_decode(img_str)
  25. if img is None:
  26. try:
  27. buf = BytesIO()
  28. image = BytesIO(img_str)
  29. im = Image.open(image)
  30. rgb = im.convert('RGB')
  31. rgb.save(buf, 'jpeg')
  32. buf.seek(0)
  33. image_bytes = buf.read()
  34. data_base64 = str(base64.b64encode(image_bytes),
  35. encoding="utf-8")
  36. image_decode = base64.b64decode(data_base64)
  37. img_array = np.frombuffer(image_decode, np.uint8)
  38. img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
  39. except:
  40. logger.error("error in loading image:{}".format(image_file))
  41. return None
  42. if img is None:
  43. logger.error("error in loading image:{}".format(image_file))
  44. return None
  45. if isinstance(img, np.ndarray) and len(img.shape) == 2:
  46. img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  47. return img
  48. def bbox_to_points(bbox):
  49. """ 将bbox格式转换为四个顶点的数组 """
  50. x0, y0, x1, y1 = bbox
  51. return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
  52. def points_to_bbox(points):
  53. """ 将四个顶点的数组转换为bbox格式 """
  54. x0, y0 = points[0]
  55. x1, _ = points[1]
  56. _, y1 = points[2]
  57. return [x0, y0, x1, y1]
  58. def merge_intervals(intervals):
  59. # Sort the intervals based on the start value
  60. intervals.sort(key=lambda x: x[0])
  61. merged = []
  62. for interval in intervals:
  63. # If the list of merged intervals is empty or if the current
  64. # interval does not overlap with the previous, simply append it.
  65. if not merged or merged[-1][1] < interval[0]:
  66. merged.append(interval)
  67. else:
  68. # Otherwise, there is overlap, so we merge the current and previous intervals.
  69. merged[-1][1] = max(merged[-1][1], interval[1])
  70. return merged
  71. def remove_intervals(original, masks):
  72. # Merge all mask intervals
  73. merged_masks = merge_intervals(masks)
  74. result = []
  75. original_start, original_end = original
  76. for mask in merged_masks:
  77. mask_start, mask_end = mask
  78. # If the mask starts after the original range, ignore it
  79. if mask_start > original_end:
  80. continue
  81. # If the mask ends before the original range starts, ignore it
  82. if mask_end < original_start:
  83. continue
  84. # Remove the masked part from the original range
  85. if original_start < mask_start:
  86. result.append([original_start, mask_start - 1])
  87. original_start = max(mask_end + 1, original_start)
  88. # Add the remaining part of the original range, if any
  89. if original_start <= original_end:
  90. result.append([original_start, original_end])
  91. return result
  92. def update_det_boxes(dt_boxes, mfd_res):
  93. new_dt_boxes = []
  94. angle_boxes_list = []
  95. for text_box in dt_boxes:
  96. if calculate_is_angle(text_box):
  97. angle_boxes_list.append(text_box)
  98. continue
  99. text_bbox = points_to_bbox(text_box)
  100. masks_list = []
  101. for mf_box in mfd_res:
  102. mf_bbox = mf_box['bbox']
  103. if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
  104. masks_list.append([mf_bbox[0], mf_bbox[2]])
  105. text_x_range = [text_bbox[0], text_bbox[2]]
  106. text_remove_mask_range = remove_intervals(text_x_range, masks_list)
  107. temp_dt_box = []
  108. for text_remove_mask in text_remove_mask_range:
  109. temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
  110. if len(temp_dt_box) > 0:
  111. new_dt_boxes.extend(temp_dt_box)
  112. new_dt_boxes.extend(angle_boxes_list)
  113. return new_dt_boxes
  114. def merge_overlapping_spans(spans):
  115. """
  116. Merges overlapping spans on the same line.
  117. :param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
  118. :return: A list of merged spans
  119. """
  120. # Return an empty list if the input spans list is empty
  121. if not spans:
  122. return []
  123. # Sort spans by their starting x-coordinate
  124. spans.sort(key=lambda x: x[0])
  125. # Initialize the list of merged spans
  126. merged = []
  127. for span in spans:
  128. # Unpack span coordinates
  129. x1, y1, x2, y2 = span
  130. # If the merged list is empty or there's no horizontal overlap, add the span directly
  131. if not merged or merged[-1][2] < x1:
  132. merged.append(span)
  133. else:
  134. # If there is horizontal overlap, merge the current span with the previous one
  135. last_span = merged.pop()
  136. # Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
  137. x1 = min(last_span[0], x1)
  138. y1 = min(last_span[1], y1)
  139. x2 = max(last_span[2], x2)
  140. y2 = max(last_span[3], y2)
  141. # Add the merged span back to the list
  142. merged.append((x1, y1, x2, y2))
  143. # Return the list of merged spans
  144. return merged
  145. def merge_det_boxes(dt_boxes):
  146. """
  147. Merge detection boxes.
  148. This function takes a list of detected bounding boxes, each represented by four corner points.
  149. The goal is to merge these bounding boxes into larger text regions.
  150. Parameters:
  151. dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
  152. Returns:
  153. list: A list containing the merged text regions, where each region is represented by four corner points.
  154. """
  155. # Convert the detection boxes into a dictionary format with bounding boxes and type
  156. dt_boxes_dict_list = []
  157. angle_boxes_list = []
  158. for text_box in dt_boxes:
  159. text_bbox = points_to_bbox(text_box)
  160. if calculate_is_angle(text_box):
  161. angle_boxes_list.append(text_box)
  162. continue
  163. text_box_dict = {
  164. 'bbox': text_bbox,
  165. 'type': 'text',
  166. }
  167. dt_boxes_dict_list.append(text_box_dict)
  168. # Merge adjacent text regions into lines
  169. lines = merge_spans_to_line(dt_boxes_dict_list)
  170. # Initialize a new list for storing the merged text regions
  171. new_dt_boxes = []
  172. for line in lines:
  173. line_bbox_list = []
  174. for span in line:
  175. line_bbox_list.append(span['bbox'])
  176. # Merge overlapping text regions within the same line
  177. merged_spans = merge_overlapping_spans(line_bbox_list)
  178. # Convert the merged text regions back to point format and add them to the new detection box list
  179. for span in merged_spans:
  180. new_dt_boxes.append(bbox_to_points(span))
  181. new_dt_boxes.extend(angle_boxes_list)
  182. return new_dt_boxes
  183. def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
  184. paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
  185. # Adjust the coordinates of the formula area
  186. adjusted_mfdetrec_res = []
  187. for mf_res in single_page_mfdetrec_res:
  188. mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
  189. # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
  190. x0 = mf_xmin - xmin + paste_x
  191. y0 = mf_ymin - ymin + paste_y
  192. x1 = mf_xmax - xmin + paste_x
  193. y1 = mf_ymax - ymin + paste_y
  194. # Filter formula blocks outside the graph
  195. if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
  196. continue
  197. else:
  198. adjusted_mfdetrec_res.append({
  199. "bbox": [x0, y0, x1, y1],
  200. })
  201. return adjusted_mfdetrec_res
  202. def get_ocr_result_list(ocr_res, useful_list):
  203. paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
  204. ocr_result_list = []
  205. for box_ocr_res in ocr_res:
  206. if len(box_ocr_res) == 2:
  207. p1, p2, p3, p4 = box_ocr_res[0]
  208. text, score = box_ocr_res[1]
  209. # logger.info(f"text: {text}, score: {score}")
  210. if score < 0.6: # 过滤低置信度的结果
  211. continue
  212. else:
  213. p1, p2, p3, p4 = box_ocr_res
  214. text, score = "", 1
  215. # average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
  216. # if average_angle_degrees > 0.5:
  217. poly = [p1, p2, p3, p4]
  218. if calculate_is_angle(poly):
  219. # logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
  220. # 与x轴的夹角超过0.5度,对边界做一下矫正
  221. # 计算几何中心
  222. x_center = sum(point[0] for point in poly) / 4
  223. y_center = sum(point[1] for point in poly) / 4
  224. new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
  225. new_width = p3[0] - p1[0]
  226. p1 = [x_center - new_width / 2, y_center - new_height / 2]
  227. p2 = [x_center + new_width / 2, y_center - new_height / 2]
  228. p3 = [x_center + new_width / 2, y_center + new_height / 2]
  229. p4 = [x_center - new_width / 2, y_center + new_height / 2]
  230. # Convert the coordinates back to the original coordinate system
  231. p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
  232. p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
  233. p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
  234. p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
  235. ocr_result_list.append({
  236. 'category_id': 15,
  237. 'poly': p1 + p2 + p3 + p4,
  238. 'score': float(round(score, 2)),
  239. 'text': text,
  240. })
  241. return ocr_result_list
  242. def calculate_is_angle(poly):
  243. p1, p2, p3, p4 = poly
  244. height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
  245. if 0.8 * height <= (p3[1] - p1[1]) <= 1.2 * height:
  246. return False
  247. else:
  248. # logger.info((p3[1] - p1[1])/height)
  249. return True
  250. class ONNXModelSingleton:
  251. _instance = None
  252. _models = {}
  253. def __new__(cls, *args, **kwargs):
  254. if cls._instance is None:
  255. cls._instance = super().__new__(cls)
  256. return cls._instance
  257. def get_onnx_model(self, **kwargs):
  258. lang = kwargs.get('lang', None)
  259. det_db_box_thresh = kwargs.get('det_db_box_thresh', 0.3)
  260. use_dilation = kwargs.get('use_dilation', True)
  261. det_db_unclip_ratio = kwargs.get('det_db_unclip_ratio', 1.8)
  262. key = (lang, det_db_box_thresh, use_dilation, det_db_unclip_ratio)
  263. if key not in self._models:
  264. self._models[key] = onnx_model_init(key)
  265. return self._models[key]
  266. def onnx_model_init(key):
  267. if len(key) < 4:
  268. logger.error('Invalid key length, expected at least 4 elements')
  269. exit(1)
  270. try:
  271. resource_path = files("rapidocr_onnxruntime") / "models"
  272. additional_ocr_params = {
  273. "use_onnx": True,
  274. "det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
  275. "rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
  276. "cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
  277. "det_db_box_thresh": key[1],
  278. "use_dilation": key[2],
  279. "det_db_unclip_ratio": key[3],
  280. }
  281. if key[0] is not None:
  282. additional_ocr_params["lang"] = key[0]
  283. # logger.info(f"additional_ocr_params: {additional_ocr_params}")
  284. onnx_model = PaddleOCR(**additional_ocr_params)
  285. if onnx_model is None:
  286. logger.error('model init failed')
  287. exit(1)
  288. else:
  289. return onnx_model
  290. except Exception as e:
  291. logger.exception(f'Error initializing model: {e}')
  292. exit(1)