ocr_utils.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. import math
  2. import numpy as np
  3. from loguru import logger
  4. from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
  5. from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
  6. def bbox_to_points(bbox):
  7. """ 将bbox格式转换为四个顶点的数组 """
  8. x0, y0, x1, y1 = bbox
  9. return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
  10. def points_to_bbox(points):
  11. """ 将四个顶点的数组转换为bbox格式 """
  12. x0, y0 = points[0]
  13. x1, _ = points[1]
  14. _, y1 = points[2]
  15. return [x0, y0, x1, y1]
  16. def merge_intervals(intervals):
  17. # Sort the intervals based on the start value
  18. intervals.sort(key=lambda x: x[0])
  19. merged = []
  20. for interval in intervals:
  21. # If the list of merged intervals is empty or if the current
  22. # interval does not overlap with the previous, simply append it.
  23. if not merged or merged[-1][1] < interval[0]:
  24. merged.append(interval)
  25. else:
  26. # Otherwise, there is overlap, so we merge the current and previous intervals.
  27. merged[-1][1] = max(merged[-1][1], interval[1])
  28. return merged
  29. def remove_intervals(original, masks):
  30. # Merge all mask intervals
  31. merged_masks = merge_intervals(masks)
  32. result = []
  33. original_start, original_end = original
  34. for mask in merged_masks:
  35. mask_start, mask_end = mask
  36. # If the mask starts after the original range, ignore it
  37. if mask_start > original_end:
  38. continue
  39. # If the mask ends before the original range starts, ignore it
  40. if mask_end < original_start:
  41. continue
  42. # Remove the masked part from the original range
  43. if original_start < mask_start:
  44. result.append([original_start, mask_start - 1])
  45. original_start = max(mask_end + 1, original_start)
  46. # Add the remaining part of the original range, if any
  47. if original_start <= original_end:
  48. result.append([original_start, original_end])
  49. return result
  50. def update_det_boxes(dt_boxes, mfd_res):
  51. new_dt_boxes = []
  52. for text_box in dt_boxes:
  53. text_bbox = points_to_bbox(text_box)
  54. masks_list = []
  55. for mf_box in mfd_res:
  56. mf_bbox = mf_box['bbox']
  57. if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
  58. masks_list.append([mf_bbox[0], mf_bbox[2]])
  59. text_x_range = [text_bbox[0], text_bbox[2]]
  60. text_remove_mask_range = remove_intervals(text_x_range, masks_list)
  61. temp_dt_box = []
  62. for text_remove_mask in text_remove_mask_range:
  63. temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
  64. if len(temp_dt_box) > 0:
  65. new_dt_boxes.extend(temp_dt_box)
  66. return new_dt_boxes
  67. def merge_overlapping_spans(spans):
  68. """
  69. Merges overlapping spans on the same line.
  70. :param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
  71. :return: A list of merged spans
  72. """
  73. # Return an empty list if the input spans list is empty
  74. if not spans:
  75. return []
  76. # Sort spans by their starting x-coordinate
  77. spans.sort(key=lambda x: x[0])
  78. # Initialize the list of merged spans
  79. merged = []
  80. for span in spans:
  81. # Unpack span coordinates
  82. x1, y1, x2, y2 = span
  83. # If the merged list is empty or there's no horizontal overlap, add the span directly
  84. if not merged or merged[-1][2] < x1:
  85. merged.append(span)
  86. else:
  87. # If there is horizontal overlap, merge the current span with the previous one
  88. last_span = merged.pop()
  89. # Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
  90. x1 = min(last_span[0], x1)
  91. y1 = min(last_span[1], y1)
  92. x2 = max(last_span[2], x2)
  93. y2 = max(last_span[3], y2)
  94. # Add the merged span back to the list
  95. merged.append((x1, y1, x2, y2))
  96. # Return the list of merged spans
  97. return merged
  98. def merge_det_boxes(dt_boxes):
  99. """
  100. Merge detection boxes.
  101. This function takes a list of detected bounding boxes, each represented by four corner points.
  102. The goal is to merge these bounding boxes into larger text regions.
  103. Parameters:
  104. dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
  105. Returns:
  106. list: A list containing the merged text regions, where each region is represented by four corner points.
  107. """
  108. # Convert the detection boxes into a dictionary format with bounding boxes and type
  109. dt_boxes_dict_list = []
  110. angle_boxes_list = []
  111. for text_box in dt_boxes:
  112. text_bbox = points_to_bbox(text_box)
  113. if text_bbox[2] <= text_bbox[0] or text_bbox[3] <= text_bbox[1]:
  114. angle_boxes_list.append(text_box)
  115. continue
  116. text_box_dict = {
  117. 'bbox': text_bbox,
  118. 'type': 'text',
  119. }
  120. dt_boxes_dict_list.append(text_box_dict)
  121. # Merge adjacent text regions into lines
  122. lines = merge_spans_to_line(dt_boxes_dict_list)
  123. # Initialize a new list for storing the merged text regions
  124. new_dt_boxes = []
  125. for line in lines:
  126. line_bbox_list = []
  127. for span in line:
  128. line_bbox_list.append(span['bbox'])
  129. # Merge overlapping text regions within the same line
  130. merged_spans = merge_overlapping_spans(line_bbox_list)
  131. # Convert the merged text regions back to point format and add them to the new detection box list
  132. for span in merged_spans:
  133. new_dt_boxes.append(bbox_to_points(span))
  134. new_dt_boxes.extend(angle_boxes_list)
  135. return new_dt_boxes
  136. def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
  137. paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
  138. # Adjust the coordinates of the formula area
  139. adjusted_mfdetrec_res = []
  140. for mf_res in single_page_mfdetrec_res:
  141. mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
  142. # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
  143. x0 = mf_xmin - xmin + paste_x
  144. y0 = mf_ymin - ymin + paste_y
  145. x1 = mf_xmax - xmin + paste_x
  146. y1 = mf_ymax - ymin + paste_y
  147. # Filter formula blocks outside the graph
  148. if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
  149. continue
  150. else:
  151. adjusted_mfdetrec_res.append({
  152. "bbox": [x0, y0, x1, y1],
  153. })
  154. return adjusted_mfdetrec_res
  155. def get_ocr_result_list(ocr_res, useful_list):
  156. paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
  157. ocr_result_list = []
  158. for box_ocr_res in ocr_res:
  159. p1, p2, p3, p4 = box_ocr_res[0]
  160. text, score = box_ocr_res[1]
  161. average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
  162. if average_angle_degrees > 0.5:
  163. # logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
  164. # 与x轴的夹角超过0.5度,对边界做一下矫正
  165. # 计算几何中心
  166. x_center = sum(point[0] for point in box_ocr_res[0]) / 4
  167. y_center = sum(point[1] for point in box_ocr_res[0]) / 4
  168. new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
  169. new_width = p3[0] - p1[0]
  170. p1 = [x_center - new_width / 2, y_center - new_height / 2]
  171. p2 = [x_center + new_width / 2, y_center - new_height / 2]
  172. p3 = [x_center + new_width / 2, y_center + new_height / 2]
  173. p4 = [x_center - new_width / 2, y_center + new_height / 2]
  174. # Convert the coordinates back to the original coordinate system
  175. p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
  176. p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
  177. p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
  178. p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
  179. ocr_result_list.append({
  180. 'category_id': 15,
  181. 'poly': p1 + p2 + p3 + p4,
  182. 'score': float(round(score, 2)),
  183. 'text': text,
  184. })
  185. return ocr_result_list
  186. def calculate_angle_degrees(poly):
  187. # 定义对角线的顶点
  188. diagonal1 = (poly[0], poly[2])
  189. diagonal2 = (poly[1], poly[3])
  190. # 计算对角线的斜率
  191. def slope(p1, p2):
  192. return (p2[1] - p1[1]) / (p2[0] - p1[0]) if p2[0] != p1[0] else float('inf')
  193. slope1 = slope(diagonal1[0], diagonal1[1])
  194. slope2 = slope(diagonal2[0], diagonal2[1])
  195. # 计算对角线与x轴的夹角(以弧度为单位)
  196. angle1_radians = math.atan(slope1)
  197. angle2_radians = math.atan(slope2)
  198. # 将弧度转换为角度
  199. angle1_degrees = math.degrees(angle1_radians)
  200. angle2_degrees = math.degrees(angle2_radians)
  201. # 取两条对角线与x轴夹角的平均值
  202. average_angle_degrees = abs((angle1_degrees + angle2_degrees) / 2)
  203. # logger.info(f"average_angle_degrees: {average_angle_degrees}")
  204. return average_angle_degrees