utils.py 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict, List, Tuple, Union
  15. import numpy as np
  16. from ..result_v2 import LayoutParsingBlock
  17. def calculate_projection_iou(
  18. bbox1: List[float], bbox2: List[float], direction: str = "horizontal"
  19. ) -> float:
  20. """
  21. Calculate the IoU of lines between two bounding boxes.
  22. Args:
  23. bbox1 (List[float]): First bounding box [x_min, y_min, x_max, y_max].
  24. bbox2 (List[float]): Second bounding box [x_min, y_min, x_max, y_max].
  25. direction (str): direction of the projection, "horizontal" or "vertical".
  26. Returns:
  27. float: Line IoU. Returns 0 if there is no overlap.
  28. """
  29. start_index, end_index = 1, 3
  30. if direction == "horizontal":
  31. start_index, end_index = 0, 2
  32. intersection_start = max(bbox1[start_index], bbox2[start_index])
  33. intersection_end = min(bbox1[end_index], bbox2[end_index])
  34. overlap = intersection_end - intersection_start
  35. if overlap <= 0:
  36. return 0
  37. union_width = max(bbox1[end_index], bbox2[end_index]) - min(
  38. bbox1[start_index], bbox2[start_index]
  39. )
  40. return overlap / union_width if union_width > 0 else 0.0
  41. def calculate_iou(
  42. bbox1: Union[list, tuple],
  43. bbox2: Union[list, tuple],
  44. ) -> float:
  45. """
  46. Calculate the Intersection over Union (IoU) of two bounding boxes.
  47. Parameters:
  48. bbox1 (list or tuple): The first bounding box, format [x_min, y_min, x_max, y_max]
  49. bbox2 (list or tuple): The second bounding box, format [x_min, y_min, x_max, y_max]
  50. Returns:
  51. float: The IoU value between the two bounding boxes
  52. """
  53. x_min_inter = max(bbox1[0], bbox2[0])
  54. y_min_inter = max(bbox1[1], bbox2[1])
  55. x_max_inter = min(bbox1[2], bbox2[2])
  56. y_max_inter = min(bbox1[3], bbox2[3])
  57. inter_width = max(0, x_max_inter - x_min_inter)
  58. inter_height = max(0, y_max_inter - y_min_inter)
  59. inter_area = inter_width * inter_height
  60. bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
  61. bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
  62. union_area = bbox1_area + bbox2_area - inter_area
  63. if union_area == 0:
  64. return 0.0
  65. return inter_area / union_area
  66. def get_nearest_edge_distance(
  67. bbox1: List[int],
  68. bbox2: List[int],
  69. weight: List[float] = [1.0, 1.0, 1.0, 1.0],
  70. ) -> Tuple[float]:
  71. """
  72. Calculate the nearest edge distance between two bounding boxes, considering directional weights.
  73. Args:
  74. bbox1 (list): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  75. bbox2 (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  76. weight (list, optional): Directional weights for the edge distances [left, right, up, down]. Defaults to [1, 1, 1, 1].
  77. Returns:
  78. float: The calculated minimum edge distance between the bounding boxes.
  79. """
  80. x1, y1, x2, y2 = bbox1
  81. x1_prime, y1_prime, x2_prime, y2_prime = bbox2
  82. min_x_distance, min_y_distance = 0, 0
  83. horizontal_iou = calculate_projection_iou(bbox1, bbox2, "horizontal")
  84. vertical_iou = calculate_projection_iou(bbox1, bbox2, "vertical")
  85. if horizontal_iou > 0 and vertical_iou > 0:
  86. return 0.0
  87. if horizontal_iou == 0:
  88. min_x_distance = min(abs(x1 - x2_prime), abs(x2 - x1_prime)) * (
  89. weight[0] if x2 < x1_prime else weight[1]
  90. )
  91. if vertical_iou == 0:
  92. min_y_distance = min(abs(y1 - y2_prime), abs(y2 - y1_prime)) * (
  93. weight[2] if y2 < y1_prime else weight[3]
  94. )
  95. return min_x_distance + min_y_distance
  96. def _projection_by_bboxes(boxes: np.ndarray, axis: int) -> np.ndarray:
  97. """
  98. Generate a 1D projection histogram from bounding boxes along a specified axis.
  99. Args:
  100. boxes: A (N, 4) array of bounding boxes defined by [x_min, y_min, x_max, y_max].
  101. axis: Axis for projection; 0 for horizontal (x-axis), 1 for vertical (y-axis).
  102. Returns:
  103. A 1D numpy array representing the projection histogram based on bounding box intervals.
  104. """
  105. assert axis in [0, 1]
  106. max_length = np.max(boxes[:, axis::2])
  107. projection = np.zeros(max_length, dtype=int)
  108. # Increment projection histogram over the interval defined by each bounding box
  109. for start, end in boxes[:, axis::2]:
  110. projection[start:end] += 1
  111. return projection
  112. def _split_projection_profile(arr_values: np.ndarray, min_value: float, min_gap: float):
  113. """
  114. Split the projection profile into segments based on specified thresholds.
  115. Args:
  116. arr_values: 1D array representing the projection profile.
  117. min_value: Minimum value threshold to consider a profile segment significant.
  118. min_gap: Minimum gap width to consider a separation between segments.
  119. Returns:
  120. A tuple of start and end indices for each segment that meets the criteria.
  121. """
  122. # Identify indices where the projection exceeds the minimum value
  123. significant_indices = np.where(arr_values > min_value)[0]
  124. if not len(significant_indices):
  125. return
  126. # Calculate gaps between significant indices
  127. index_diffs = significant_indices[1:] - significant_indices[:-1]
  128. gap_indices = np.where(index_diffs > min_gap)[0]
  129. # Determine start and end indices of segments
  130. segment_starts = np.insert(
  131. significant_indices[gap_indices + 1],
  132. 0,
  133. significant_indices[0],
  134. )
  135. segment_ends = np.append(
  136. significant_indices[gap_indices],
  137. significant_indices[-1] + 1,
  138. )
  139. return segment_starts, segment_ends
  140. def recursive_yx_cut(
  141. boxes: np.ndarray, indices: List[int], res: List[int], min_gap: int = 1
  142. ):
  143. """
  144. Recursively project and segment bounding boxes, starting with Y-axis and followed by X-axis.
  145. Args:
  146. boxes: A (N, 4) array representing bounding boxes.
  147. indices: List of indices indicating the original position of boxes.
  148. res: List to store indices of the final segmented bounding boxes.
  149. min_gap (int): Minimum gap width to consider a separation between segments on the X-axis. Defaults to 1.
  150. Returns:
  151. None: This function modifies the `res` list in place.
  152. """
  153. assert len(boxes) == len(
  154. indices
  155. ), "The length of boxes and indices must be the same."
  156. # Sort by y_min for Y-axis projection
  157. y_sorted_indices = boxes[:, 1].argsort()
  158. y_sorted_boxes = boxes[y_sorted_indices]
  159. y_sorted_indices = np.array(indices)[y_sorted_indices]
  160. # Perform Y-axis projection
  161. y_projection = _projection_by_bboxes(boxes=y_sorted_boxes, axis=1)
  162. y_intervals = _split_projection_profile(y_projection, 0, 1)
  163. if not y_intervals:
  164. return
  165. # Process each segment defined by Y-axis projection
  166. for y_start, y_end in zip(*y_intervals):
  167. # Select boxes within the current y interval
  168. y_interval_indices = (y_start <= y_sorted_boxes[:, 1]) & (
  169. y_sorted_boxes[:, 1] < y_end
  170. )
  171. y_boxes_chunk = y_sorted_boxes[y_interval_indices]
  172. y_indices_chunk = y_sorted_indices[y_interval_indices]
  173. # Sort by x_min for X-axis projection
  174. x_sorted_indices = y_boxes_chunk[:, 0].argsort()
  175. x_sorted_boxes_chunk = y_boxes_chunk[x_sorted_indices]
  176. x_sorted_indices_chunk = y_indices_chunk[x_sorted_indices]
  177. # Perform X-axis projection
  178. x_projection = _projection_by_bboxes(boxes=x_sorted_boxes_chunk, axis=0)
  179. x_intervals = _split_projection_profile(x_projection, 0, min_gap)
  180. if not x_intervals:
  181. continue
  182. # If X-axis cannot be further segmented, add current indices to results
  183. if len(x_intervals[0]) == 1:
  184. res.extend(x_sorted_indices_chunk)
  185. continue
  186. # Recursively process each segment defined by X-axis projection
  187. for x_start, x_end in zip(*x_intervals):
  188. x_interval_indices = (x_start <= x_sorted_boxes_chunk[:, 0]) & (
  189. x_sorted_boxes_chunk[:, 0] < x_end
  190. )
  191. recursive_yx_cut(
  192. x_sorted_boxes_chunk[x_interval_indices],
  193. x_sorted_indices_chunk[x_interval_indices],
  194. res,
  195. )
  196. def recursive_xy_cut(
  197. boxes: np.ndarray, indices: List[int], res: List[int], min_gap: int = 1
  198. ):
  199. """
  200. Recursively performs X-axis projection followed by Y-axis projection to segment bounding boxes.
  201. Args:
  202. boxes: A (N, 4) array representing bounding boxes with [x_min, y_min, x_max, y_max].
  203. indices: A list of indices representing the position of boxes in the original data.
  204. res: A list to store indices of bounding boxes that meet the criteria.
  205. min_gap (int): Minimum gap width to consider a separation between segments on the X-axis. Defaults to 1.
  206. Returns:
  207. None: This function modifies the `res` list in place.
  208. """
  209. # Ensure boxes and indices have the same length
  210. assert len(boxes) == len(
  211. indices
  212. ), "The length of boxes and indices must be the same."
  213. # Sort by x_min to prepare for X-axis projection
  214. x_sorted_indices = boxes[:, 0].argsort()
  215. x_sorted_boxes = boxes[x_sorted_indices]
  216. x_sorted_indices = np.array(indices)[x_sorted_indices]
  217. # Perform X-axis projection
  218. x_projection = _projection_by_bboxes(boxes=x_sorted_boxes, axis=0)
  219. x_intervals = _split_projection_profile(x_projection, 0, 1)
  220. if not x_intervals:
  221. return
  222. # Process each segment defined by X-axis projection
  223. for x_start, x_end in zip(*x_intervals):
  224. # Select boxes within the current x interval
  225. x_interval_indices = (x_start <= x_sorted_boxes[:, 0]) & (
  226. x_sorted_boxes[:, 0] < x_end
  227. )
  228. x_boxes_chunk = x_sorted_boxes[x_interval_indices]
  229. x_indices_chunk = x_sorted_indices[x_interval_indices]
  230. # Sort selected boxes by y_min to prepare for Y-axis projection
  231. y_sorted_indices = x_boxes_chunk[:, 1].argsort()
  232. y_sorted_boxes_chunk = x_boxes_chunk[y_sorted_indices]
  233. y_sorted_indices_chunk = x_indices_chunk[y_sorted_indices]
  234. # Perform Y-axis projection
  235. y_projection = _projection_by_bboxes(boxes=y_sorted_boxes_chunk, axis=1)
  236. y_intervals = _split_projection_profile(y_projection, 0, min_gap)
  237. if not y_intervals:
  238. continue
  239. # If Y-axis cannot be further segmented, add current indices to results
  240. if len(y_intervals[0]) == 1:
  241. res.extend(y_sorted_indices_chunk)
  242. continue
  243. # Recursively process each segment defined by Y-axis projection
  244. for y_start, y_end in zip(*y_intervals):
  245. y_interval_indices = (y_start <= y_sorted_boxes_chunk[:, 1]) & (
  246. y_sorted_boxes_chunk[:, 1] < y_end
  247. )
  248. recursive_xy_cut(
  249. y_sorted_boxes_chunk[y_interval_indices],
  250. y_sorted_indices_chunk[y_interval_indices],
  251. res,
  252. )
  253. def reference_insert(
  254. block: LayoutParsingBlock,
  255. sorted_blocks: List[LayoutParsingBlock],
  256. config: Dict,
  257. median_width: float = 0.0,
  258. ):
  259. """
  260. Insert reference block into sorted blocks based on the distance between the block and the nearest sorted block.
  261. Args:
  262. block: The block to insert into the sorted blocks.
  263. sorted_blocks: The sorted blocks where the new block will be inserted.
  264. config: Configuration dictionary containing parameters related to the layout parsing.
  265. median_width: Median width of the document. Defaults to 0.0.
  266. Returns:
  267. sorted_blocks: The updated sorted blocks after insertion.
  268. """
  269. min_distance = float("inf")
  270. nearest_sorted_block_index = 0
  271. for sorted_block_idx, sorted_block in enumerate(sorted_blocks):
  272. if sorted_block.bbox[3] <= block.bbox[1]:
  273. distance = -(sorted_block.bbox[2] * 10 + sorted_block.bbox[3])
  274. if distance < min_distance:
  275. min_distance = distance
  276. nearest_sorted_block_index = sorted_block_idx
  277. sorted_blocks.insert(nearest_sorted_block_index + 1, block)
  278. return sorted_blocks
  279. def manhattan_insert(
  280. block: LayoutParsingBlock,
  281. sorted_blocks: List[LayoutParsingBlock],
  282. config: Dict,
  283. median_width: float = 0.0,
  284. ):
  285. """
  286. Insert a block into a sorted list of blocks based on the Manhattan distance between the block and the nearest sorted block.
  287. Args:
  288. block: The block to insert into the sorted blocks.
  289. sorted_blocks: The sorted blocks where the new block will be inserted.
  290. config: Configuration dictionary containing parameters related to the layout parsing.
  291. median_width: Median width of the document. Defaults to 0.0.
  292. Returns:
  293. sorted_blocks: The updated sorted blocks after insertion.
  294. """
  295. min_distance = float("inf")
  296. nearest_sorted_block_index = 0
  297. for sorted_block_idx, sorted_block in enumerate(sorted_blocks):
  298. distance = _manhattan_distance(block.bbox, sorted_block.bbox)
  299. if distance < min_distance:
  300. min_distance = distance
  301. nearest_sorted_block_index = sorted_block_idx
  302. sorted_blocks.insert(nearest_sorted_block_index + 1, block)
  303. return sorted_blocks
  304. def weighted_distance_insert(
  305. block: LayoutParsingBlock,
  306. sorted_blocks: List[LayoutParsingBlock],
  307. config: Dict,
  308. median_width: float = 0.0,
  309. ):
  310. """
  311. Insert a block into a sorted list of blocks based on the weighted distance between the block and the nearest sorted block.
  312. Args:
  313. block: The block to insert into the sorted blocks.
  314. sorted_blocks: The sorted blocks where the new block will be inserted.
  315. config: Configuration dictionary containing parameters related to the layout parsing.
  316. median_width: Median width of the document. Defaults to 0.0.
  317. Returns:
  318. sorted_blocks: The updated sorted blocks after insertion.
  319. """
  320. doc_title_labels = config.get("doc_title_labels", [])
  321. paragraph_title_labels = config.get("paragraph_title_labels", [])
  322. vision_labels = config.get("vision_labels", [])
  323. xy_cut_block_labels = config.get("xy_cut_block_labels", [])
  324. tolerance_len = config.get("tolerance_len", 2)
  325. x1, y1, x2, y2 = block.bbox
  326. min_weighted_distance, min_edge_distance, min_up_edge_distance = (
  327. float("inf"),
  328. float("inf"),
  329. float("inf"),
  330. )
  331. nearest_sorted_block_index = 0
  332. for sorted_block_idx, sorted_block in enumerate(sorted_blocks):
  333. x1_prime, y1_prime, x2_prime, y2_prime = sorted_block.bbox
  334. # Calculate edge distance
  335. weight = _get_weights(block.region_label, block.direction)
  336. edge_distance = get_nearest_edge_distance(block.bbox, sorted_block.bbox, weight)
  337. if block.label in doc_title_labels:
  338. disperse = max(1, median_width)
  339. tolerance_len = max(tolerance_len, disperse)
  340. if block.label == "abstract":
  341. tolerance_len *= 2
  342. edge_distance = max(0.1, edge_distance) * 10
  343. # Calculate up edge distances
  344. up_edge_distance = y1_prime
  345. left_edge_distance = x1_prime
  346. if (
  347. block.label in xy_cut_block_labels
  348. or block.label in doc_title_labels
  349. or block.label in paragraph_title_labels
  350. or block.label in vision_labels
  351. ) and y1 > y2_prime:
  352. up_edge_distance = -y2_prime
  353. left_edge_distance = -x2_prime
  354. if abs(min_up_edge_distance - up_edge_distance) <= tolerance_len:
  355. up_edge_distance = min_up_edge_distance
  356. # Calculate weighted distance
  357. weighted_distance = (
  358. +edge_distance * config.get("edge_weight", 10**4)
  359. + up_edge_distance * config.get("up_edge_weight", 1)
  360. + left_edge_distance * config.get("left_edge_weight", 0.0001)
  361. )
  362. min_edge_distance = min(edge_distance, min_edge_distance)
  363. min_up_edge_distance = min(up_edge_distance, min_up_edge_distance)
  364. if weighted_distance < min_weighted_distance:
  365. nearest_sorted_block_index = sorted_block_idx
  366. min_weighted_distance = weighted_distance
  367. if y1 > y1_prime or (y1 == y1_prime and x1 > x1_prime):
  368. nearest_sorted_block_index = sorted_block_idx + 1
  369. sorted_blocks.insert(nearest_sorted_block_index, block)
  370. return sorted_blocks
  371. def insert_child_blocks(
  372. block: LayoutParsingBlock,
  373. block_idx: int,
  374. sorted_blocks: List[LayoutParsingBlock],
  375. ) -> List[LayoutParsingBlock]:
  376. """
  377. Insert child blocks of a block into the sorted blocks list.
  378. Args:
  379. block: The parent block whose child blocks need to be inserted.
  380. block_idx: Index at which the parent block exists in the sorted blocks list.
  381. sorted_blocks: Sorted blocks list where the child blocks are to be inserted.
  382. Returns:
  383. sorted_blocks: Updated sorted blocks list after inserting child blocks.
  384. """
  385. if block.child_blocks:
  386. sub_blocks = block.get_child_blocks()
  387. sub_blocks.append(block)
  388. sub_blocks = sort_child_blocks(sub_blocks, block.direction)
  389. sorted_blocks[block_idx] = sub_blocks[0]
  390. for block in sub_blocks[1:]:
  391. block_idx += 1
  392. sorted_blocks.insert(block_idx, block)
  393. return sorted_blocks
  394. def sort_child_blocks(blocks, direction="horizontal") -> List[LayoutParsingBlock]:
  395. """
  396. Sort child blocks based on their bounding box coordinates.
  397. Args:
  398. blocks: A list of LayoutParsingBlock objects representing the child blocks.
  399. direction: Orientation of the blocks ('horizontal' or 'vertical'). Default is 'horizontal'.
  400. Returns:
  401. sorted_blocks: A sorted list of LayoutParsingBlock objects.
  402. """
  403. if direction == "horizontal":
  404. # from top to bottom
  405. blocks.sort(
  406. key=lambda x: (
  407. x.bbox[1], # y_min
  408. x.bbox[0], # x_min
  409. x.bbox[1] ** 2 + x.bbox[0] ** 2, # distance with (0,0)
  410. ),
  411. reverse=False,
  412. )
  413. else:
  414. # from right to left
  415. blocks.sort(
  416. key=lambda x: (
  417. x.bbox[0], # x_min
  418. x.bbox[1], # y_min
  419. x.bbox[1] ** 2 + x.bbox[0] ** 2, # distance with (0,0)
  420. ),
  421. reverse=True,
  422. )
  423. return blocks
  424. def _get_weights(label, dircetion="horizontal"):
  425. """Define weights based on the label and orientation."""
  426. if label == "doc_title":
  427. return (
  428. [1, 0.1, 0.1, 1] if dircetion == "horizontal" else [0.2, 0.1, 1, 1]
  429. ) # left-down , right-left
  430. elif label in [
  431. "paragraph_title",
  432. "table_title",
  433. "abstract",
  434. "image",
  435. "seal",
  436. "chart",
  437. "figure",
  438. ]:
  439. return [1, 1, 0.1, 1] # down
  440. else:
  441. return [1, 1, 1, 0.1] # up
  442. def _manhattan_distance(
  443. point1: Tuple[float, float],
  444. point2: Tuple[float, float],
  445. weight_x: float = 1.0,
  446. weight_y: float = 1.0,
  447. ) -> float:
  448. """
  449. Calculate the weighted Manhattan distance between two points.
  450. Args:
  451. point1 (Tuple[float, float]): The first point as (x, y).
  452. point2 (Tuple[float, float]): The second point as (x, y).
  453. weight_x (float): The weight for the x-axis distance. Default is 1.0.
  454. weight_y (float): The weight for the y-axis distance. Default is 1.0.
  455. Returns:
  456. float: The weighted Manhattan distance between the two points.
  457. """
  458. return weight_x * abs(point1[0] - point2[0]) + weight_y * abs(point1[1] - point2[1])
  459. def sort_blocks(blocks, median_width=None, reverse=False):
  460. """
  461. Sort blocks based on their y_min, x_min and distance with (0,0).
  462. Args:
  463. blocks (list): list of blocks to be sorted.
  464. median_width (int): the median width of the text blocks.
  465. reverse (bool, optional): whether to sort in descending order. Default is False.
  466. Returns:
  467. list: a list of sorted blocks.
  468. """
  469. if median_width is None:
  470. median_width = 1
  471. blocks.sort(
  472. key=lambda x: (
  473. x.bbox[1] // 10, # y_min
  474. x.bbox[0] // median_width, # x_min
  475. x.bbox[1] ** 2 + x.bbox[0] ** 2, # distance with (0,0)
  476. ),
  477. reverse=reverse,
  478. )
  479. return blocks
  480. def get_cut_blocks(
  481. blocks, cut_direction, cut_coordinates, overall_region_box, mask_labels=[]
  482. ):
  483. """
  484. Cut blocks based on the given cut direction and coordinates.
  485. Args:
  486. blocks (list): list of blocks to be cut.
  487. cut_direction (str): cut direction, either "horizontal" or "vertical".
  488. cut_coordinates (list): list of cut coordinates.
  489. overall_region_box (list): the overall region box that contains all blocks.
  490. Returns:
  491. list: a list of tuples containing the cutted blocks and their corresponding mean width。
  492. """
  493. cuted_list = []
  494. # filter out mask blocks,including header, footer, unordered and child_blocks
  495. # 0: horizontal, 1: vertical
  496. cut_aixis = 0 if cut_direction == "horizontal" else 1
  497. blocks.sort(key=lambda x: x.bbox[cut_aixis + 2])
  498. overall_max_axis_coordinate = overall_region_box[cut_aixis + 2]
  499. cut_coordinates.append(overall_max_axis_coordinate)
  500. cut_coordinates = list(set(cut_coordinates))
  501. cut_coordinates.sort()
  502. cut_idx = 0
  503. for cut_coordinate in cut_coordinates:
  504. group_blocks = []
  505. block_idx = cut_idx
  506. while block_idx < len(blocks):
  507. block = blocks[block_idx]
  508. if block.bbox[cut_aixis + 2] > cut_coordinate:
  509. break
  510. elif block.region_label not in mask_labels:
  511. group_blocks.append(block)
  512. block_idx += 1
  513. cut_idx = block_idx
  514. if group_blocks:
  515. cuted_list.append(group_blocks)
  516. return cuted_list
  517. def split_sub_region_blocks(
  518. blocks: List[LayoutParsingBlock],
  519. config: Dict,
  520. ) -> List:
  521. """
  522. Split blocks into sub regions based on the all layout region bbox.
  523. Args:
  524. blocks (List[LayoutParsingBlock]): A list of blocks.
  525. config (Dict): Configuration dictionary.
  526. Returns:
  527. List: A list of lists of blocks, each representing a sub region.
  528. """
  529. region_bbox = config.get("all_layout_region_box", None)
  530. x1, y1, x2, y2 = region_bbox
  531. region_width = x2 - x1
  532. region_height = y2 - y1
  533. if region_width < region_height:
  534. return [(blocks, region_bbox)]
  535. all_boxes = np.array([block.bbox for block in blocks])
  536. discontinuous = calculate_discontinuous_projection(all_boxes, direction="vertical")
  537. if len(discontinuous) > 1:
  538. cut_coordinates = []
  539. region_boxes = []
  540. current_interval = discontinuous[0]
  541. for x1, x2 in discontinuous[1:]:
  542. if x1 - current_interval[1] > 100:
  543. cut_coordinates.extend([x1, x2])
  544. region_boxes.append([x1, y1, x2, y2])
  545. current_interval = [x1, x2]
  546. region_blocks = get_cut_blocks(blocks, "vertical", cut_coordinates, region_bbox)
  547. return [region_info for region_info in zip(region_blocks, region_boxes)]
  548. else:
  549. return [(blocks, region_bbox)]
  550. def get_adjacent_blocks_by_direction(
  551. blocks: List[LayoutParsingBlock],
  552. block_idx: int,
  553. ref_block_idxes: List[int],
  554. iou_threshold,
  555. ) -> List:
  556. """
  557. Get the adjacent blocks with the same direction as the current block.
  558. Args:
  559. block (LayoutParsingBlock): The current block.
  560. blocks (List[LayoutParsingBlock]): A list of all blocks.
  561. ref_block_idxes (List[int]): A list of indices of reference blocks.
  562. iou_threshold (float): The IOU threshold to determine if two blocks are considered adjacent.
  563. Returns:
  564. Int: The index of the previous block with same direction.
  565. Int: The index of the following block with same direction.
  566. """
  567. min_prev_block_distance = float("inf")
  568. prev_block_index = None
  569. min_post_block_distance = float("inf")
  570. post_block_index = None
  571. block = blocks[block_idx]
  572. child_labels = [
  573. "vision_footnote",
  574. "sub_paragraph_title",
  575. "doc_title_text",
  576. "vision_title",
  577. ]
  578. # find the nearest text block with same direction to the current block
  579. for ref_block_idx in ref_block_idxes:
  580. ref_block = blocks[ref_block_idx]
  581. ref_block_direction = ref_block.direction
  582. if ref_block.region_label in child_labels:
  583. continue
  584. match_block_iou = calculate_projection_iou(
  585. block.bbox,
  586. ref_block.bbox,
  587. ref_block_direction,
  588. )
  589. child_match_distance_tolerance_len = block.short_side_length / 10
  590. if block.region_label == "vision":
  591. if ref_block.num_of_lines == 1:
  592. gap_tolerance_len = ref_block.short_side_length * 2
  593. else:
  594. gap_tolerance_len = block.short_side_length / 10
  595. else:
  596. gap_tolerance_len = block.short_side_length * 2
  597. if match_block_iou >= iou_threshold:
  598. prev_distance = (
  599. block.secondary_direction_start_coordinate
  600. - ref_block.secondary_direction_end_coordinate
  601. + child_match_distance_tolerance_len
  602. ) // 5 + ref_block.start_coordinate / 5000
  603. next_distance = (
  604. ref_block.secondary_direction_start_coordinate
  605. - block.secondary_direction_end_coordinate
  606. + child_match_distance_tolerance_len
  607. ) // 5 + ref_block.start_coordinate / 5000
  608. if (
  609. ref_block.secondary_direction_end_coordinate
  610. <= block.secondary_direction_start_coordinate
  611. + child_match_distance_tolerance_len
  612. and prev_distance < min_prev_block_distance
  613. ):
  614. min_prev_block_distance = prev_distance
  615. if (
  616. block.secondary_direction_start_coordinate
  617. - ref_block.secondary_direction_end_coordinate
  618. < gap_tolerance_len
  619. ):
  620. prev_block_index = ref_block_idx
  621. elif (
  622. ref_block.secondary_direction_start_coordinate
  623. > block.secondary_direction_end_coordinate
  624. - child_match_distance_tolerance_len
  625. and next_distance < min_post_block_distance
  626. ):
  627. min_post_block_distance = next_distance
  628. if (
  629. ref_block.secondary_direction_start_coordinate
  630. - block.secondary_direction_end_coordinate
  631. < gap_tolerance_len
  632. ):
  633. post_block_index = ref_block_idx
  634. diff_dist = abs(min_prev_block_distance - min_post_block_distance)
  635. # if the difference in distance is too large, only consider the nearest one
  636. if diff_dist * 5 > block.short_side_length:
  637. if min_prev_block_distance < min_post_block_distance:
  638. post_block_index = None
  639. else:
  640. prev_block_index = None
  641. return prev_block_index, post_block_index
  642. def update_doc_title_child_blocks(
  643. blocks: List[LayoutParsingBlock],
  644. block: LayoutParsingBlock,
  645. prev_idx: int,
  646. post_idx: int,
  647. config: dict,
  648. ) -> None:
  649. """
  650. Update the child blocks of a document title block.
  651. The child blocks need to meet the following conditions:
  652. 1. They must be adjacent
  653. 2. They must have the same direction as the parent block.
  654. 3. Their short side length should be less than 80% of the parent's short side length.
  655. 4. Their long side length should be less than 150% of the parent's long side length.
  656. 5. The child block must be text block.
  657. Args:
  658. blocks (List[LayoutParsingBlock]): overall blocks.
  659. block (LayoutParsingBlock): document title block.
  660. prev_idx (int): previous block index, None if not exist.
  661. post_idx (int): post block index, None if not exist.
  662. config (dict): configurations.
  663. Returns:
  664. None
  665. """
  666. for idx in [prev_idx, post_idx]:
  667. if idx is None:
  668. continue
  669. ref_block = blocks[idx]
  670. with_seem_direction = ref_block.direction == block.direction
  671. short_side_length_condition = (
  672. ref_block.short_side_length < block.short_side_length * 0.8
  673. )
  674. long_side_length_condition = (
  675. ref_block.long_side_length < block.long_side_length
  676. or ref_block.long_side_length > 1.5 * block.long_side_length
  677. )
  678. if (
  679. with_seem_direction
  680. and short_side_length_condition
  681. and long_side_length_condition
  682. and ref_block.num_of_lines < 3
  683. ):
  684. ref_block.region_label = "doc_title_text"
  685. block.append_child_block(ref_block)
  686. config["text_block_idxes"].remove(idx)
  687. def update_paragraph_title_child_blocks(
  688. blocks: List[LayoutParsingBlock],
  689. block: LayoutParsingBlock,
  690. prev_idx: int,
  691. post_idx: int,
  692. config: dict,
  693. ) -> None:
  694. """
  695. Update the child blocks of a paragraph title block.
  696. The child blocks need to meet the following conditions:
  697. 1. They must be adjacent
  698. 2. They must have the same direction as the parent block.
  699. 3. The child block must be paragraph title block.
  700. Args:
  701. blocks (List[LayoutParsingBlock]): overall blocks.
  702. block (LayoutParsingBlock): document title block.
  703. prev_idx (int): previous block index, None if not exist.
  704. post_idx (int): post block index, None if not exist.
  705. config (dict): configurations.
  706. Returns:
  707. None
  708. """
  709. paragraph_title_labels = config.get("paragraph_title_labels", [])
  710. for idx in [prev_idx, post_idx]:
  711. if idx is None:
  712. continue
  713. ref_block = blocks[idx]
  714. with_seem_direction = ref_block.direction == block.direction
  715. if with_seem_direction and ref_block.label in paragraph_title_labels:
  716. ref_block.region_label = "sub_paragraph_title"
  717. block.append_child_block(ref_block)
  718. config["paragraph_title_block_idxes"].remove(idx)
  719. def update_vision_child_blocks(
  720. blocks: List[LayoutParsingBlock],
  721. block: LayoutParsingBlock,
  722. ref_block_idxes: List[int],
  723. prev_idx: int,
  724. post_idx: int,
  725. config: dict,
  726. ) -> None:
  727. """
  728. Update the child blocks of a paragraph title block.
  729. The child blocks need to meet the following conditions:
  730. - For Both:
  731. 1. They must be adjacent
  732. 2. The child block must be vision_title or text block.
  733. - For vision_title:
  734. 1. The distance between the child block and the parent block should be less than 1/2 of the parent's height.
  735. - For text block:
  736. 1. The distance between the child block and the parent block should be less than 15.
  737. 2. The child short_side_length should be less than the parent's short side length.
  738. 3. The child long_side_length should be less than 50% of the parent's long side length.
  739. 4. The difference between their centers is very small.
  740. Args:
  741. blocks (List[LayoutParsingBlock]): overall blocks.
  742. block (LayoutParsingBlock): document title block.
  743. ref_block_idxes (List[int]): A list of indices of reference blocks.
  744. prev_idx (int): previous block index, None if not exist.
  745. post_idx (int): post block index, None if not exist.
  746. config (dict): configurations.
  747. Returns:
  748. None
  749. """
  750. vision_title_labels = config.get("vision_title_labels", [])
  751. text_labels = config.get("text_labels", [])
  752. for idx in [prev_idx, post_idx]:
  753. if idx is None:
  754. continue
  755. ref_block = blocks[idx]
  756. nearest_edge_distance = get_nearest_edge_distance(block.bbox, ref_block.bbox)
  757. block_center = block.get_centroid()
  758. ref_block_center = ref_block.get_centroid()
  759. if ref_block.label in vision_title_labels and nearest_edge_distance <= min(
  760. block.height * 0.5, ref_block.height * 2
  761. ):
  762. ref_block.region_label = "vision_title"
  763. block.append_child_block(ref_block)
  764. config["vision_title_block_idxes"].remove(idx)
  765. elif (
  766. nearest_edge_distance <= 15
  767. and ref_block.short_side_length < block.short_side_length
  768. and ref_block.long_side_length < 0.5 * block.long_side_length
  769. and ref_block.direction == block.direction
  770. and (
  771. abs(block_center[0] - ref_block_center[0]) < 10
  772. or (
  773. block.bbox[0] - ref_block.bbox[0] < 10
  774. and ref_block.num_of_lines == 1
  775. )
  776. or (
  777. block.bbox[2] - ref_block.bbox[2] < 10
  778. and ref_block.num_of_lines == 1
  779. )
  780. )
  781. ):
  782. has_vision_footnote = False
  783. if len(block.child_blocks) > 0:
  784. for child_block in block.child_blocks:
  785. if child_block.label in text_labels:
  786. has_vision_footnote = True
  787. if not has_vision_footnote:
  788. ref_block.region_label = "vision_footnote"
  789. block.append_child_block(ref_block)
  790. config["text_block_idxes"].remove(idx)
  791. def calculate_discontinuous_projection(boxes, direction="horizontal") -> List:
  792. """
  793. Calculate the discontinuous projection of boxes along the specified direction.
  794. Args:
  795. boxes (ndarray): Array of bounding boxes represented by [[x_min, y_min, x_max, y_max]].
  796. direction (str): Direction along which to perform the projection ('horizontal' or 'vertical').
  797. Returns:
  798. list: List of tuples representing the merged intervals.
  799. """
  800. if direction == "horizontal":
  801. intervals = boxes[:, [0, 2]]
  802. elif direction == "vertical":
  803. intervals = boxes[:, [1, 3]]
  804. else:
  805. raise ValueError("Direction must be 'horizontal' or 'vertical'")
  806. intervals = intervals[np.argsort(intervals[:, 0])]
  807. merged_intervals = []
  808. current_start, current_end = intervals[0]
  809. for start, end in intervals[1:]:
  810. if start <= current_end:
  811. current_end = max(current_end, end)
  812. else:
  813. merged_intervals.append((current_start, current_end))
  814. current_start, current_end = start, end
  815. merged_intervals.append((current_start, current_end))
  816. return merged_intervals
  817. def shrink_overlapping_boxes(
  818. boxes, direction="horizontal", min_threshold=0, max_threshold=0.1
  819. ) -> List:
  820. """
  821. Shrink overlapping boxes along the specified direction.
  822. Args:
  823. boxes (ndarray): Array of bounding boxes represented by [[x_min, y_min, x_max, y_max]].
  824. direction (str): Direction along which to perform the shrinking ('horizontal' or 'vertical').
  825. min_threshold (float): Minimum threshold for shrinking. Default is 0.
  826. max_threshold (float): Maximum threshold for shrinking. Default is 0.2.
  827. Returns:
  828. list: List of tuples representing the merged intervals.
  829. """
  830. current_block = boxes[0]
  831. for block in boxes[1:]:
  832. x1, y1, x2, y2 = current_block.bbox
  833. x1_prime, y1_prime, x2_prime, y2_prime = block.bbox
  834. cut_iou = calculate_projection_iou(
  835. current_block.bbox, block.bbox, direction=direction
  836. )
  837. match_iou = calculate_projection_iou(
  838. current_block.bbox,
  839. block.bbox,
  840. direction="horizontal" if direction == "vertical" else "vertical",
  841. )
  842. if direction == "vertical":
  843. if (
  844. (match_iou > 0 and cut_iou > min_threshold and cut_iou < max_threshold)
  845. or y2 == y1_prime
  846. or abs(y2 - y1_prime) <= 3
  847. ):
  848. overlap_y_min = max(y1, y1_prime)
  849. overlap_y_max = min(y2, y2_prime)
  850. split_y = int((overlap_y_min + overlap_y_max) / 2)
  851. overlap_y_min = split_y - 1
  852. overlap_y_max = split_y + 1
  853. current_block.bbox = [x1, y1, x2, overlap_y_min]
  854. block.bbox = [x1_prime, overlap_y_max, x2_prime, y2_prime]
  855. else:
  856. if (
  857. (match_iou > 0 and cut_iou > min_threshold and cut_iou < max_threshold)
  858. or x2 == x1_prime
  859. or abs(x2 - x1_prime) <= 3
  860. ):
  861. overlap_x_min = max(x1, x1_prime)
  862. overlap_x_max = min(x2, x2_prime)
  863. split_x = int((overlap_x_min + overlap_x_max) / 2)
  864. overlap_x_min = split_x - 1
  865. overlap_x_max = split_x + 1
  866. current_block.bbox = [x1, y1, overlap_x_min, y2]
  867. block.bbox = [overlap_x_max, y1_prime, x2_prime, y2_prime]
  868. current_block = block
  869. return boxes