utils.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Tuple
  15. import numpy as np
  16. from ..result_v2 import LayoutParsingBlock, LayoutParsingRegion
  17. from ..setting import BLOCK_LABEL_MAP, XYCUT_SETTINGS
  18. from ..utils import calculate_projection_overlap_ratio
  19. def get_nearest_edge_distance(
  20. bbox1: List[int],
  21. bbox2: List[int],
  22. weight: List[float] = [1.0, 1.0, 1.0, 1.0],
  23. ) -> Tuple[float]:
  24. """
  25. Calculate the nearest edge distance between two bounding boxes, considering directional weights.
  26. Args:
  27. bbox1 (list): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  28. bbox2 (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  29. weight (list, optional): directional weights for the edge distances [left, right, up, down]. Defaults to [1, 1, 1, 1].
  30. Returns:
  31. float: The calculated minimum edge distance between the bounding boxes.
  32. """
  33. x1, y1, x2, y2 = bbox1
  34. x1_prime, y1_prime, x2_prime, y2_prime = bbox2
  35. min_x_distance, min_y_distance = 0, 0
  36. horizontal_iou = calculate_projection_overlap_ratio(bbox1, bbox2, "horizontal")
  37. vertical_iou = calculate_projection_overlap_ratio(bbox1, bbox2, "vertical")
  38. if horizontal_iou > 0 and vertical_iou > 0:
  39. return 0.0
  40. if horizontal_iou == 0:
  41. min_x_distance = min(abs(x1 - x2_prime), abs(x2 - x1_prime)) * (
  42. weight[0] if x2 < x1_prime else weight[1]
  43. )
  44. if vertical_iou == 0:
  45. min_y_distance = min(abs(y1 - y2_prime), abs(y2 - y1_prime)) * (
  46. weight[2] if y2 < y1_prime else weight[3]
  47. )
  48. return min_x_distance + min_y_distance
  49. def _projection_by_bboxes(boxes: np.ndarray, axis: int) -> np.ndarray:
  50. """
  51. Generate a 1D projection histogram from bounding boxes along a specified axis.
  52. Args:
  53. boxes: A (N, 4) array of bounding boxes defined by [x_min, y_min, x_max, y_max].
  54. axis: Axis for projection; 0 for horizontal (x-axis), 1 for vertical (y-axis).
  55. Returns:
  56. A 1D numpy array representing the projection histogram based on bounding box intervals.
  57. """
  58. assert axis in [0, 1]
  59. max_length = np.max(boxes[:, axis::2])
  60. if max_length < 0:
  61. max_length = abs(np.min(boxes[:, axis::2]))
  62. projection = np.zeros(max_length, dtype=int)
  63. # Increment projection histogram over the interval defined by each bounding box
  64. for start, end in boxes[:, axis::2]:
  65. start = abs(start)
  66. end = abs(end)
  67. projection[start:end] += 1
  68. return projection
  69. def _split_projection_profile(arr_values: np.ndarray, min_value: float, min_gap: float):
  70. """
  71. Split the projection profile into segments based on specified thresholds.
  72. Args:
  73. arr_values: 1D array representing the projection profile.
  74. min_value: Minimum value threshold to consider a profile segment significant.
  75. min_gap: Minimum gap width to consider a separation between segments.
  76. Returns:
  77. A tuple of start and end indices for each segment that meets the criteria.
  78. """
  79. # Identify indices where the projection exceeds the minimum value
  80. significant_indices = np.where(arr_values > min_value)[0]
  81. if not len(significant_indices):
  82. return
  83. # Calculate gaps between significant indices
  84. index_diffs = significant_indices[1:] - significant_indices[:-1]
  85. gap_indices = np.where(index_diffs > min_gap)[0]
  86. # Determine start and end indices of segments
  87. segment_starts = np.insert(
  88. significant_indices[gap_indices + 1],
  89. 0,
  90. significant_indices[0],
  91. )
  92. segment_ends = np.append(
  93. significant_indices[gap_indices],
  94. significant_indices[-1] + 1,
  95. )
  96. return segment_starts, segment_ends
  97. def recursive_yx_cut(
  98. boxes: np.ndarray, indices: List[int], res: List[int], min_gap: int = 1
  99. ):
  100. """
  101. Recursively project and segment bounding boxes, starting with Y-axis and followed by X-axis.
  102. Args:
  103. boxes: A (N, 4) array representing bounding boxes.
  104. indices: List of indices indicating the original position of boxes.
  105. res: List to store indices of the final segmented bounding boxes.
  106. min_gap (int): Minimum gap width to consider a separation between segments on the X-axis. Defaults to 1.
  107. Returns:
  108. None: This function modifies the `res` list in place.
  109. """
  110. assert len(boxes) == len(
  111. indices
  112. ), "The length of boxes and indices must be the same."
  113. # Sort by y_min for Y-axis projection
  114. y_sorted_indices = boxes[:, 1].argsort()
  115. y_sorted_boxes = boxes[y_sorted_indices]
  116. y_sorted_indices = np.array(indices)[y_sorted_indices]
  117. # Perform Y-axis projection
  118. y_projection = _projection_by_bboxes(boxes=y_sorted_boxes, axis=1)
  119. y_intervals = _split_projection_profile(y_projection, 0, 1)
  120. if not y_intervals:
  121. return
  122. # Process each segment defined by Y-axis projection
  123. for y_start, y_end in zip(*y_intervals):
  124. # Select boxes within the current y interval
  125. y_interval_indices = (y_start <= y_sorted_boxes[:, 1]) & (
  126. y_sorted_boxes[:, 1] < y_end
  127. )
  128. y_boxes_chunk = y_sorted_boxes[y_interval_indices]
  129. y_indices_chunk = y_sorted_indices[y_interval_indices]
  130. # Sort by x_min for X-axis projection
  131. x_sorted_indices = y_boxes_chunk[:, 0].argsort()
  132. x_sorted_boxes_chunk = y_boxes_chunk[x_sorted_indices]
  133. x_sorted_indices_chunk = y_indices_chunk[x_sorted_indices]
  134. # Perform X-axis projection
  135. x_projection = _projection_by_bboxes(boxes=x_sorted_boxes_chunk, axis=0)
  136. x_intervals = _split_projection_profile(x_projection, 0, min_gap)
  137. if not x_intervals:
  138. continue
  139. # If X-axis cannot be further segmented, add current indices to results
  140. if len(x_intervals[0]) == 1:
  141. res.extend(x_sorted_indices_chunk)
  142. continue
  143. if np.min(x_sorted_boxes_chunk[:, 0]) < 0:
  144. x_intervals = np.flip(x_intervals, axis=1)
  145. # Recursively process each segment defined by X-axis projection
  146. for x_start, x_end in zip(*x_intervals):
  147. x_interval_indices = (x_start <= abs(x_sorted_boxes_chunk[:, 0])) & (
  148. abs(x_sorted_boxes_chunk[:, 0]) < x_end
  149. )
  150. recursive_yx_cut(
  151. x_sorted_boxes_chunk[x_interval_indices],
  152. x_sorted_indices_chunk[x_interval_indices],
  153. res,
  154. )
  155. def recursive_xy_cut(
  156. boxes: np.ndarray, indices: List[int], res: List[int], min_gap: int = 1
  157. ):
  158. """
  159. Recursively performs X-axis projection followed by Y-axis projection to segment bounding boxes.
  160. Args:
  161. boxes: A (N, 4) array representing bounding boxes with [x_min, y_min, x_max, y_max].
  162. indices: A list of indices representing the position of boxes in the original data.
  163. res: A list to store indices of bounding boxes that meet the criteria.
  164. min_gap (int): Minimum gap width to consider a separation between segments on the X-axis. Defaults to 1.
  165. Returns:
  166. None: This function modifies the `res` list in place.
  167. """
  168. # Ensure boxes and indices have the same length
  169. assert len(boxes) == len(
  170. indices
  171. ), "The length of boxes and indices must be the same."
  172. # Sort by x_min to prepare for X-axis projection
  173. x_sorted_indices = boxes[:, 0].argsort()
  174. x_sorted_boxes = boxes[x_sorted_indices]
  175. x_sorted_indices = np.array(indices)[x_sorted_indices]
  176. # Perform X-axis projection
  177. x_projection = _projection_by_bboxes(boxes=x_sorted_boxes, axis=0)
  178. x_intervals = _split_projection_profile(x_projection, 0, 1)
  179. if not x_intervals:
  180. return
  181. if np.min(x_sorted_boxes[:, 0]) < 0:
  182. x_intervals = np.flip(x_intervals, axis=1)
  183. # Process each segment defined by X-axis projection
  184. for x_start, x_end in zip(*x_intervals):
  185. # Select boxes within the current x interval
  186. x_interval_indices = (x_start <= abs(x_sorted_boxes[:, 0])) & (
  187. abs(x_sorted_boxes[:, 0]) < x_end
  188. )
  189. x_boxes_chunk = x_sorted_boxes[x_interval_indices]
  190. x_indices_chunk = x_sorted_indices[x_interval_indices]
  191. # Sort selected boxes by y_min to prepare for Y-axis projection
  192. y_sorted_indices = x_boxes_chunk[:, 1].argsort()
  193. y_sorted_boxes_chunk = x_boxes_chunk[y_sorted_indices]
  194. y_sorted_indices_chunk = x_indices_chunk[y_sorted_indices]
  195. # Perform Y-axis projection
  196. y_projection = _projection_by_bboxes(boxes=y_sorted_boxes_chunk, axis=1)
  197. y_intervals = _split_projection_profile(y_projection, 0, min_gap)
  198. if not y_intervals:
  199. continue
  200. # If Y-axis cannot be further segmented, add current indices to results
  201. if len(y_intervals[0]) == 1:
  202. res.extend(y_sorted_indices_chunk)
  203. continue
  204. # Recursively process each segment defined by Y-axis projection
  205. for y_start, y_end in zip(*y_intervals):
  206. y_interval_indices = (y_start <= y_sorted_boxes_chunk[:, 1]) & (
  207. y_sorted_boxes_chunk[:, 1] < y_end
  208. )
  209. recursive_xy_cut(
  210. y_sorted_boxes_chunk[y_interval_indices],
  211. y_sorted_indices_chunk[y_interval_indices],
  212. res,
  213. )
  214. def reference_insert(
  215. block: LayoutParsingBlock,
  216. sorted_blocks: List[LayoutParsingBlock],
  217. **kwargs,
  218. ):
  219. """
  220. Insert reference block into sorted blocks based on the distance between the block and the nearest sorted block.
  221. Args:
  222. block: The block to insert into the sorted blocks.
  223. sorted_blocks: The sorted blocks where the new block will be inserted.
  224. config: Configuration dictionary containing parameters related to the layout parsing.
  225. median_width: Median width of the document. Defaults to 0.0.
  226. Returns:
  227. sorted_blocks: The updated sorted blocks after insertion.
  228. """
  229. min_distance = float("inf")
  230. nearest_sorted_block_index = 0
  231. for sorted_block_idx, sorted_block in enumerate(sorted_blocks):
  232. if sorted_block.bbox[3] <= block.bbox[1]:
  233. distance = -(sorted_block.bbox[2] * 10 + sorted_block.bbox[3])
  234. if distance < min_distance:
  235. min_distance = distance
  236. nearest_sorted_block_index = sorted_block_idx
  237. sorted_blocks.insert(nearest_sorted_block_index + 1, block)
  238. return sorted_blocks
  239. def manhattan_insert(
  240. block: LayoutParsingBlock,
  241. sorted_blocks: List[LayoutParsingBlock],
  242. **kwargs,
  243. ):
  244. """
  245. Insert a block into a sorted list of blocks based on the Manhattan distance between the block and the nearest sorted block.
  246. Args:
  247. block: The block to insert into the sorted blocks.
  248. sorted_blocks: The sorted blocks where the new block will be inserted.
  249. config: Configuration dictionary containing parameters related to the layout parsing.
  250. median_width: Median width of the document. Defaults to 0.0.
  251. Returns:
  252. sorted_blocks: The updated sorted blocks after insertion.
  253. """
  254. min_distance = float("inf")
  255. nearest_sorted_block_index = 0
  256. for sorted_block_idx, sorted_block in enumerate(sorted_blocks):
  257. distance = _manhattan_distance(block.bbox, sorted_block.bbox)
  258. if distance < min_distance:
  259. min_distance = distance
  260. nearest_sorted_block_index = sorted_block_idx
  261. sorted_blocks.insert(nearest_sorted_block_index + 1, block)
  262. return sorted_blocks
  263. def weighted_distance_insert(
  264. block: LayoutParsingBlock,
  265. sorted_blocks: List[LayoutParsingBlock],
  266. region: LayoutParsingRegion,
  267. ):
  268. """
  269. Insert a block into a sorted list of blocks based on the weighted distance between the block and the nearest sorted block.
  270. Args:
  271. block: The block to insert into the sorted blocks.
  272. sorted_blocks: The sorted blocks where the new block will be inserted.
  273. config: Configuration dictionary containing parameters related to the layout parsing.
  274. median_width: Median width of the document. Defaults to 0.0.
  275. Returns:
  276. sorted_blocks: The updated sorted blocks after insertion.
  277. """
  278. tolerance_len = XYCUT_SETTINGS["edge_distance_compare_tolerance_len"]
  279. x1, y1, x2, y2 = block.bbox
  280. min_weighted_distance, min_edge_distance, min_up_edge_distance = (
  281. float("inf"),
  282. float("inf"),
  283. float("inf"),
  284. )
  285. nearest_sorted_block_index = 0
  286. for sorted_block_idx, sorted_block in enumerate(sorted_blocks):
  287. x1_prime, y1_prime, x2_prime, y2_prime = sorted_block.bbox
  288. # Calculate edge distance
  289. weight = _get_weights(block.order_label, block.direction)
  290. edge_distance = get_nearest_edge_distance(block.bbox, sorted_block.bbox, weight)
  291. if block.label in BLOCK_LABEL_MAP["doc_title_labels"]:
  292. disperse = max(1, region.text_line_width)
  293. tolerance_len = max(tolerance_len, disperse)
  294. if block.label == "abstract":
  295. tolerance_len *= 2
  296. edge_distance = max(0.1, edge_distance) * 10
  297. # Calculate up edge distances
  298. up_edge_distance = y1_prime if region.direction == "horizontal" else -x2_prime
  299. left_edge_distance = x1_prime if region.direction == "horizontal" else y1_prime
  300. is_below_sorted_block = (
  301. y2_prime < y1 if region.direction == "horizontal" else x1_prime > x2
  302. )
  303. if (
  304. block.label not in BLOCK_LABEL_MAP["unordered_labels"]
  305. or block.label in BLOCK_LABEL_MAP["doc_title_labels"]
  306. or block.label in BLOCK_LABEL_MAP["paragraph_title_labels"]
  307. or block.label in BLOCK_LABEL_MAP["vision_labels"]
  308. ) and is_below_sorted_block:
  309. up_edge_distance = -up_edge_distance
  310. left_edge_distance = -left_edge_distance
  311. if abs(min_up_edge_distance - up_edge_distance) <= tolerance_len:
  312. up_edge_distance = min_up_edge_distance
  313. # Calculate weighted distance
  314. weighted_distance = (
  315. +edge_distance
  316. * XYCUT_SETTINGS["distance_weight_map"].get("edge_weight", 10**4)
  317. + up_edge_distance
  318. * XYCUT_SETTINGS["distance_weight_map"].get("up_edge_weight", 1)
  319. + left_edge_distance
  320. * XYCUT_SETTINGS["distance_weight_map"].get("left_edge_weight", 0.0001)
  321. )
  322. min_edge_distance = min(edge_distance, min_edge_distance)
  323. min_up_edge_distance = min(up_edge_distance, min_up_edge_distance)
  324. if weighted_distance < min_weighted_distance:
  325. nearest_sorted_block_index = sorted_block_idx
  326. min_weighted_distance = weighted_distance
  327. if y1 > y1_prime or (y1 == y1_prime and x1 > x1_prime):
  328. nearest_sorted_block_index = sorted_block_idx + 1
  329. sorted_blocks.insert(nearest_sorted_block_index, block)
  330. return sorted_blocks
  331. def insert_child_blocks(
  332. block: LayoutParsingBlock,
  333. block_idx: int,
  334. sorted_blocks: List[LayoutParsingBlock],
  335. ) -> List[LayoutParsingBlock]:
  336. """
  337. Insert child blocks of a block into the sorted blocks list.
  338. Args:
  339. block: The parent block whose child blocks need to be inserted.
  340. block_idx: Index at which the parent block exists in the sorted blocks list.
  341. sorted_blocks: Sorted blocks list where the child blocks are to be inserted.
  342. Returns:
  343. sorted_blocks: Updated sorted blocks list after inserting child blocks.
  344. """
  345. if block.child_blocks:
  346. sub_blocks = block.get_child_blocks()
  347. sub_blocks.append(block)
  348. sub_blocks = sort_child_blocks(sub_blocks, sub_blocks[0].direction)
  349. sorted_blocks[block_idx] = sub_blocks[0]
  350. for block in sub_blocks[1:]:
  351. block_idx += 1
  352. sorted_blocks.insert(block_idx, block)
  353. return sorted_blocks
  354. def sort_child_blocks(blocks, direction="horizontal") -> List[LayoutParsingBlock]:
  355. """
  356. Sort child blocks based on their bounding box coordinates.
  357. Args:
  358. blocks: A list of LayoutParsingBlock objects representing the child blocks.
  359. direction: direction of the blocks ('horizontal' or 'vertical'). Default is 'horizontal'.
  360. Returns:
  361. sorted_blocks: A sorted list of LayoutParsingBlock objects.
  362. """
  363. if direction == "horizontal":
  364. # from top to bottom
  365. blocks.sort(
  366. key=lambda x: (
  367. x.bbox[1], # y_min
  368. x.bbox[0], # x_min
  369. x.bbox[1] ** 2 + x.bbox[0] ** 2, # distance with (0,0)
  370. ),
  371. )
  372. else:
  373. # from right to left
  374. blocks.sort(
  375. key=lambda x: (
  376. -x.bbox[0], # x_min
  377. x.bbox[1], # y_min
  378. x.bbox[1] ** 2 - x.bbox[0] ** 2, # distance with (max,0)
  379. ),
  380. )
  381. return blocks
  382. def _get_weights(label, direction="horizontal"):
  383. """Define weights based on the label and direction."""
  384. if label == "doc_title":
  385. return (
  386. [1, 0.1, 0.1, 1] if direction == "horizontal" else [0.2, 0.1, 1, 1]
  387. ) # left-down , right-left
  388. elif label in [
  389. "paragraph_title",
  390. "table_title",
  391. "abstract",
  392. "image",
  393. "seal",
  394. "chart",
  395. "figure",
  396. ]:
  397. return [1, 1, 0.1, 1] # down
  398. else:
  399. return [1, 1, 1, 0.1] # up
  400. def _manhattan_distance(
  401. point1: Tuple[float, float],
  402. point2: Tuple[float, float],
  403. weight_x: float = 1.0,
  404. weight_y: float = 1.0,
  405. ) -> float:
  406. """
  407. Calculate the weighted Manhattan distance between two points.
  408. Args:
  409. point1 (Tuple[float, float]): The first point as (x, y).
  410. point2 (Tuple[float, float]): The second point as (x, y).
  411. weight_x (float): The weight for the x-axis distance. Default is 1.0.
  412. weight_y (float): The weight for the y-axis distance. Default is 1.0.
  413. Returns:
  414. float: The weighted Manhattan distance between the two points.
  415. """
  416. return weight_x * abs(point1[0] - point2[0]) + weight_y * abs(point1[1] - point2[1])
  417. def sort_normal_blocks(blocks, text_line_height, text_line_width, region_direction):
  418. if region_direction == "horizontal":
  419. blocks.sort(
  420. key=lambda x: (
  421. x.bbox[1] // text_line_height,
  422. x.bbox[0] // text_line_width,
  423. x.bbox[1] ** 2 + x.bbox[0] ** 2,
  424. ),
  425. )
  426. else:
  427. blocks.sort(
  428. key=lambda x: (
  429. -x.bbox[0] // text_line_width,
  430. x.bbox[1] // text_line_height,
  431. x.bbox[1] ** 2 - x.bbox[2] ** 2, # distance with (max,0)
  432. ),
  433. )
  434. return blocks
  435. def sort_normal_blocks(blocks, text_line_height, text_line_width, region_direction):
  436. if region_direction == "horizontal":
  437. blocks.sort(
  438. key=lambda x: (
  439. x.bbox[1] // text_line_height,
  440. x.bbox[0] // text_line_width,
  441. x.bbox[1] ** 2 + x.bbox[0] ** 2,
  442. ),
  443. )
  444. else:
  445. blocks.sort(
  446. key=lambda x: (
  447. -x.bbox[0] // text_line_width,
  448. x.bbox[1] // text_line_height,
  449. -(x.bbox[2] ** 2 + x.bbox[1] ** 2),
  450. ),
  451. )
  452. return blocks
  453. def get_cut_blocks(
  454. blocks, cut_direction, cut_coordinates, overall_region_box, mask_labels=[]
  455. ):
  456. """
  457. Cut blocks based on the given cut direction and coordinates.
  458. Args:
  459. blocks (list): list of blocks to be cut.
  460. cut_direction (str): cut direction, either "horizontal" or "vertical".
  461. cut_coordinates (list): list of cut coordinates.
  462. overall_region_box (list): the overall region box that contains all blocks.
  463. Returns:
  464. list: a list of tuples containing the cutted blocks and their corresponding mean width。
  465. """
  466. cuted_list = []
  467. # filter out mask blocks,including header, footer, unordered and child_blocks
  468. # 0: horizontal, 1: vertical
  469. cut_aixis = 0 if cut_direction == "horizontal" else 1
  470. blocks.sort(key=lambda x: x.bbox[cut_aixis + 2])
  471. cut_coordinates.append(float("inf"))
  472. cut_coordinates = list(set(cut_coordinates))
  473. cut_coordinates.sort()
  474. cut_idx = 0
  475. for cut_coordinate in cut_coordinates:
  476. group_blocks = []
  477. block_idx = cut_idx
  478. while block_idx < len(blocks):
  479. block = blocks[block_idx]
  480. if block.bbox[cut_aixis + 2] > cut_coordinate:
  481. break
  482. elif block.order_label not in mask_labels:
  483. group_blocks.append(block)
  484. block_idx += 1
  485. cut_idx = block_idx
  486. if group_blocks:
  487. cuted_list.append(group_blocks)
  488. return cuted_list
  489. def add_split_block(
  490. blocks: List[LayoutParsingBlock], region_bbox: List[int]
  491. ) -> List[LayoutParsingBlock]:
  492. block_bboxes = np.array([block.bbox for block in blocks])
  493. discontinuous = calculate_discontinuous_projection(
  494. block_bboxes, direction="vertical"
  495. )
  496. current_interval = discontinuous[0]
  497. for interval in discontinuous[1:]:
  498. gap_len = interval[0] - current_interval[1]
  499. if gap_len > 40:
  500. x1, _, x2, __ = region_bbox
  501. y1 = current_interval[1] + 5
  502. y2 = interval[0] - 5
  503. bbox = [x1, y1, x2, y2]
  504. split_block = LayoutParsingBlock(label="split", bbox=bbox)
  505. blocks.append(split_block)
  506. current_interval = interval
  507. def get_nearest_blocks(
  508. block: LayoutParsingBlock,
  509. ref_blocks: List[LayoutParsingBlock],
  510. overlap_threshold,
  511. direction="horizontal",
  512. ) -> List:
  513. """
  514. Get the adjacent blocks with the same direction as the current block.
  515. Args:
  516. block (LayoutParsingBlock): The current block.
  517. blocks (List[LayoutParsingBlock]): A list of all blocks.
  518. ref_block_idxes (List[int]): A list of indices of reference blocks.
  519. iou_threshold (float): The IOU threshold to determine if two blocks are considered adjacent.
  520. Returns:
  521. Int: The index of the previous block with same direction.
  522. Int: The index of the following block with same direction.
  523. """
  524. prev_blocks: List[LayoutParsingBlock] = []
  525. post_blocks: List[LayoutParsingBlock] = []
  526. sort_index = 1 if direction == "horizontal" else 0
  527. for ref_block in ref_blocks:
  528. if ref_block.index == block.index:
  529. continue
  530. overlap_ratio = calculate_projection_overlap_ratio(
  531. block.bbox, ref_block.bbox, direction, mode="small"
  532. )
  533. if overlap_ratio > overlap_threshold:
  534. if ref_block.bbox[sort_index] <= block.bbox[sort_index]:
  535. prev_blocks.append(ref_block)
  536. else:
  537. post_blocks.append(ref_block)
  538. if prev_blocks:
  539. prev_blocks.sort(key=lambda x: x.bbox[sort_index], reverse=True)
  540. if post_blocks:
  541. post_blocks.sort(key=lambda x: x.bbox[sort_index])
  542. return prev_blocks, post_blocks
  543. def get_adjacent_blocks_by_direction(
  544. blocks: List[LayoutParsingBlock],
  545. block_idx: int,
  546. ref_block_idxes: List[int],
  547. iou_threshold,
  548. ) -> List:
  549. """
  550. Get the adjacent blocks with the same direction as the current block.
  551. Args:
  552. block (LayoutParsingBlock): The current block.
  553. blocks (List[LayoutParsingBlock]): A list of all blocks.
  554. ref_block_idxes (List[int]): A list of indices of reference blocks.
  555. iou_threshold (float): The IOU threshold to determine if two blocks are considered adjacent.
  556. Returns:
  557. Int: The index of the previous block with same direction.
  558. Int: The index of the following block with same direction.
  559. """
  560. min_prev_block_distance = float("inf")
  561. prev_block_index = None
  562. min_post_block_distance = float("inf")
  563. post_block_index = None
  564. block = blocks[block_idx]
  565. child_labels = [
  566. "vision_footnote",
  567. "sub_paragraph_title",
  568. "doc_title_text",
  569. "vision_title",
  570. ]
  571. # find the nearest text block with same direction to the current block
  572. for ref_block_idx in ref_block_idxes:
  573. ref_block = blocks[ref_block_idx]
  574. ref_block_direction = ref_block.direction
  575. if ref_block.order_label in child_labels:
  576. continue
  577. match_block_iou = calculate_projection_overlap_ratio(
  578. block.bbox,
  579. ref_block.bbox,
  580. ref_block_direction,
  581. )
  582. child_match_distance_tolerance_len = block.short_side_length / 10
  583. if block.order_label == "vision":
  584. if ref_block.num_of_lines == 1:
  585. gap_tolerance_len = ref_block.short_side_length * 2
  586. else:
  587. gap_tolerance_len = block.short_side_length / 10
  588. else:
  589. gap_tolerance_len = block.short_side_length * 2
  590. if match_block_iou >= iou_threshold:
  591. prev_distance = (
  592. block.secondary_direction_start_coordinate
  593. - ref_block.secondary_direction_end_coordinate
  594. + child_match_distance_tolerance_len
  595. ) // 5 + ref_block.start_coordinate / 5000
  596. next_distance = (
  597. ref_block.secondary_direction_start_coordinate
  598. - block.secondary_direction_end_coordinate
  599. + child_match_distance_tolerance_len
  600. ) // 5 + ref_block.start_coordinate / 5000
  601. if (
  602. ref_block.secondary_direction_end_coordinate
  603. <= block.secondary_direction_start_coordinate
  604. + child_match_distance_tolerance_len
  605. and prev_distance < min_prev_block_distance
  606. ):
  607. min_prev_block_distance = prev_distance
  608. if (
  609. block.secondary_direction_start_coordinate
  610. - ref_block.secondary_direction_end_coordinate
  611. < gap_tolerance_len
  612. ):
  613. prev_block_index = ref_block_idx
  614. elif (
  615. ref_block.secondary_direction_start_coordinate
  616. > block.secondary_direction_end_coordinate
  617. - child_match_distance_tolerance_len
  618. and next_distance < min_post_block_distance
  619. ):
  620. min_post_block_distance = next_distance
  621. if (
  622. ref_block.secondary_direction_start_coordinate
  623. - block.secondary_direction_end_coordinate
  624. < gap_tolerance_len
  625. ):
  626. post_block_index = ref_block_idx
  627. diff_dist = abs(min_prev_block_distance - min_post_block_distance)
  628. # if the difference in distance is too large, only consider the nearest one
  629. if diff_dist * 5 > block.short_side_length:
  630. if min_prev_block_distance < min_post_block_distance:
  631. post_block_index = None
  632. else:
  633. prev_block_index = None
  634. return prev_block_index, post_block_index
  635. def update_doc_title_child_blocks(
  636. block: LayoutParsingBlock,
  637. region: LayoutParsingRegion,
  638. ) -> None:
  639. """
  640. Update the child blocks of a document title block.
  641. The child blocks need to meet the following conditions:
  642. 1. They must be adjacent
  643. 2. They must have the same direction as the parent block.
  644. 3. Their short side length should be less than 80% of the parent's short side length.
  645. 4. Their long side length should be less than 150% of the parent's long side length.
  646. 5. The child block must be text block.
  647. 6. The nearest edge distance should be less than 2 times of the text line height.
  648. Args:
  649. blocks (List[LayoutParsingBlock]): overall blocks.
  650. block (LayoutParsingBlock): document title block.
  651. prev_idx (int): previous block index, None if not exist.
  652. post_idx (int): post block index, None if not exist.
  653. config (dict): configurations.
  654. Returns:
  655. None
  656. """
  657. ref_blocks = [region.block_map[idx] for idx in region.normal_text_block_idxes]
  658. overlap_threshold = XYCUT_SETTINGS["child_block_overlap_ratio_threshold"]
  659. prev_blocks, post_blocks = get_nearest_blocks(
  660. block, ref_blocks, overlap_threshold, block.direction
  661. )
  662. prev_block = None
  663. post_block = None
  664. if prev_blocks:
  665. prev_block = prev_blocks[0]
  666. if post_blocks:
  667. post_block = post_blocks[0]
  668. for ref_block in [prev_block, post_block]:
  669. if ref_block is None:
  670. continue
  671. with_seem_direction = ref_block.direction == block.direction
  672. short_side_length_condition = (
  673. ref_block.short_side_length < block.short_side_length * 0.8
  674. )
  675. long_side_length_condition = (
  676. ref_block.long_side_length < block.long_side_length
  677. or ref_block.long_side_length > 1.5 * block.long_side_length
  678. )
  679. nearest_edge_distance = get_nearest_edge_distance(block.bbox, ref_block.bbox)
  680. if (
  681. with_seem_direction
  682. and ref_block.label in BLOCK_LABEL_MAP["text_labels"]
  683. and short_side_length_condition
  684. and long_side_length_condition
  685. and ref_block.num_of_lines < 3
  686. and nearest_edge_distance < ref_block.text_line_height * 2
  687. ):
  688. ref_block.order_label = "doc_title_text"
  689. block.append_child_block(ref_block)
  690. region.normal_text_block_idxes.remove(ref_block.index)
  691. def update_paragraph_title_child_blocks(
  692. block: LayoutParsingBlock,
  693. region: LayoutParsingRegion,
  694. ) -> None:
  695. """
  696. Update the child blocks of a paragraph title block.
  697. The child blocks need to meet the following conditions:
  698. 1. They must be adjacent
  699. 2. They must have the same direction as the parent block.
  700. 3. The child block must be paragraph title block.
  701. Args:
  702. blocks (List[LayoutParsingBlock]): overall blocks.
  703. block (LayoutParsingBlock): document title block.
  704. prev_idx (int): previous block index, None if not exist.
  705. post_idx (int): post block index, None if not exist.
  706. config (dict): configurations.
  707. Returns:
  708. None
  709. """
  710. if block.order_label == "sub_paragraph_title":
  711. return
  712. ref_blocks = [
  713. region.block_map[idx]
  714. for idx in region.paragraph_title_block_idxes + region.normal_text_block_idxes
  715. ]
  716. overlap_threshold = XYCUT_SETTINGS["child_block_overlap_ratio_threshold"]
  717. prev_blocks, post_blocks = get_nearest_blocks(
  718. block, ref_blocks, overlap_threshold, block.direction
  719. )
  720. for ref_blocks in [prev_blocks, post_blocks]:
  721. for ref_block in ref_blocks:
  722. if ref_block.label not in BLOCK_LABEL_MAP["paragraph_title_labels"]:
  723. break
  724. min_text_line_height = min(
  725. block.text_line_height, ref_block.text_line_height
  726. )
  727. nearest_edge_distance = get_nearest_edge_distance(
  728. block.bbox, ref_block.bbox
  729. )
  730. with_seem_direction = ref_block.direction == block.direction
  731. if (
  732. with_seem_direction
  733. and nearest_edge_distance <= min_text_line_height * 1.5
  734. ):
  735. ref_block.order_label = "sub_paragraph_title"
  736. block.append_child_block(ref_block)
  737. region.paragraph_title_block_idxes.remove(ref_block.index)
  738. def update_vision_child_blocks(
  739. block: LayoutParsingBlock,
  740. region: LayoutParsingRegion,
  741. ) -> None:
  742. """
  743. Update the child blocks of a paragraph title block.
  744. The child blocks need to meet the following conditions:
  745. - For Both:
  746. 1. They must be adjacent
  747. 2. The child block must be vision_title or text block.
  748. - For vision_title:
  749. 1. The distance between the child block and the parent block should be less than 1/2 of the parent's height.
  750. - For text block:
  751. 1. The distance between the child block and the parent block should be less than 15.
  752. 2. The child short_side_length should be less than the parent's short side length.
  753. 3. The child long_side_length should be less than 50% of the parent's long side length.
  754. 4. The difference between their centers is very small.
  755. Args:
  756. blocks (List[LayoutParsingBlock]): overall blocks.
  757. block (LayoutParsingBlock): document title block.
  758. ref_block_idxes (List[int]): A list of indices of reference blocks.
  759. prev_idx (int): previous block index, None if not exist.
  760. post_idx (int): post block index, None if not exist.
  761. config (dict): configurations.
  762. Returns:
  763. None
  764. """
  765. ref_blocks = [
  766. region.block_map[idx]
  767. for idx in region.normal_text_block_idxes + region.vision_title_block_idxes
  768. ]
  769. overlap_threshold = XYCUT_SETTINGS["child_block_overlap_ratio_threshold"]
  770. has_vision_footnote = False
  771. has_vision_title = False
  772. for direction in [block.direction, block.secondary_direction]:
  773. prev_blocks, post_blocks = get_nearest_blocks(
  774. block, ref_blocks, overlap_threshold, direction
  775. )
  776. for ref_block in prev_blocks:
  777. if (
  778. ref_block.label
  779. not in BLOCK_LABEL_MAP["text_labels"]
  780. + BLOCK_LABEL_MAP["vision_title_labels"]
  781. ):
  782. break
  783. nearest_edge_distance = get_nearest_edge_distance(
  784. block.bbox, ref_block.bbox
  785. )
  786. block_center = block.get_centroid()
  787. ref_block_center = ref_block.get_centroid()
  788. if (
  789. ref_block.label in BLOCK_LABEL_MAP["vision_title_labels"]
  790. and nearest_edge_distance <= ref_block.text_line_height * 2
  791. ):
  792. has_vision_title = True
  793. ref_block.order_label = "vision_title"
  794. block.append_child_block(ref_block)
  795. region.vision_title_block_idxes.remove(ref_block.index)
  796. if ref_block.label in BLOCK_LABEL_MAP["text_labels"]:
  797. if (
  798. not has_vision_footnote
  799. and ref_block.direction == block.direction
  800. and ref_block.long_side_length < block.long_side_length
  801. ):
  802. if (
  803. (
  804. nearest_edge_distance <= block.text_line_height * 2
  805. and ref_block.short_side_length < block.short_side_length
  806. and ref_block.long_side_length
  807. < 0.5 * block.long_side_length
  808. and abs(block_center[0] - ref_block_center[0]) < 10
  809. )
  810. or (
  811. block.bbox[0] - ref_block.bbox[0] < 10
  812. and ref_block.num_of_lines == 1
  813. )
  814. or (
  815. block.bbox[2] - ref_block.bbox[2] < 10
  816. and ref_block.num_of_lines == 1
  817. )
  818. ):
  819. has_vision_footnote = True
  820. ref_block.order_label = "vision_footnote"
  821. block.append_child_block(ref_block)
  822. region.normal_text_block_idxes.remove(ref_block.index)
  823. break
  824. for ref_block in post_blocks:
  825. if (
  826. has_vision_footnote
  827. and ref_block.label in BLOCK_LABEL_MAP["text_labels"]
  828. ):
  829. break
  830. nearest_edge_distance = get_nearest_edge_distance(
  831. block.bbox, ref_block.bbox
  832. )
  833. block_center = block.get_centroid()
  834. ref_block_center = ref_block.get_centroid()
  835. if (
  836. ref_block.label in BLOCK_LABEL_MAP["vision_title_labels"]
  837. and nearest_edge_distance <= ref_block.text_line_height * 2
  838. ):
  839. has_vision_title = True
  840. ref_block.order_label = "vision_title"
  841. block.append_child_block(ref_block)
  842. region.vision_title_block_idxes.remove(ref_block.index)
  843. if ref_block.label in BLOCK_LABEL_MAP["text_labels"]:
  844. if (
  845. not has_vision_footnote
  846. and nearest_edge_distance <= block.text_line_height * 2
  847. and ref_block.short_side_length < block.short_side_length
  848. and ref_block.long_side_length < 0.5 * block.long_side_length
  849. and ref_block.direction == block.direction
  850. and (
  851. abs(block_center[0] - ref_block_center[0]) < 10
  852. or (
  853. block.bbox[0] - ref_block.bbox[0] < 10
  854. and ref_block.num_of_lines == 1
  855. )
  856. or (
  857. block.bbox[2] - ref_block.bbox[2] < 10
  858. and ref_block.num_of_lines == 1
  859. )
  860. )
  861. ):
  862. has_vision_footnote = True
  863. ref_block.order_label = "vision_footnote"
  864. block.append_child_block(ref_block)
  865. region.normal_text_block_idxes.remove(ref_block.index)
  866. break
  867. if has_vision_title:
  868. break
  869. def calculate_discontinuous_projection(
  870. boxes, direction="horizontal", return_num=False
  871. ) -> List:
  872. """
  873. Calculate the discontinuous projection of boxes along the specified direction.
  874. Args:
  875. boxes (ndarray): Array of bounding boxes represented by [[x_min, y_min, x_max, y_max]].
  876. direction (str): direction along which to perform the projection ('horizontal' or 'vertical').
  877. Returns:
  878. list: List of tuples representing the merged intervals.
  879. """
  880. boxes = np.array(boxes)
  881. if direction == "horizontal":
  882. intervals = boxes[:, [0, 2]]
  883. elif direction == "vertical":
  884. intervals = boxes[:, [1, 3]]
  885. else:
  886. raise ValueError("direction must be 'horizontal' or 'vertical'")
  887. intervals = intervals[np.argsort(intervals[:, 0])]
  888. merged_intervals = []
  889. num = 1
  890. current_start, current_end = intervals[0]
  891. num_list = []
  892. for start, end in intervals[1:]:
  893. if start <= current_end:
  894. num += 1
  895. current_end = max(current_end, end)
  896. else:
  897. num_list.append(num)
  898. merged_intervals.append((current_start, current_end))
  899. num = 1
  900. current_start, current_end = start, end
  901. num_list.append(num)
  902. merged_intervals.append((current_start, current_end))
  903. if return_num:
  904. return merged_intervals, num_list
  905. return merged_intervals
  906. def is_projection_consistent(blocks, intervals, direction="horizontal"):
  907. for interval in intervals:
  908. if direction == "horizontal":
  909. start_index, stop_index = 0, 2
  910. interval_box = [interval[0], 0, interval[1], 1]
  911. else:
  912. start_index, stop_index = 1, 3
  913. interval_box = [0, interval[0], 1, interval[1]]
  914. same_interval_bboxes = []
  915. for block in blocks:
  916. overlap_ratio = calculate_projection_overlap_ratio(
  917. interval_box, block.bbox, direction=direction
  918. )
  919. if overlap_ratio > 0 and block.label in BLOCK_LABEL_MAP["text_labels"]:
  920. same_interval_bboxes.append(block.bbox)
  921. start_coordinates = [bbox[start_index] for bbox in same_interval_bboxes]
  922. if start_coordinates:
  923. min_start_coordinate = min(start_coordinates)
  924. max_start_coordinate = max(start_coordinates)
  925. is_start_consistent = (
  926. False
  927. if max_start_coordinate - min_start_coordinate
  928. >= abs(interval[0] - interval[1]) * 0.05
  929. else True
  930. )
  931. stop_coordinates = [bbox[stop_index] for bbox in same_interval_bboxes]
  932. min_stop_coordinate = min(stop_coordinates)
  933. max_stop_coordinate = max(stop_coordinates)
  934. if (
  935. max_stop_coordinate - min_stop_coordinate
  936. >= abs(interval[0] - interval[1]) * 0.05
  937. and is_start_consistent
  938. ):
  939. return False
  940. return True
  941. def shrink_overlapping_boxes(
  942. boxes, direction="horizontal", min_threshold=0, max_threshold=0.1
  943. ) -> List:
  944. """
  945. Shrink overlapping boxes along the specified direction.
  946. Args:
  947. boxes (ndarray): Array of bounding boxes represented by [[x_min, y_min, x_max, y_max]].
  948. direction (str): direction along which to perform the shrinking ('horizontal' or 'vertical').
  949. min_threshold (float): Minimum threshold for shrinking. Default is 0.
  950. max_threshold (float): Maximum threshold for shrinking. Default is 0.2.
  951. Returns:
  952. list: List of tuples representing the merged intervals.
  953. """
  954. current_block = boxes[0]
  955. for block in boxes[1:]:
  956. x1, y1, x2, y2 = current_block.bbox
  957. x1_prime, y1_prime, x2_prime, y2_prime = block.bbox
  958. cut_iou = calculate_projection_overlap_ratio(
  959. current_block.bbox, block.bbox, direction=direction
  960. )
  961. match_iou = calculate_projection_overlap_ratio(
  962. current_block.bbox,
  963. block.bbox,
  964. direction="horizontal" if direction == "vertical" else "vertical",
  965. )
  966. if direction == "vertical":
  967. if (
  968. (match_iou > 0 and cut_iou > min_threshold and cut_iou < max_threshold)
  969. or y2 == y1_prime
  970. or abs(y2 - y1_prime) <= 3
  971. ):
  972. overlap_y_min = max(y1, y1_prime)
  973. overlap_y_max = min(y2, y2_prime)
  974. split_y = int((overlap_y_min + overlap_y_max) / 2)
  975. overlap_y_min = split_y - 1
  976. overlap_y_max = split_y + 1
  977. current_block.bbox = [x1, y1, x2, overlap_y_min]
  978. block.bbox = [x1_prime, overlap_y_max, x2_prime, y2_prime]
  979. else:
  980. if (
  981. (match_iou > 0 and cut_iou > min_threshold and cut_iou < max_threshold)
  982. or x2 == x1_prime
  983. or abs(x2 - x1_prime) <= 3
  984. ):
  985. overlap_x_min = max(x1, x1_prime)
  986. overlap_x_max = min(x2, x2_prime)
  987. split_x = int((overlap_x_min + overlap_x_max) / 2)
  988. overlap_x_min = split_x - 1
  989. overlap_x_max = split_x + 1
  990. current_block.bbox = [x1, y1, overlap_x_min, y2]
  991. block.bbox = [overlap_x_max, y1_prime, x2_prime, y2_prime]
  992. current_block = block
  993. return boxes