layout_utils.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. """
  2. 布局处理工具模块
  3. 提供布局相关处理功能:
  4. - 重叠框检测与去重
  5. - 阅读顺序排序
  6. - IoU/重叠比例计算
  7. """
  8. from typing import Dict, List, Any
  9. from loguru import logger
  10. # 导入 MinerU 组件
  11. try:
  12. from mineru.utils.boxbase import calculate_iou, calculate_overlap_area_2_minbox_area_ratio
  13. MINERU_AVAILABLE = True
  14. except ImportError:
  15. MINERU_AVAILABLE = False
  16. calculate_iou = None
  17. calculate_overlap_area_2_minbox_area_ratio = None
  18. class LayoutUtils:
  19. """布局处理工具类"""
  20. @staticmethod
  21. def calculate_iou(bbox1: List[float], bbox2: List[float]) -> float:
  22. """
  23. 计算两个 bbox 的 IoU(交并比)
  24. Args:
  25. bbox1: 第一个 bbox [x1, y1, x2, y2]
  26. bbox2: 第二个 bbox [x1, y1, x2, y2]
  27. Returns:
  28. IoU 值
  29. """
  30. if MINERU_AVAILABLE and calculate_iou is not None:
  31. return calculate_iou(bbox1, bbox2)
  32. # 备用实现
  33. x_left = max(bbox1[0], bbox2[0])
  34. y_top = max(bbox1[1], bbox2[1])
  35. x_right = min(bbox1[2], bbox2[2])
  36. y_bottom = min(bbox1[3], bbox2[3])
  37. if x_right < x_left or y_bottom < y_top:
  38. return 0.0
  39. intersection_area = (x_right - x_left) * (y_bottom - y_top)
  40. bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
  41. bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
  42. if bbox1_area == 0 or bbox2_area == 0:
  43. return 0.0
  44. return intersection_area / float(bbox1_area + bbox2_area - intersection_area)
  45. @staticmethod
  46. def calculate_overlap_ratio(bbox1: List[float], bbox2: List[float]) -> float:
  47. """
  48. 计算重叠面积占小框面积的比例
  49. Args:
  50. bbox1: 第一个 bbox [x1, y1, x2, y2]
  51. bbox2: 第二个 bbox [x1, y1, x2, y2]
  52. Returns:
  53. 重叠比例
  54. """
  55. if MINERU_AVAILABLE and calculate_overlap_area_2_minbox_area_ratio is not None:
  56. return calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2)
  57. # 备用实现
  58. x_left = max(bbox1[0], bbox2[0])
  59. y_top = max(bbox1[1], bbox2[1])
  60. x_right = min(bbox1[2], bbox2[2])
  61. y_bottom = min(bbox1[3], bbox2[3])
  62. if x_right < x_left or y_bottom < y_top:
  63. return 0.0
  64. intersection_area = (x_right - x_left) * (y_bottom - y_top)
  65. area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
  66. area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
  67. min_area = min(area1, area2)
  68. if min_area == 0:
  69. return 0.0
  70. return intersection_area / min_area
  71. @staticmethod
  72. def remove_overlapping_boxes(
  73. layout_results: List[Dict[str, Any]],
  74. iou_threshold: float = 0.8,
  75. overlap_ratio_threshold: float = 0.8
  76. ) -> List[Dict[str, Any]]:
  77. """
  78. 处理重叠的布局框(参考 MinerU 的去重策略)
  79. 策略:
  80. 1. 高 IoU 重叠:保留置信度高的框
  81. 2. 包含关系:小框被大框高度包含时,保留大框并扩展边界
  82. 3. 同类型优先合并
  83. Args:
  84. layout_results: Layout 检测结果列表
  85. iou_threshold: IoU 阈值,超过此值认为高度重叠
  86. overlap_ratio_threshold: 重叠面积占小框面积的比例阈值
  87. Returns:
  88. 去重后的布局结果列表
  89. """
  90. if not layout_results or len(layout_results) <= 1:
  91. return layout_results
  92. # 复制列表避免修改原数据
  93. results = [item.copy() for item in layout_results]
  94. need_remove = set()
  95. for i in range(len(results)):
  96. if i in need_remove:
  97. continue
  98. for j in range(i + 1, len(results)):
  99. if j in need_remove:
  100. continue
  101. bbox1 = results[i].get('bbox', [0, 0, 0, 0])
  102. bbox2 = results[j].get('bbox', [0, 0, 0, 0])
  103. if len(bbox1) < 4 or len(bbox2) < 4:
  104. continue
  105. # 计算 IoU
  106. iou = LayoutUtils.calculate_iou(bbox1, bbox2)
  107. if iou > iou_threshold:
  108. # 高度重叠,保留置信度高的
  109. score1 = results[i].get('confidence', results[i].get('score', 0))
  110. score2 = results[j].get('confidence', results[j].get('score', 0))
  111. if score1 >= score2:
  112. need_remove.add(j)
  113. else:
  114. need_remove.add(i)
  115. break # i 被移除,跳出内层循环
  116. else:
  117. # 检查包含关系
  118. overlap_ratio = LayoutUtils.calculate_overlap_ratio(bbox1, bbox2)
  119. if overlap_ratio > overlap_ratio_threshold:
  120. # 小框被大框高度包含
  121. area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
  122. area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
  123. if area1 <= area2:
  124. small_idx, large_idx = i, j
  125. else:
  126. small_idx, large_idx = j, i
  127. # 扩展大框的边界
  128. small_bbox = results[small_idx]['bbox']
  129. large_bbox = results[large_idx]['bbox']
  130. results[large_idx]['bbox'] = [
  131. min(small_bbox[0], large_bbox[0]),
  132. min(small_bbox[1], large_bbox[1]),
  133. max(small_bbox[2], large_bbox[2]),
  134. max(small_bbox[3], large_bbox[3])
  135. ]
  136. need_remove.add(small_idx)
  137. if small_idx == i:
  138. break # i 被移除,跳出内层循环
  139. # 返回去重后的结果
  140. return [results[i] for i in range(len(results)) if i not in need_remove]
  141. @staticmethod
  142. def sort_elements_by_reading_order(
  143. elements: List[Dict[str, Any]],
  144. y_tolerance: float = 15.0
  145. ) -> List[Dict[str, Any]]:
  146. """
  147. 根据阅读顺序对元素进行排序,并添加 reading_order 字段
  148. 排序规则:
  149. 1. 按Y坐标分行(考虑容差,Y坐标相近的元素视为同一行)
  150. 2. 同一行内按X坐标从左到右排序
  151. 3. 行与行之间按Y坐标从上到下排序
  152. Args:
  153. elements: 元素列表(坐标已转换为原始图片坐标系)
  154. y_tolerance: Y坐标容差,在此范围内的元素被视为同一行
  155. Returns:
  156. 排序后的元素列表,每个元素都添加了 reading_order 字段
  157. """
  158. if not elements:
  159. return elements
  160. # 为每个元素提取排序用的坐标
  161. elements_with_coords = []
  162. for elem in elements:
  163. bbox = elem.get('bbox', [0, 0, 0, 0])
  164. if len(bbox) >= 4:
  165. y_top = bbox[1] # 上边界
  166. x_left = bbox[0] # 左边界
  167. else:
  168. y_top = 0
  169. x_left = 0
  170. elements_with_coords.append((elem, y_top, x_left))
  171. # 先按Y坐标排序
  172. elements_with_coords.sort(key=lambda x: (x[1], x[2]))
  173. # 按Y坐标分行
  174. rows = []
  175. current_row = []
  176. current_row_y = None
  177. for elem, y_top, x_left in elements_with_coords:
  178. if current_row_y is None:
  179. # 第一个元素
  180. current_row.append((elem, x_left))
  181. current_row_y = y_top
  182. elif abs(y_top - current_row_y) <= y_tolerance:
  183. # 同一行
  184. current_row.append((elem, x_left))
  185. else:
  186. # 新的一行
  187. rows.append(current_row)
  188. current_row = [(elem, x_left)]
  189. current_row_y = y_top
  190. # 添加最后一行
  191. if current_row:
  192. rows.append(current_row)
  193. # 每行内按X坐标排序,然后展平
  194. sorted_elements = []
  195. reading_order = 0
  196. for row in rows:
  197. # 行内按X坐标排序
  198. row.sort(key=lambda x: x[1])
  199. for elem, _ in row:
  200. # 添加 reading_order 字段
  201. elem['reading_order'] = reading_order
  202. sorted_elements.append(elem)
  203. reading_order += 1
  204. logger.debug(f"📖 Elements sorted by reading order: {len(sorted_elements)} elements")
  205. return sorted_elements