utils.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Tuple
  15. import numpy as np
  16. from ..result_v2 import LayoutParsingBlock, LayoutParsingRegion
  17. from ..setting import BLOCK_LABEL_MAP, XYCUT_SETTINGS
  18. from ..utils import calculate_projection_overlap_ratio
  19. def get_nearest_edge_distance(
  20. bbox1: List[int],
  21. bbox2: List[int],
  22. weight: List[float] = [1.0, 1.0, 1.0, 1.0],
  23. ) -> Tuple[float]:
  24. """
  25. Calculate the nearest edge distance between two bounding boxes, considering directional weights.
  26. Args:
  27. bbox1 (list): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  28. bbox2 (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  29. weight (list, optional): directional weights for the edge distances [left, right, up, down]. Defaults to [1, 1, 1, 1].
  30. Returns:
  31. float: The calculated minimum edge distance between the bounding boxes.
  32. """
  33. x1, y1, x2, y2 = bbox1
  34. x1_prime, y1_prime, x2_prime, y2_prime = bbox2
  35. min_x_distance, min_y_distance = 0, 0
  36. horizontal_iou = calculate_projection_overlap_ratio(bbox1, bbox2, "horizontal")
  37. vertical_iou = calculate_projection_overlap_ratio(bbox1, bbox2, "vertical")
  38. if horizontal_iou > 0 and vertical_iou > 0:
  39. return 0.0
  40. if horizontal_iou == 0:
  41. min_x_distance = min(abs(x1 - x2_prime), abs(x2 - x1_prime)) * (
  42. weight[0] if x2 < x1_prime else weight[1]
  43. )
  44. if vertical_iou == 0:
  45. min_y_distance = min(abs(y1 - y2_prime), abs(y2 - y1_prime)) * (
  46. weight[2] if y2 < y1_prime else weight[3]
  47. )
  48. return min_x_distance + min_y_distance
  49. def projection_by_bboxes(boxes: np.ndarray, axis: int) -> np.ndarray:
  50. """
  51. Generate a 1D projection histogram from bounding boxes along a specified axis.
  52. Args:
  53. boxes: A (N, 4) array of bounding boxes defined by [x_min, y_min, x_max, y_max].
  54. axis: Axis for projection; 0 for horizontal (x-axis), 1 for vertical (y-axis).
  55. Returns:
  56. A 1D numpy array representing the projection histogram based on bounding box intervals.
  57. """
  58. assert axis in [0, 1]
  59. if np.min(boxes[:, axis::2]) < 0:
  60. max_length = abs(np.min(boxes[:, axis::2]))
  61. else:
  62. max_length = np.max(boxes[:, axis::2])
  63. projection = np.zeros(max_length, dtype=int)
  64. # Increment projection histogram over the interval defined by each bounding box
  65. for start, end in boxes[:, axis::2]:
  66. start = abs(start)
  67. end = abs(end)
  68. projection[start:end] += 1
  69. return projection
  70. def split_projection_profile(arr_values: np.ndarray, min_value: float, min_gap: float):
  71. """
  72. Split the projection profile into segments based on specified thresholds.
  73. Args:
  74. arr_values: 1D array representing the projection profile.
  75. min_value: Minimum value threshold to consider a profile segment significant.
  76. min_gap: Minimum gap width to consider a separation between segments.
  77. Returns:
  78. A tuple of start and end indices for each segment that meets the criteria.
  79. """
  80. # Identify indices where the projection exceeds the minimum value
  81. significant_indices = np.where(arr_values > min_value)[0]
  82. if not len(significant_indices):
  83. return
  84. # Calculate gaps between significant indices
  85. index_diffs = significant_indices[1:] - significant_indices[:-1]
  86. gap_indices = np.where(index_diffs > min_gap)[0]
  87. # Determine start and end indices of segments
  88. segment_starts = np.insert(
  89. significant_indices[gap_indices + 1],
  90. 0,
  91. significant_indices[0],
  92. )
  93. segment_ends = np.append(
  94. significant_indices[gap_indices],
  95. significant_indices[-1] + 1,
  96. )
  97. return segment_starts, segment_ends
  98. def recursive_yx_cut(
  99. boxes: np.ndarray, indices: List[int], res: List[int], min_gap: int = 1
  100. ):
  101. """
  102. Recursively project and segment bounding boxes, starting with Y-axis and followed by X-axis.
  103. Args:
  104. boxes: A (N, 4) array representing bounding boxes.
  105. indices: List of indices indicating the original position of boxes.
  106. res: List to store indices of the final segmented bounding boxes.
  107. min_gap (int): Minimum gap width to consider a separation between segments on the X-axis. Defaults to 1.
  108. Returns:
  109. None: This function modifies the `res` list in place.
  110. """
  111. assert len(boxes) == len(
  112. indices
  113. ), "The length of boxes and indices must be the same."
  114. # Sort by y_min for Y-axis projection
  115. y_sorted_indices = boxes[:, 1].argsort()
  116. y_sorted_boxes = boxes[y_sorted_indices]
  117. y_sorted_indices = np.array(indices)[y_sorted_indices]
  118. # Perform Y-axis projection
  119. y_projection = projection_by_bboxes(boxes=y_sorted_boxes, axis=1)
  120. y_intervals = split_projection_profile(y_projection, 0, 1)
  121. if not y_intervals:
  122. return
  123. # Process each segment defined by Y-axis projection
  124. for y_start, y_end in zip(*y_intervals):
  125. # Select boxes within the current y interval
  126. y_interval_indices = (y_start <= y_sorted_boxes[:, 1]) & (
  127. y_sorted_boxes[:, 1] < y_end
  128. )
  129. y_boxes_chunk = y_sorted_boxes[y_interval_indices]
  130. y_indices_chunk = y_sorted_indices[y_interval_indices]
  131. # Sort by x_min for X-axis projection
  132. x_sorted_indices = y_boxes_chunk[:, 0].argsort()
  133. x_sorted_boxes_chunk = y_boxes_chunk[x_sorted_indices]
  134. x_sorted_indices_chunk = y_indices_chunk[x_sorted_indices]
  135. # Perform X-axis projection
  136. x_projection = projection_by_bboxes(boxes=x_sorted_boxes_chunk, axis=0)
  137. x_intervals = split_projection_profile(x_projection, 0, min_gap)
  138. if not x_intervals:
  139. continue
  140. # If X-axis cannot be further segmented, add current indices to results
  141. if len(x_intervals[0]) == 1:
  142. res.extend(x_sorted_indices_chunk)
  143. continue
  144. if np.min(x_sorted_boxes_chunk[:, 0]) < 0:
  145. x_intervals = np.flip(x_intervals, axis=1)
  146. # Recursively process each segment defined by X-axis projection
  147. for x_start, x_end in zip(*x_intervals):
  148. x_interval_indices = (x_start <= abs(x_sorted_boxes_chunk[:, 0])) & (
  149. abs(x_sorted_boxes_chunk[:, 0]) < x_end
  150. )
  151. recursive_yx_cut(
  152. x_sorted_boxes_chunk[x_interval_indices],
  153. x_sorted_indices_chunk[x_interval_indices],
  154. res,
  155. )
  156. def recursive_xy_cut(
  157. boxes: np.ndarray, indices: List[int], res: List[int], min_gap: int = 1
  158. ):
  159. """
  160. Recursively performs X-axis projection followed by Y-axis projection to segment bounding boxes.
  161. Args:
  162. boxes: A (N, 4) array representing bounding boxes with [x_min, y_min, x_max, y_max].
  163. indices: A list of indices representing the position of boxes in the original data.
  164. res: A list to store indices of bounding boxes that meet the criteria.
  165. min_gap (int): Minimum gap width to consider a separation between segments on the X-axis. Defaults to 1.
  166. Returns:
  167. None: This function modifies the `res` list in place.
  168. """
  169. # Ensure boxes and indices have the same length
  170. assert len(boxes) == len(
  171. indices
  172. ), "The length of boxes and indices must be the same."
  173. # Sort by x_min to prepare for X-axis projection
  174. x_sorted_indices = boxes[:, 0].argsort()
  175. x_sorted_boxes = boxes[x_sorted_indices]
  176. x_sorted_indices = np.array(indices)[x_sorted_indices]
  177. # Perform X-axis projection
  178. x_projection = projection_by_bboxes(boxes=x_sorted_boxes, axis=0)
  179. x_intervals = split_projection_profile(x_projection, 0, 1)
  180. if not x_intervals:
  181. return
  182. if np.min(x_sorted_boxes[:, 0]) < 0:
  183. x_intervals = np.flip(x_intervals, axis=1)
  184. # Process each segment defined by X-axis projection
  185. for x_start, x_end in zip(*x_intervals):
  186. # Select boxes within the current x interval
  187. x_interval_indices = (x_start <= abs(x_sorted_boxes[:, 0])) & (
  188. abs(x_sorted_boxes[:, 0]) < x_end
  189. )
  190. x_boxes_chunk = x_sorted_boxes[x_interval_indices]
  191. x_indices_chunk = x_sorted_indices[x_interval_indices]
  192. # Sort selected boxes by y_min to prepare for Y-axis projection
  193. y_sorted_indices = x_boxes_chunk[:, 1].argsort()
  194. y_sorted_boxes_chunk = x_boxes_chunk[y_sorted_indices]
  195. y_sorted_indices_chunk = x_indices_chunk[y_sorted_indices]
  196. # Perform Y-axis projection
  197. y_projection = projection_by_bboxes(boxes=y_sorted_boxes_chunk, axis=1)
  198. y_intervals = split_projection_profile(y_projection, 0, min_gap)
  199. if not y_intervals:
  200. continue
  201. # If Y-axis cannot be further segmented, add current indices to results
  202. if len(y_intervals[0]) == 1:
  203. res.extend(y_sorted_indices_chunk)
  204. continue
  205. # Recursively process each segment defined by Y-axis projection
  206. for y_start, y_end in zip(*y_intervals):
  207. y_interval_indices = (y_start <= y_sorted_boxes_chunk[:, 1]) & (
  208. y_sorted_boxes_chunk[:, 1] < y_end
  209. )
  210. recursive_xy_cut(
  211. y_sorted_boxes_chunk[y_interval_indices],
  212. y_sorted_indices_chunk[y_interval_indices],
  213. res,
  214. )
  215. def reference_insert(
  216. block: LayoutParsingBlock,
  217. sorted_blocks: List[LayoutParsingBlock],
  218. **kwargs,
  219. ):
  220. """
  221. Insert reference block into sorted blocks based on the distance between the block and the nearest sorted block.
  222. Args:
  223. block: The block to insert into the sorted blocks.
  224. sorted_blocks: The sorted blocks where the new block will be inserted.
  225. config: Configuration dictionary containing parameters related to the layout parsing.
  226. median_width: Median width of the document. Defaults to 0.0.
  227. Returns:
  228. sorted_blocks: The updated sorted blocks after insertion.
  229. """
  230. min_distance = float("inf")
  231. nearest_sorted_block_index = 0
  232. for sorted_block_idx, sorted_block in enumerate(sorted_blocks):
  233. if sorted_block.bbox[3] <= block.bbox[1]:
  234. distance = -(sorted_block.bbox[2] * 10 + sorted_block.bbox[3])
  235. if distance < min_distance:
  236. min_distance = distance
  237. nearest_sorted_block_index = sorted_block_idx
  238. sorted_blocks.insert(nearest_sorted_block_index + 1, block)
  239. return sorted_blocks
  240. def manhattan_insert(
  241. block: LayoutParsingBlock,
  242. sorted_blocks: List[LayoutParsingBlock],
  243. **kwargs,
  244. ):
  245. """
  246. Insert a block into a sorted list of blocks based on the Manhattan distance between the block and the nearest sorted block.
  247. Args:
  248. block: The block to insert into the sorted blocks.
  249. sorted_blocks: The sorted blocks where the new block will be inserted.
  250. config: Configuration dictionary containing parameters related to the layout parsing.
  251. median_width: Median width of the document. Defaults to 0.0.
  252. Returns:
  253. sorted_blocks: The updated sorted blocks after insertion.
  254. """
  255. min_distance = float("inf")
  256. nearest_sorted_block_index = 0
  257. for sorted_block_idx, sorted_block in enumerate(sorted_blocks):
  258. distance = _manhattan_distance(block.bbox, sorted_block.bbox)
  259. if distance < min_distance:
  260. min_distance = distance
  261. nearest_sorted_block_index = sorted_block_idx
  262. sorted_blocks.insert(nearest_sorted_block_index + 1, block)
  263. return sorted_blocks
  264. def weighted_distance_insert(
  265. block: LayoutParsingBlock,
  266. sorted_blocks: List[LayoutParsingBlock],
  267. region: LayoutParsingRegion,
  268. ):
  269. """
  270. Insert a block into a sorted list of blocks based on the weighted distance between the block and the nearest sorted block.
  271. Args:
  272. block: The block to insert into the sorted blocks.
  273. sorted_blocks: The sorted blocks where the new block will be inserted.
  274. config: Configuration dictionary containing parameters related to the layout parsing.
  275. median_width: Median width of the document. Defaults to 0.0.
  276. Returns:
  277. sorted_blocks: The updated sorted blocks after insertion.
  278. """
  279. tolerance_len = XYCUT_SETTINGS["edge_distance_compare_tolerance_len"]
  280. x1, y1, x2, y2 = block.bbox
  281. min_weighted_distance, min_edge_distance, min_up_edge_distance = (
  282. float("inf"),
  283. float("inf"),
  284. float("inf"),
  285. )
  286. nearest_sorted_block_index = 0
  287. for sorted_block_idx, sorted_block in enumerate(sorted_blocks):
  288. x1_prime, y1_prime, x2_prime, y2_prime = sorted_block.bbox
  289. # Calculate edge distance
  290. weight = _get_weights(block.order_label, block.direction)
  291. edge_distance = get_nearest_edge_distance(block.bbox, sorted_block.bbox, weight)
  292. if block.label in BLOCK_LABEL_MAP["doc_title_labels"]:
  293. disperse = max(1, region.text_line_width)
  294. tolerance_len = max(tolerance_len, disperse)
  295. if block.label == "abstract":
  296. tolerance_len *= 2
  297. edge_distance = max(0.1, edge_distance) * 10
  298. # Calculate up edge distances
  299. up_edge_distance = y1_prime if region.direction == "horizontal" else -x2_prime
  300. left_edge_distance = x1_prime if region.direction == "horizontal" else y1_prime
  301. is_below_sorted_block = (
  302. y2_prime < y1 if region.direction == "horizontal" else x1_prime > x2
  303. )
  304. if (
  305. block.label not in BLOCK_LABEL_MAP["unordered_labels"]
  306. or block.label in BLOCK_LABEL_MAP["doc_title_labels"]
  307. or block.label in BLOCK_LABEL_MAP["paragraph_title_labels"]
  308. or block.label in BLOCK_LABEL_MAP["vision_labels"]
  309. ) and is_below_sorted_block:
  310. up_edge_distance = -up_edge_distance
  311. left_edge_distance = -left_edge_distance
  312. if abs(min_up_edge_distance - up_edge_distance) <= tolerance_len:
  313. up_edge_distance = min_up_edge_distance
  314. # Calculate weighted distance
  315. weighted_distance = (
  316. +edge_distance
  317. * XYCUT_SETTINGS["distance_weight_map"].get("edge_weight", 10**4)
  318. + up_edge_distance
  319. * XYCUT_SETTINGS["distance_weight_map"].get("up_edge_weight", 1)
  320. + left_edge_distance
  321. * XYCUT_SETTINGS["distance_weight_map"].get("left_edge_weight", 0.0001)
  322. )
  323. min_edge_distance = min(edge_distance, min_edge_distance)
  324. min_up_edge_distance = min(up_edge_distance, min_up_edge_distance)
  325. if weighted_distance < min_weighted_distance:
  326. nearest_sorted_block_index = sorted_block_idx
  327. min_weighted_distance = weighted_distance
  328. if y1 > y1_prime or (y1 == y1_prime and x1 > x1_prime):
  329. nearest_sorted_block_index = sorted_block_idx + 1
  330. sorted_blocks.insert(nearest_sorted_block_index, block)
  331. return sorted_blocks
  332. def insert_child_blocks(
  333. block: LayoutParsingBlock,
  334. block_idx: int,
  335. sorted_blocks: List[LayoutParsingBlock],
  336. ) -> List[LayoutParsingBlock]:
  337. """
  338. Insert child blocks of a block into the sorted blocks list.
  339. Args:
  340. block: The parent block whose child blocks need to be inserted.
  341. block_idx: Index at which the parent block exists in the sorted blocks list.
  342. sorted_blocks: Sorted blocks list where the child blocks are to be inserted.
  343. Returns:
  344. sorted_blocks: Updated sorted blocks list after inserting child blocks.
  345. """
  346. if block.child_blocks:
  347. sub_blocks = block.get_child_blocks()
  348. sub_blocks.append(block)
  349. sub_blocks = sort_child_blocks(sub_blocks, sub_blocks[0].direction)
  350. sorted_blocks[block_idx] = sub_blocks[0]
  351. for block in sub_blocks[1:]:
  352. block_idx += 1
  353. sorted_blocks.insert(block_idx, block)
  354. return sorted_blocks
  355. def sort_child_blocks(blocks, direction="horizontal") -> List[LayoutParsingBlock]:
  356. """
  357. Sort child blocks based on their bounding box coordinates.
  358. Args:
  359. blocks: A list of LayoutParsingBlock objects representing the child blocks.
  360. direction: direction of the blocks ('horizontal' or 'vertical'). Default is 'horizontal'.
  361. Returns:
  362. sorted_blocks: A sorted list of LayoutParsingBlock objects.
  363. """
  364. if direction == "horizontal":
  365. # from top to bottom
  366. blocks.sort(
  367. key=lambda x: (
  368. x.bbox[1], # y_min
  369. x.bbox[0], # x_min
  370. x.bbox[1] ** 2 + x.bbox[0] ** 2, # distance with (0,0)
  371. ),
  372. )
  373. else:
  374. # from right to left
  375. blocks.sort(
  376. key=lambda x: (
  377. -x.bbox[0], # x_min
  378. x.bbox[1], # y_min
  379. x.bbox[1] ** 2 - x.bbox[0] ** 2, # distance with (max,0)
  380. ),
  381. )
  382. return blocks
  383. def _get_weights(label, direction="horizontal"):
  384. """Define weights based on the label and direction."""
  385. if label == "doc_title":
  386. return (
  387. [1, 0.1, 0.1, 1] if direction == "horizontal" else [0.2, 0.1, 1, 1]
  388. ) # left-down , right-left
  389. elif label in [
  390. "paragraph_title",
  391. "table_title",
  392. "abstract",
  393. "image",
  394. "seal",
  395. "chart",
  396. "figure",
  397. ]:
  398. return [1, 1, 0.1, 1] # down
  399. else:
  400. return [1, 1, 1, 0.1] # up
  401. def _manhattan_distance(
  402. point1: Tuple[float, float],
  403. point2: Tuple[float, float],
  404. weight_x: float = 1.0,
  405. weight_y: float = 1.0,
  406. ) -> float:
  407. """
  408. Calculate the weighted Manhattan distance between two points.
  409. Args:
  410. point1 (Tuple[float, float]): The first point as (x, y).
  411. point2 (Tuple[float, float]): The second point as (x, y).
  412. weight_x (float): The weight for the x-axis distance. Default is 1.0.
  413. weight_y (float): The weight for the y-axis distance. Default is 1.0.
  414. Returns:
  415. float: The weighted Manhattan distance between the two points.
  416. """
  417. return weight_x * abs(point1[0] - point2[0]) + weight_y * abs(point1[1] - point2[1])
  418. def sort_normal_blocks(blocks, text_line_height, text_line_width, region_direction):
  419. if region_direction == "horizontal":
  420. blocks.sort(
  421. key=lambda x: (
  422. x.bbox[1] // text_line_height,
  423. x.bbox[0] // text_line_width,
  424. x.bbox[1] ** 2 + x.bbox[0] ** 2,
  425. ),
  426. )
  427. else:
  428. blocks.sort(
  429. key=lambda x: (
  430. -x.bbox[0] // text_line_width,
  431. x.bbox[1] // text_line_height,
  432. x.bbox[1] ** 2 - x.bbox[2] ** 2, # distance with (max,0)
  433. ),
  434. )
  435. return blocks
  436. def sort_normal_blocks(blocks, text_line_height, text_line_width, region_direction):
  437. if region_direction == "horizontal":
  438. blocks.sort(
  439. key=lambda x: (
  440. x.bbox[1] // text_line_height,
  441. x.bbox[0] // text_line_width,
  442. x.bbox[1] ** 2 + x.bbox[0] ** 2,
  443. ),
  444. )
  445. else:
  446. blocks.sort(
  447. key=lambda x: (
  448. -x.bbox[0] // text_line_width,
  449. x.bbox[1] // text_line_height,
  450. -(x.bbox[2] ** 2 + x.bbox[1] ** 2),
  451. ),
  452. )
  453. return blocks
  454. def get_cut_blocks(blocks, cut_direction, cut_coordinates, mask_labels=[]):
  455. """
  456. Cut blocks based on the given cut direction and coordinates.
  457. Args:
  458. blocks (list): list of blocks to be cut.
  459. cut_direction (str): cut direction, either "horizontal" or "vertical".
  460. cut_coordinates (list): list of cut coordinates.
  461. Returns:
  462. list: a list of tuples containing the cutted blocks and their corresponding mean width。
  463. """
  464. cuted_list = []
  465. # filter out mask blocks,including header, footer, unordered and child_blocks
  466. # 0: horizontal, 1: vertical
  467. cut_aixis = 0 if cut_direction == "horizontal" else 1
  468. blocks.sort(key=lambda x: x.bbox[cut_aixis + 2])
  469. cut_coordinates.append(float("inf"))
  470. cut_coordinates = list(set(cut_coordinates))
  471. cut_coordinates.sort()
  472. cut_idx = 0
  473. for cut_coordinate in cut_coordinates:
  474. group_blocks = []
  475. block_idx = cut_idx
  476. while block_idx < len(blocks):
  477. block = blocks[block_idx]
  478. if block.bbox[cut_aixis + 2] > cut_coordinate:
  479. break
  480. elif block.order_label not in mask_labels:
  481. group_blocks.append(block)
  482. block_idx += 1
  483. cut_idx = block_idx
  484. if group_blocks:
  485. cuted_list.append(group_blocks)
  486. return cuted_list
  487. def add_split_block(
  488. blocks: List[LayoutParsingBlock], region_bbox: List[int]
  489. ) -> List[LayoutParsingBlock]:
  490. block_bboxes = np.array([block.bbox for block in blocks])
  491. discontinuous = calculate_discontinuous_projection(
  492. block_bboxes, direction="vertical"
  493. )
  494. current_interval = discontinuous[0]
  495. for interval in discontinuous[1:]:
  496. gap_len = interval[0] - current_interval[1]
  497. if gap_len > 40:
  498. x1, _, x2, __ = region_bbox
  499. y1 = current_interval[1] + 5
  500. y2 = interval[0] - 5
  501. bbox = [x1, y1, x2, y2]
  502. split_block = LayoutParsingBlock(label="split", bbox=bbox)
  503. blocks.append(split_block)
  504. current_interval = interval
  505. def get_nearest_blocks(
  506. block: LayoutParsingBlock,
  507. ref_blocks: List[LayoutParsingBlock],
  508. overlap_threshold,
  509. direction="horizontal",
  510. ) -> List:
  511. """
  512. Get the adjacent blocks with the same direction as the current block.
  513. Args:
  514. block (LayoutParsingBlock): The current block.
  515. blocks (List[LayoutParsingBlock]): A list of all blocks.
  516. ref_block_idxes (List[int]): A list of indices of reference blocks.
  517. iou_threshold (float): The IOU threshold to determine if two blocks are considered adjacent.
  518. Returns:
  519. Int: The index of the previous block with same direction.
  520. Int: The index of the following block with same direction.
  521. """
  522. prev_blocks: List[LayoutParsingBlock] = []
  523. post_blocks: List[LayoutParsingBlock] = []
  524. sort_index = 1 if direction == "horizontal" else 0
  525. for ref_block in ref_blocks:
  526. if ref_block.index == block.index:
  527. continue
  528. overlap_ratio = calculate_projection_overlap_ratio(
  529. block.bbox, ref_block.bbox, direction, mode="small"
  530. )
  531. if overlap_ratio > overlap_threshold:
  532. if ref_block.bbox[sort_index] <= block.bbox[sort_index]:
  533. prev_blocks.append(ref_block)
  534. else:
  535. post_blocks.append(ref_block)
  536. if prev_blocks:
  537. prev_blocks.sort(key=lambda x: x.bbox[sort_index], reverse=True)
  538. if post_blocks:
  539. post_blocks.sort(key=lambda x: x.bbox[sort_index])
  540. return prev_blocks, post_blocks
  541. def get_adjacent_blocks_by_direction(
  542. blocks: List[LayoutParsingBlock],
  543. block_idx: int,
  544. ref_block_idxes: List[int],
  545. iou_threshold,
  546. ) -> List:
  547. """
  548. Get the adjacent blocks with the same direction as the current block.
  549. Args:
  550. block (LayoutParsingBlock): The current block.
  551. blocks (List[LayoutParsingBlock]): A list of all blocks.
  552. ref_block_idxes (List[int]): A list of indices of reference blocks.
  553. iou_threshold (float): The IOU threshold to determine if two blocks are considered adjacent.
  554. Returns:
  555. Int: The index of the previous block with same direction.
  556. Int: The index of the following block with same direction.
  557. """
  558. min_prev_block_distance = float("inf")
  559. prev_block_index = None
  560. min_post_block_distance = float("inf")
  561. post_block_index = None
  562. block = blocks[block_idx]
  563. child_labels = [
  564. "vision_footnote",
  565. "sub_paragraph_title",
  566. "doc_title_text",
  567. "vision_title",
  568. ]
  569. # find the nearest text block with same direction to the current block
  570. for ref_block_idx in ref_block_idxes:
  571. ref_block = blocks[ref_block_idx]
  572. ref_block_direction = ref_block.direction
  573. if ref_block.order_label in child_labels:
  574. continue
  575. match_block_iou = calculate_projection_overlap_ratio(
  576. block.bbox,
  577. ref_block.bbox,
  578. ref_block_direction,
  579. )
  580. child_match_distance_tolerance_len = block.short_side_length / 10
  581. if block.order_label == "vision":
  582. if ref_block.num_of_lines == 1:
  583. gap_tolerance_len = ref_block.short_side_length * 2
  584. else:
  585. gap_tolerance_len = block.short_side_length / 10
  586. else:
  587. gap_tolerance_len = block.short_side_length * 2
  588. if match_block_iou >= iou_threshold:
  589. prev_distance = (
  590. block.secondary_direction_start_coordinate
  591. - ref_block.secondary_direction_end_coordinate
  592. + child_match_distance_tolerance_len
  593. ) // 5 + ref_block.start_coordinate / 5000
  594. next_distance = (
  595. ref_block.secondary_direction_start_coordinate
  596. - block.secondary_direction_end_coordinate
  597. + child_match_distance_tolerance_len
  598. ) // 5 + ref_block.start_coordinate / 5000
  599. if (
  600. ref_block.secondary_direction_end_coordinate
  601. <= block.secondary_direction_start_coordinate
  602. + child_match_distance_tolerance_len
  603. and prev_distance < min_prev_block_distance
  604. ):
  605. min_prev_block_distance = prev_distance
  606. if (
  607. block.secondary_direction_start_coordinate
  608. - ref_block.secondary_direction_end_coordinate
  609. < gap_tolerance_len
  610. ):
  611. prev_block_index = ref_block_idx
  612. elif (
  613. ref_block.secondary_direction_start_coordinate
  614. > block.secondary_direction_end_coordinate
  615. - child_match_distance_tolerance_len
  616. and next_distance < min_post_block_distance
  617. ):
  618. min_post_block_distance = next_distance
  619. if (
  620. ref_block.secondary_direction_start_coordinate
  621. - block.secondary_direction_end_coordinate
  622. < gap_tolerance_len
  623. ):
  624. post_block_index = ref_block_idx
  625. diff_dist = abs(min_prev_block_distance - min_post_block_distance)
  626. # if the difference in distance is too large, only consider the nearest one
  627. if diff_dist * 5 > block.short_side_length:
  628. if min_prev_block_distance < min_post_block_distance:
  629. post_block_index = None
  630. else:
  631. prev_block_index = None
  632. return prev_block_index, post_block_index
  633. def update_doc_title_child_blocks(
  634. block: LayoutParsingBlock,
  635. region: LayoutParsingRegion,
  636. ) -> None:
  637. """
  638. Update the child blocks of a document title block.
  639. The child blocks need to meet the following conditions:
  640. 1. They must be adjacent
  641. 2. They must have the same direction as the parent block.
  642. 3. Their short side length should be less than 80% of the parent's short side length.
  643. 4. Their long side length should be less than 150% of the parent's long side length.
  644. 5. The child block must be text block.
  645. 6. The nearest edge distance should be less than 2 times of the text line height.
  646. Args:
  647. blocks (List[LayoutParsingBlock]): overall blocks.
  648. block (LayoutParsingBlock): document title block.
  649. prev_idx (int): previous block index, None if not exist.
  650. post_idx (int): post block index, None if not exist.
  651. config (dict): configurations.
  652. Returns:
  653. None
  654. """
  655. ref_blocks = [region.block_map[idx] for idx in region.normal_text_block_idxes]
  656. overlap_threshold = XYCUT_SETTINGS["child_block_overlap_ratio_threshold"]
  657. prev_blocks, post_blocks = get_nearest_blocks(
  658. block, ref_blocks, overlap_threshold, block.direction
  659. )
  660. prev_block = None
  661. post_block = None
  662. if prev_blocks:
  663. prev_block = prev_blocks[0]
  664. if post_blocks:
  665. post_block = post_blocks[0]
  666. for ref_block in [prev_block, post_block]:
  667. if ref_block is None:
  668. continue
  669. with_seem_direction = ref_block.direction == block.direction
  670. short_side_length_condition = (
  671. ref_block.short_side_length < block.short_side_length * 0.8
  672. )
  673. long_side_length_condition = (
  674. ref_block.long_side_length < block.long_side_length
  675. or ref_block.long_side_length > 1.5 * block.long_side_length
  676. )
  677. nearest_edge_distance = get_nearest_edge_distance(block.bbox, ref_block.bbox)
  678. if (
  679. with_seem_direction
  680. and ref_block.label in BLOCK_LABEL_MAP["text_labels"]
  681. and short_side_length_condition
  682. and long_side_length_condition
  683. and ref_block.num_of_lines < 3
  684. and nearest_edge_distance < ref_block.text_line_height * 2
  685. ):
  686. ref_block.order_label = "doc_title_text"
  687. block.append_child_block(ref_block)
  688. region.normal_text_block_idxes.remove(ref_block.index)
  689. def update_paragraph_title_child_blocks(
  690. block: LayoutParsingBlock,
  691. region: LayoutParsingRegion,
  692. ) -> None:
  693. """
  694. Update the child blocks of a paragraph title block.
  695. The child blocks need to meet the following conditions:
  696. 1. They must be adjacent
  697. 2. They must have the same direction as the parent block.
  698. 3. The child block must be paragraph title block.
  699. Args:
  700. blocks (List[LayoutParsingBlock]): overall blocks.
  701. block (LayoutParsingBlock): document title block.
  702. prev_idx (int): previous block index, None if not exist.
  703. post_idx (int): post block index, None if not exist.
  704. config (dict): configurations.
  705. Returns:
  706. None
  707. """
  708. if block.order_label == "sub_paragraph_title":
  709. return
  710. ref_blocks = [
  711. region.block_map[idx]
  712. for idx in region.paragraph_title_block_idxes + region.normal_text_block_idxes
  713. ]
  714. overlap_threshold = XYCUT_SETTINGS["child_block_overlap_ratio_threshold"]
  715. prev_blocks, post_blocks = get_nearest_blocks(
  716. block, ref_blocks, overlap_threshold, block.direction
  717. )
  718. for ref_blocks in [prev_blocks, post_blocks]:
  719. for ref_block in ref_blocks:
  720. if ref_block.label not in BLOCK_LABEL_MAP["paragraph_title_labels"]:
  721. break
  722. min_text_line_height = min(
  723. block.text_line_height, ref_block.text_line_height
  724. )
  725. nearest_edge_distance = get_nearest_edge_distance(
  726. block.bbox, ref_block.bbox
  727. )
  728. with_seem_direction = ref_block.direction == block.direction
  729. if (
  730. with_seem_direction
  731. and nearest_edge_distance <= min_text_line_height * 1.5
  732. ):
  733. ref_block.order_label = "sub_paragraph_title"
  734. block.append_child_block(ref_block)
  735. region.paragraph_title_block_idxes.remove(ref_block.index)
  736. def update_vision_child_blocks(
  737. block: LayoutParsingBlock,
  738. region: LayoutParsingRegion,
  739. ) -> None:
  740. """
  741. Update the child blocks of a paragraph title block.
  742. The child blocks need to meet the following conditions:
  743. - For Both:
  744. 1. They must be adjacent
  745. 2. The child block must be vision_title or text block.
  746. - For vision_title:
  747. 1. The distance between the child block and the parent block should be less than 1/2 of the parent's height.
  748. - For text block:
  749. 1. The distance between the child block and the parent block should be less than 15.
  750. 2. The child short_side_length should be less than the parent's short side length.
  751. 3. The child long_side_length should be less than 50% of the parent's long side length.
  752. 4. The difference between their centers is very small.
  753. Args:
  754. blocks (List[LayoutParsingBlock]): overall blocks.
  755. block (LayoutParsingBlock): document title block.
  756. ref_block_idxes (List[int]): A list of indices of reference blocks.
  757. prev_idx (int): previous block index, None if not exist.
  758. post_idx (int): post block index, None if not exist.
  759. config (dict): configurations.
  760. Returns:
  761. None
  762. """
  763. ref_blocks = [
  764. region.block_map[idx]
  765. for idx in region.normal_text_block_idxes + region.vision_title_block_idxes
  766. ]
  767. overlap_threshold = XYCUT_SETTINGS["child_block_overlap_ratio_threshold"]
  768. has_vision_footnote = False
  769. has_vision_title = False
  770. for direction in [block.direction, block.secondary_direction]:
  771. prev_blocks, post_blocks = get_nearest_blocks(
  772. block, ref_blocks, overlap_threshold, direction
  773. )
  774. for ref_block in prev_blocks:
  775. if (
  776. ref_block.label
  777. not in BLOCK_LABEL_MAP["text_labels"]
  778. + BLOCK_LABEL_MAP["vision_title_labels"]
  779. ):
  780. break
  781. nearest_edge_distance = get_nearest_edge_distance(
  782. block.bbox, ref_block.bbox
  783. )
  784. block_center = block.get_centroid()
  785. ref_block_center = ref_block.get_centroid()
  786. if (
  787. ref_block.label in BLOCK_LABEL_MAP["vision_title_labels"]
  788. and nearest_edge_distance <= ref_block.text_line_height * 2
  789. ):
  790. has_vision_title = True
  791. ref_block.order_label = "vision_title"
  792. block.append_child_block(ref_block)
  793. region.vision_title_block_idxes.remove(ref_block.index)
  794. if ref_block.label in BLOCK_LABEL_MAP["text_labels"]:
  795. if (
  796. not has_vision_footnote
  797. and ref_block.direction == block.direction
  798. and ref_block.long_side_length < block.long_side_length
  799. ):
  800. if (
  801. (
  802. nearest_edge_distance <= block.text_line_height * 2
  803. and ref_block.short_side_length < block.short_side_length
  804. and ref_block.long_side_length
  805. < 0.5 * block.long_side_length
  806. and abs(block_center[0] - ref_block_center[0]) < 10
  807. )
  808. or (
  809. block.bbox[0] - ref_block.bbox[0] < 10
  810. and ref_block.num_of_lines == 1
  811. )
  812. or (
  813. block.bbox[2] - ref_block.bbox[2] < 10
  814. and ref_block.num_of_lines == 1
  815. )
  816. ):
  817. has_vision_footnote = True
  818. ref_block.order_label = "vision_footnote"
  819. block.append_child_block(ref_block)
  820. region.normal_text_block_idxes.remove(ref_block.index)
  821. break
  822. for ref_block in post_blocks:
  823. if (
  824. has_vision_footnote
  825. and ref_block.label in BLOCK_LABEL_MAP["text_labels"]
  826. ):
  827. break
  828. nearest_edge_distance = get_nearest_edge_distance(
  829. block.bbox, ref_block.bbox
  830. )
  831. block_center = block.get_centroid()
  832. ref_block_center = ref_block.get_centroid()
  833. if (
  834. ref_block.label in BLOCK_LABEL_MAP["vision_title_labels"]
  835. and nearest_edge_distance <= ref_block.text_line_height * 2
  836. ):
  837. has_vision_title = True
  838. ref_block.order_label = "vision_title"
  839. block.append_child_block(ref_block)
  840. region.vision_title_block_idxes.remove(ref_block.index)
  841. if ref_block.label in BLOCK_LABEL_MAP["text_labels"]:
  842. if (
  843. not has_vision_footnote
  844. and nearest_edge_distance <= block.text_line_height * 2
  845. and ref_block.short_side_length < block.short_side_length
  846. and ref_block.long_side_length < 0.5 * block.long_side_length
  847. and ref_block.direction == block.direction
  848. and (
  849. abs(block_center[0] - ref_block_center[0]) < 10
  850. or (
  851. block.bbox[0] - ref_block.bbox[0] < 10
  852. and ref_block.num_of_lines == 1
  853. )
  854. or (
  855. block.bbox[2] - ref_block.bbox[2] < 10
  856. and ref_block.num_of_lines == 1
  857. )
  858. )
  859. ):
  860. has_vision_footnote = True
  861. ref_block.order_label = "vision_footnote"
  862. block.append_child_block(ref_block)
  863. region.normal_text_block_idxes.remove(ref_block.index)
  864. break
  865. if has_vision_title:
  866. break
  867. def calculate_discontinuous_projection(
  868. boxes, direction="horizontal", return_num=False
  869. ) -> List:
  870. """
  871. Calculate the discontinuous projection of boxes along the specified direction.
  872. Args:
  873. boxes (ndarray): Array of bounding boxes represented by [[x_min, y_min, x_max, y_max]].
  874. direction (str): direction along which to perform the projection ('horizontal' or 'vertical').
  875. Returns:
  876. list: List of tuples representing the merged intervals.
  877. """
  878. boxes = np.array(boxes)
  879. if direction == "horizontal":
  880. intervals = boxes[:, [0, 2]]
  881. elif direction == "vertical":
  882. intervals = boxes[:, [1, 3]]
  883. else:
  884. raise ValueError("direction must be 'horizontal' or 'vertical'")
  885. intervals = intervals[np.argsort(intervals[:, 0])]
  886. merged_intervals = []
  887. num = 1
  888. current_start, current_end = intervals[0]
  889. num_list = []
  890. for start, end in intervals[1:]:
  891. if start <= current_end:
  892. num += 1
  893. current_end = max(current_end, end)
  894. else:
  895. num_list.append(num)
  896. merged_intervals.append((current_start, current_end))
  897. num = 1
  898. current_start, current_end = start, end
  899. num_list.append(num)
  900. merged_intervals.append((current_start, current_end))
  901. if return_num:
  902. return merged_intervals, num_list
  903. return merged_intervals
  904. def is_projection_consistent(blocks, intervals, direction="horizontal"):
  905. for interval in intervals:
  906. if direction == "horizontal":
  907. start_index, stop_index = 0, 2
  908. interval_box = [interval[0], 0, interval[1], 1]
  909. else:
  910. start_index, stop_index = 1, 3
  911. interval_box = [0, interval[0], 1, interval[1]]
  912. same_interval_bboxes = []
  913. for block in blocks:
  914. overlap_ratio = calculate_projection_overlap_ratio(
  915. interval_box, block.bbox, direction=direction
  916. )
  917. if overlap_ratio > 0 and block.label in BLOCK_LABEL_MAP["text_labels"]:
  918. same_interval_bboxes.append(block.bbox)
  919. start_coordinates = [bbox[start_index] for bbox in same_interval_bboxes]
  920. if start_coordinates:
  921. min_start_coordinate = min(start_coordinates)
  922. max_start_coordinate = max(start_coordinates)
  923. is_start_consistent = (
  924. False
  925. if max_start_coordinate - min_start_coordinate
  926. >= abs(interval[0] - interval[1]) * 0.05
  927. else True
  928. )
  929. stop_coordinates = [bbox[stop_index] for bbox in same_interval_bboxes]
  930. min_stop_coordinate = min(stop_coordinates)
  931. max_stop_coordinate = max(stop_coordinates)
  932. if (
  933. max_stop_coordinate - min_stop_coordinate
  934. >= abs(interval[0] - interval[1]) * 0.05
  935. and is_start_consistent
  936. ):
  937. return False
  938. return True
  939. def shrink_overlapping_boxes(
  940. boxes, direction="horizontal", min_threshold=0, max_threshold=0.1
  941. ) -> List:
  942. """
  943. Shrink overlapping boxes along the specified direction.
  944. Args:
  945. boxes (ndarray): Array of bounding boxes represented by [[x_min, y_min, x_max, y_max]].
  946. direction (str): direction along which to perform the shrinking ('horizontal' or 'vertical').
  947. min_threshold (float): Minimum threshold for shrinking. Default is 0.
  948. max_threshold (float): Maximum threshold for shrinking. Default is 0.2.
  949. Returns:
  950. list: List of tuples representing the merged intervals.
  951. """
  952. current_block = boxes[0]
  953. for block in boxes[1:]:
  954. x1, y1, x2, y2 = current_block.bbox
  955. x1_prime, y1_prime, x2_prime, y2_prime = block.bbox
  956. cut_iou = calculate_projection_overlap_ratio(
  957. current_block.bbox, block.bbox, direction=direction
  958. )
  959. match_iou = calculate_projection_overlap_ratio(
  960. current_block.bbox,
  961. block.bbox,
  962. direction="horizontal" if direction == "vertical" else "vertical",
  963. )
  964. if direction == "vertical":
  965. if (
  966. (match_iou > 0 and cut_iou > min_threshold and cut_iou < max_threshold)
  967. or y2 == y1_prime
  968. or abs(y2 - y1_prime) <= 3
  969. ):
  970. overlap_y_min = max(y1, y1_prime)
  971. overlap_y_max = min(y2, y2_prime)
  972. split_y = int((overlap_y_min + overlap_y_max) / 2)
  973. overlap_y_min = split_y - 1
  974. overlap_y_max = split_y + 1
  975. current_block.bbox = [x1, y1, x2, overlap_y_min]
  976. block.bbox = [x1_prime, overlap_y_max, x2_prime, y2_prime]
  977. else:
  978. if (
  979. (match_iou > 0 and cut_iou > min_threshold and cut_iou < max_threshold)
  980. or x2 == x1_prime
  981. or abs(x2 - x1_prime) <= 3
  982. ):
  983. overlap_x_min = max(x1, x1_prime)
  984. overlap_x_max = min(x2, x2_prime)
  985. split_x = int((overlap_x_min + overlap_x_max) / 2)
  986. overlap_x_min = split_x - 1
  987. overlap_x_max = split_x + 1
  988. current_block.bbox = [x1, y1, overlap_x_min, y2]
  989. block.bbox = [overlap_x_max, y1_prime, x2_prime, y2_prime]
  990. current_block = block
  991. return boxes