utils.py 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Tuple
  15. import numpy as np
  16. from ..layout_objects import LayoutBlock, LayoutRegion
  17. from ..setting import BLOCK_LABEL_MAP, XYCUT_SETTINGS
  18. from ..utils import (
  19. calculate_overlap_ratio,
  20. calculate_projection_overlap_ratio,
  21. get_seg_flag,
  22. )
  23. def get_nearest_edge_distance(
  24. bbox1: List[int],
  25. bbox2: List[int],
  26. weight: List[float] = [1.0, 1.0, 1.0, 1.0],
  27. ) -> Tuple[float]:
  28. """
  29. Calculate the nearest edge distance between two bounding boxes, considering directional weights.
  30. Args:
  31. bbox1 (list): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  32. bbox2 (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  33. weight (list, optional): directional weights for the edge distances [left, right, up, down]. Defaults to [1, 1, 1, 1].
  34. Returns:
  35. float: The calculated minimum edge distance between the bounding boxes.
  36. """
  37. x1, y1, x2, y2 = bbox1
  38. x1_prime, y1_prime, x2_prime, y2_prime = bbox2
  39. min_x_distance, min_y_distance = 0, 0
  40. horizontal_iou = calculate_projection_overlap_ratio(bbox1, bbox2, "horizontal")
  41. vertical_iou = calculate_projection_overlap_ratio(bbox1, bbox2, "vertical")
  42. if horizontal_iou > 0 and vertical_iou > 0:
  43. return 0.0
  44. if horizontal_iou == 0:
  45. min_x_distance = min(abs(x1 - x2_prime), abs(x2 - x1_prime)) * (
  46. weight[0] if x2 < x1_prime else weight[1]
  47. )
  48. if vertical_iou == 0:
  49. min_y_distance = min(abs(y1 - y2_prime), abs(y2 - y1_prime)) * (
  50. weight[2] if y2 < y1_prime else weight[3]
  51. )
  52. return min_x_distance + min_y_distance
  53. def projection_by_bboxes(boxes: np.ndarray, axis: int) -> np.ndarray:
  54. """
  55. Generate a 1D projection histogram from bounding boxes along a specified axis.
  56. Args:
  57. boxes: A (N, 4) array of bounding boxes defined by [x_min, y_min, x_max, y_max].
  58. axis: Axis for projection; 0 for horizontal (x-axis), 1 for vertical (y-axis).
  59. Returns:
  60. A 1D numpy array representing the projection histogram based on bounding box intervals.
  61. """
  62. assert axis in [0, 1]
  63. if np.min(boxes[:, axis::2]) < 0:
  64. max_length = abs(np.min(boxes[:, axis::2]))
  65. else:
  66. max_length = np.max(boxes[:, axis::2])
  67. projection = np.zeros(max_length, dtype=int)
  68. # Increment projection histogram over the interval defined by each bounding box
  69. for start, end in boxes[:, axis::2]:
  70. start = abs(start)
  71. end = abs(end)
  72. projection[start:end] += 1
  73. return projection
  74. def split_projection_profile(arr_values: np.ndarray, min_value: float, min_gap: float):
  75. """
  76. Split the projection profile into segments based on specified thresholds.
  77. Args:
  78. arr_values: 1D array representing the projection profile.
  79. min_value: Minimum value threshold to consider a profile segment significant.
  80. min_gap: Minimum gap width to consider a separation between segments.
  81. Returns:
  82. A tuple of start and end indices for each segment that meets the criteria.
  83. """
  84. # Identify indices where the projection exceeds the minimum value
  85. significant_indices = np.where(arr_values > min_value)[0]
  86. if not len(significant_indices):
  87. return
  88. # Calculate gaps between significant indices
  89. index_diffs = significant_indices[1:] - significant_indices[:-1]
  90. gap_indices = np.where(index_diffs > min_gap)[0]
  91. # Determine start and end indices of segments
  92. segment_starts = np.insert(
  93. significant_indices[gap_indices + 1],
  94. 0,
  95. significant_indices[0],
  96. )
  97. segment_ends = np.append(
  98. significant_indices[gap_indices],
  99. significant_indices[-1] + 1,
  100. )
  101. return segment_starts, segment_ends
  102. def recursive_yx_cut(
  103. boxes: np.ndarray, indices: List[int], res: List[int], min_gap: int = 1
  104. ):
  105. """
  106. Recursively project and segment bounding boxes, starting with Y-axis and followed by X-axis.
  107. Args:
  108. boxes: A (N, 4) array representing bounding boxes.
  109. indices: List of indices indicating the original position of boxes.
  110. res: List to store indices of the final segmented bounding boxes.
  111. min_gap (int): Minimum gap width to consider a separation between segments on the X-axis. Defaults to 1.
  112. Returns:
  113. None: This function modifies the `res` list in place.
  114. """
  115. assert len(boxes) == len(
  116. indices
  117. ), "The length of boxes and indices must be the same."
  118. # Sort by y_min for Y-axis projection
  119. y_sorted_indices = boxes[:, 1].argsort()
  120. y_sorted_boxes = boxes[y_sorted_indices]
  121. y_sorted_indices = np.array(indices)[y_sorted_indices]
  122. # Perform Y-axis projection
  123. y_projection = projection_by_bboxes(boxes=y_sorted_boxes, axis=1)
  124. y_intervals = split_projection_profile(y_projection, 0, 1)
  125. if not y_intervals:
  126. return
  127. # Process each segment defined by Y-axis projection
  128. for y_start, y_end in zip(*y_intervals):
  129. # Select boxes within the current y interval
  130. y_interval_indices = (y_start <= y_sorted_boxes[:, 1]) & (
  131. y_sorted_boxes[:, 1] < y_end
  132. )
  133. y_boxes_chunk = y_sorted_boxes[y_interval_indices]
  134. y_indices_chunk = y_sorted_indices[y_interval_indices]
  135. # Sort by x_min for X-axis projection
  136. x_sorted_indices = y_boxes_chunk[:, 0].argsort()
  137. x_sorted_boxes_chunk = y_boxes_chunk[x_sorted_indices]
  138. x_sorted_indices_chunk = y_indices_chunk[x_sorted_indices]
  139. # Perform X-axis projection
  140. x_projection = projection_by_bboxes(boxes=x_sorted_boxes_chunk, axis=0)
  141. x_intervals = split_projection_profile(x_projection, 0, min_gap)
  142. if not x_intervals:
  143. continue
  144. # If X-axis cannot be further segmented, add current indices to results
  145. if len(x_intervals[0]) == 1:
  146. res.extend(x_sorted_indices_chunk)
  147. continue
  148. if np.min(x_sorted_boxes_chunk[:, 0]) < 0:
  149. x_intervals = np.flip(x_intervals, axis=1)
  150. # Recursively process each segment defined by X-axis projection
  151. for x_start, x_end in zip(*x_intervals):
  152. x_interval_indices = (x_start <= abs(x_sorted_boxes_chunk[:, 0])) & (
  153. abs(x_sorted_boxes_chunk[:, 0]) < x_end
  154. )
  155. recursive_yx_cut(
  156. x_sorted_boxes_chunk[x_interval_indices],
  157. x_sorted_indices_chunk[x_interval_indices],
  158. res,
  159. )
  160. def recursive_xy_cut(
  161. boxes: np.ndarray, indices: List[int], res: List[int], min_gap: int = 1
  162. ):
  163. """
  164. Recursively performs X-axis projection followed by Y-axis projection to segment bounding boxes.
  165. Args:
  166. boxes: A (N, 4) array representing bounding boxes with [x_min, y_min, x_max, y_max].
  167. indices: A list of indices representing the position of boxes in the original data.
  168. res: A list to store indices of bounding boxes that meet the criteria.
  169. min_gap (int): Minimum gap width to consider a separation between segments on the X-axis. Defaults to 1.
  170. Returns:
  171. None: This function modifies the `res` list in place.
  172. """
  173. # Ensure boxes and indices have the same length
  174. assert len(boxes) == len(
  175. indices
  176. ), "The length of boxes and indices must be the same."
  177. # Sort by x_min to prepare for X-axis projection
  178. x_sorted_indices = boxes[:, 0].argsort()
  179. x_sorted_boxes = boxes[x_sorted_indices]
  180. x_sorted_indices = np.array(indices)[x_sorted_indices]
  181. # Perform X-axis projection
  182. x_projection = projection_by_bboxes(boxes=x_sorted_boxes, axis=0)
  183. x_intervals = split_projection_profile(x_projection, 0, 1)
  184. if not x_intervals:
  185. return
  186. if np.min(x_sorted_boxes[:, 0]) < 0:
  187. x_intervals = np.flip(x_intervals, axis=1)
  188. # Process each segment defined by X-axis projection
  189. for x_start, x_end in zip(*x_intervals):
  190. # Select boxes within the current x interval
  191. x_interval_indices = (x_start <= abs(x_sorted_boxes[:, 0])) & (
  192. abs(x_sorted_boxes[:, 0]) < x_end
  193. )
  194. x_boxes_chunk = x_sorted_boxes[x_interval_indices]
  195. x_indices_chunk = x_sorted_indices[x_interval_indices]
  196. # Sort selected boxes by y_min to prepare for Y-axis projection
  197. y_sorted_indices = x_boxes_chunk[:, 1].argsort()
  198. y_sorted_boxes_chunk = x_boxes_chunk[y_sorted_indices]
  199. y_sorted_indices_chunk = x_indices_chunk[y_sorted_indices]
  200. # Perform Y-axis projection
  201. y_projection = projection_by_bboxes(boxes=y_sorted_boxes_chunk, axis=1)
  202. y_intervals = split_projection_profile(y_projection, 0, min_gap)
  203. if not y_intervals:
  204. continue
  205. # If Y-axis cannot be further segmented, add current indices to results
  206. if len(y_intervals[0]) == 1:
  207. res.extend(y_sorted_indices_chunk)
  208. continue
  209. # Recursively process each segment defined by Y-axis projection
  210. for y_start, y_end in zip(*y_intervals):
  211. y_interval_indices = (y_start <= y_sorted_boxes_chunk[:, 1]) & (
  212. y_sorted_boxes_chunk[:, 1] < y_end
  213. )
  214. recursive_xy_cut(
  215. y_sorted_boxes_chunk[y_interval_indices],
  216. y_sorted_indices_chunk[y_interval_indices],
  217. res,
  218. )
  219. def reference_insert(
  220. block: LayoutBlock,
  221. sorted_blocks: List[LayoutBlock],
  222. **kwargs,
  223. ):
  224. """
  225. Insert reference block into sorted blocks based on the distance between the block and the nearest sorted block.
  226. Args:
  227. block: The block to insert into the sorted blocks.
  228. sorted_blocks: The sorted blocks where the new block will be inserted.
  229. config: Configuration dictionary containing parameters related to the layout parsing.
  230. median_width: Median width of the document. Defaults to 0.0.
  231. Returns:
  232. sorted_blocks: The updated sorted blocks after insertion.
  233. """
  234. min_distance = float("inf")
  235. nearest_sorted_block_index = 0
  236. for sorted_block_idx, sorted_block in enumerate(sorted_blocks):
  237. if sorted_block.bbox[3] <= block.bbox[1]:
  238. distance = -(sorted_block.bbox[2] * 10 + sorted_block.bbox[3])
  239. if distance < min_distance:
  240. min_distance = distance
  241. nearest_sorted_block_index = sorted_block_idx
  242. sorted_blocks.insert(nearest_sorted_block_index + 1, block)
  243. return sorted_blocks
  244. def manhattan_insert(
  245. block: LayoutBlock,
  246. sorted_blocks: List[LayoutBlock],
  247. **kwargs,
  248. ):
  249. """
  250. Insert a block into a sorted list of blocks based on the Manhattan distance between the block and the nearest sorted block.
  251. Args:
  252. block: The block to insert into the sorted blocks.
  253. sorted_blocks: The sorted blocks where the new block will be inserted.
  254. config: Configuration dictionary containing parameters related to the layout parsing.
  255. median_width: Median width of the document. Defaults to 0.0.
  256. Returns:
  257. sorted_blocks: The updated sorted blocks after insertion.
  258. """
  259. min_distance = float("inf")
  260. nearest_sorted_block_index = 0
  261. for sorted_block_idx, sorted_block in enumerate(sorted_blocks):
  262. distance = _manhattan_distance(block.bbox, sorted_block.bbox)
  263. if distance < min_distance:
  264. min_distance = distance
  265. nearest_sorted_block_index = sorted_block_idx
  266. sorted_blocks.insert(nearest_sorted_block_index + 1, block)
  267. return sorted_blocks
  268. def euclidean_insert(
  269. block: LayoutRegion,
  270. sorted_blocks: List[LayoutRegion],
  271. **kwargs,
  272. ):
  273. """
  274. Insert a block into a sorted list of blocks based on the Euclidean distance between the block and the nearest sorted block.
  275. Args:
  276. block: The block to insert into the sorted blocks.
  277. sorted_blocks: The sorted blocks where the new block will be inserted.
  278. config: Configuration dictionary containing parameters related to the layout parsing.
  279. median_width: Median width of the document. Defaults to 0.0.
  280. Returns:
  281. sorted_blocks: The updated sorted blocks after insertion.
  282. """
  283. nearest_sorted_block_index = len(sorted_blocks)
  284. block_euclidean_distance = block.euclidean_distance
  285. for sorted_block_idx, sorted_block in enumerate(sorted_blocks):
  286. distance = sorted_block.euclidean_distance
  287. if distance > block_euclidean_distance:
  288. nearest_sorted_block_index = sorted_block_idx
  289. break
  290. sorted_blocks.insert(nearest_sorted_block_index, block)
  291. return sorted_blocks
  292. def weighted_distance_insert(
  293. block: LayoutBlock,
  294. sorted_blocks: List[LayoutBlock],
  295. region: LayoutRegion,
  296. ):
  297. """
  298. Insert a block into a sorted list of blocks based on the weighted distance between the block and the nearest sorted block.
  299. Args:
  300. block: The block to insert into the sorted blocks.
  301. sorted_blocks: The sorted blocks where the new block will be inserted.
  302. config: Configuration dictionary containing parameters related to the layout parsing.
  303. median_width: Median width of the document. Defaults to 0.0.
  304. Returns:
  305. sorted_blocks: The updated sorted blocks after insertion.
  306. """
  307. tolerance_len = XYCUT_SETTINGS["edge_distance_compare_tolerance_len"]
  308. x1, y1, x2, y2 = block.bbox
  309. min_weighted_distance, min_edge_distance, min_up_edge_distance = (
  310. float("inf"),
  311. float("inf"),
  312. float("inf"),
  313. )
  314. nearest_sorted_block_index = 0
  315. for sorted_block_idx, sorted_block in enumerate(sorted_blocks):
  316. x1_prime, y1_prime, x2_prime, y2_prime = sorted_block.bbox
  317. # Calculate edge distance
  318. weight = _get_weights(block.order_label, block.direction)
  319. edge_distance = get_nearest_edge_distance(block.bbox, sorted_block.bbox, weight)
  320. if block.label in BLOCK_LABEL_MAP["doc_title_labels"]:
  321. disperse = max(1, region.text_line_width)
  322. tolerance_len = max(tolerance_len, disperse)
  323. if block.label == "abstract":
  324. tolerance_len *= 2
  325. edge_distance = max(0.1, edge_distance) * 10
  326. # Calculate up edge distances
  327. up_edge_distance = y1_prime if region.direction == "horizontal" else -x2_prime
  328. left_edge_distance = x1_prime if region.direction == "horizontal" else y1_prime
  329. is_below_sorted_block = (
  330. y2_prime < y1 if region.direction == "horizontal" else x1_prime > x2
  331. )
  332. if (
  333. block.label not in BLOCK_LABEL_MAP["unordered_labels"]
  334. or block.label in BLOCK_LABEL_MAP["doc_title_labels"]
  335. or block.label in BLOCK_LABEL_MAP["paragraph_title_labels"]
  336. or block.label in BLOCK_LABEL_MAP["vision_labels"]
  337. ) and is_below_sorted_block:
  338. up_edge_distance = -up_edge_distance
  339. left_edge_distance = -left_edge_distance
  340. if abs(min_up_edge_distance - up_edge_distance) <= tolerance_len:
  341. up_edge_distance = min_up_edge_distance
  342. # Calculate weighted distance
  343. weighted_distance = (
  344. +edge_distance
  345. * XYCUT_SETTINGS["distance_weight_map"].get("edge_weight", 10**4)
  346. + up_edge_distance
  347. * XYCUT_SETTINGS["distance_weight_map"].get("up_edge_weight", 1)
  348. + left_edge_distance
  349. * XYCUT_SETTINGS["distance_weight_map"].get("left_edge_weight", 0.0001)
  350. )
  351. min_edge_distance = min(edge_distance, min_edge_distance)
  352. min_up_edge_distance = min(up_edge_distance, min_up_edge_distance)
  353. if weighted_distance < min_weighted_distance:
  354. nearest_sorted_block_index = sorted_block_idx
  355. min_weighted_distance = weighted_distance
  356. if abs(y1 // 2 - y1_prime // 2) > 0:
  357. sorted_distance = y1_prime
  358. block_distance = y1
  359. else:
  360. if region.direction == "horizontal":
  361. if abs(x1 // 2 - x2 // 2) > 0:
  362. sorted_distance = x1_prime
  363. block_distance = x1
  364. else:
  365. # distance with (0,0)
  366. sorted_block_center_x, sorted_block_center_y = (
  367. sorted_block.get_centroid()
  368. )
  369. block_center_x, block_center_y = block.get_centroid()
  370. sorted_distance = (
  371. sorted_block_center_x**2 + sorted_block_center_y**2
  372. )
  373. block_distance = block_center_x**2 + block_center_y**2
  374. else:
  375. if abs(x1 - x2) > 0:
  376. sorted_distance = -x2_prime
  377. block_distance = -x2
  378. else:
  379. # distance with (max,0)
  380. sorted_block_center_x, sorted_block_center_y = (
  381. sorted_block.get_centroid()
  382. )
  383. block_center_x, block_center_y = block.get_centroid()
  384. sorted_distance = (
  385. sorted_block_center_x**2 + sorted_block_center_y**2
  386. )
  387. block_distance = block_center_x**2 + block_center_y**2
  388. if block_distance > sorted_distance:
  389. nearest_sorted_block_index = sorted_block_idx + 1
  390. if (
  391. sorted_block_idx < len(sorted_blocks) - 1
  392. and block.label
  393. in BLOCK_LABEL_MAP["vision_labels"]
  394. + BLOCK_LABEL_MAP["vision_title_labels"]
  395. ):
  396. seg_start_flag, _ = get_seg_flag(
  397. sorted_blocks[sorted_block_idx + 1],
  398. sorted_blocks[sorted_block_idx],
  399. )
  400. if not seg_start_flag:
  401. nearest_sorted_block_index += 1
  402. else:
  403. if (
  404. sorted_block_idx > 0
  405. and block.label
  406. in BLOCK_LABEL_MAP["vision_labels"]
  407. + BLOCK_LABEL_MAP["vision_title_labels"]
  408. ):
  409. seg_start_flag, _ = get_seg_flag(
  410. sorted_blocks[sorted_block_idx],
  411. sorted_blocks[sorted_block_idx - 1],
  412. )
  413. if not seg_start_flag:
  414. nearest_sorted_block_index = sorted_block_idx - 1
  415. sorted_blocks.insert(nearest_sorted_block_index, block)
  416. return sorted_blocks
  417. def insert_child_blocks(
  418. block: LayoutBlock,
  419. block_idx: int,
  420. sorted_blocks: List[LayoutBlock],
  421. ) -> List[LayoutBlock]:
  422. """
  423. Insert child blocks of a block into the sorted blocks list.
  424. Args:
  425. block: The parent block whose child blocks need to be inserted.
  426. block_idx: Index at which the parent block exists in the sorted blocks list.
  427. sorted_blocks: Sorted blocks list where the child blocks are to be inserted.
  428. Returns:
  429. sorted_blocks: Updated sorted blocks list after inserting child blocks.
  430. """
  431. if block.child_blocks:
  432. sub_blocks = block.get_child_blocks()
  433. sub_blocks.append(block)
  434. sub_blocks = sort_child_blocks(sub_blocks, sub_blocks[0].direction)
  435. sorted_blocks[block_idx] = sub_blocks[0]
  436. for block in sub_blocks[1:]:
  437. block_idx += 1
  438. sorted_blocks.insert(block_idx, block)
  439. return sorted_blocks
  440. def sort_child_blocks(
  441. blocks: List[LayoutRegion], direction="horizontal"
  442. ) -> List[LayoutBlock]:
  443. """
  444. Sort child blocks based on their bounding box coordinates.
  445. Args:
  446. blocks: A list of LayoutBlock objects representing the child blocks.
  447. direction: direction of the blocks ('horizontal' or 'vertical'). Default is 'horizontal'.
  448. Returns:
  449. sorted_blocks: A sorted list of LayoutBlock objects.
  450. """
  451. if blocks[0].label != "region":
  452. if direction == "horizontal":
  453. blocks.sort(
  454. key=lambda x: (
  455. x.bbox[1],
  456. x.bbox[0],
  457. x.get_centroid()[0] ** 2 + x.get_centroid()[1] ** 2,
  458. ), # distance with (0,0)
  459. )
  460. else:
  461. blocks.sort(
  462. key=lambda x: (
  463. -x.bbox[2],
  464. x.bbox[1],
  465. -x.get_centroid()[0] ** 2 + x.get_centroid()[1] ** 2,
  466. ), # distance with (max,0)
  467. )
  468. else:
  469. blocks.sort(key=lambda x: x.euclidean_distance)
  470. return blocks
  471. def _get_weights(label, direction="horizontal"):
  472. """Define weights based on the label and direction."""
  473. if label == "doc_title":
  474. return (
  475. [1, 0.1, 0.1, 1] if direction == "horizontal" else [0.2, 0.1, 1, 1]
  476. ) # left-down , right-left
  477. elif label in [
  478. "paragraph_title",
  479. "table_title",
  480. "abstract",
  481. "image",
  482. "seal",
  483. "chart",
  484. "figure",
  485. ]:
  486. return [1, 1, 0.1, 1] # down
  487. else:
  488. return [1, 1, 1, 0.1] # up
  489. def _manhattan_distance(
  490. point1: Tuple[float, float],
  491. point2: Tuple[float, float],
  492. weight_x: float = 1.0,
  493. weight_y: float = 1.0,
  494. ) -> float:
  495. """
  496. Calculate the weighted Manhattan distance between two points.
  497. Args:
  498. point1 (Tuple[float, float]): The first point as (x, y).
  499. point2 (Tuple[float, float]): The second point as (x, y).
  500. weight_x (float): The weight for the x-axis distance. Default is 1.0.
  501. weight_y (float): The weight for the y-axis distance. Default is 1.0.
  502. Returns:
  503. float: The weighted Manhattan distance between the two points.
  504. """
  505. return weight_x * abs(point1[0] - point2[0]) + weight_y * abs(point1[1] - point2[1])
  506. def sort_normal_blocks(
  507. blocks, text_line_height, text_line_width, region_direction
  508. ) -> List[LayoutBlock]:
  509. """Sort blocks by their position within the page
  510. Args:
  511. blocks (List[LayoutBlock]): List of blocks to be sorted.
  512. text_line_height (int): Height of each line of text.
  513. text_line_width (int): Width of each line of text.
  514. region_direction (str): Direction of the region, either "horizontal" or "vertical".
  515. Returns:
  516. List[LayoutBlock]: Sorted list of blocks.
  517. """
  518. if region_direction == "horizontal":
  519. blocks.sort(
  520. key=lambda x: (
  521. x.bbox[1] // text_line_height,
  522. x.bbox[0] // text_line_width,
  523. x.get_centroid()[0] ** 2 + x.get_centroid()[1] ** 2,
  524. ),
  525. )
  526. else:
  527. blocks.sort(
  528. key=lambda x: (
  529. -x.bbox[2] // text_line_width,
  530. x.bbox[1] // text_line_height,
  531. -x.get_centroid()[0] ** 2 + x.get_centroid()[1] ** 2,
  532. ),
  533. )
  534. return blocks
  535. def get_cut_blocks(blocks, cut_direction, cut_coordinates, mask_labels=[]):
  536. """
  537. Cut blocks based on the given cut direction and coordinates.
  538. Args:
  539. blocks (list): list of blocks to be cut.
  540. cut_direction (str): cut direction, either "horizontal" or "vertical".
  541. cut_coordinates (list): list of cut coordinates.
  542. Returns:
  543. list: a list of tuples containing the cutted blocks and their corresponding mean width。
  544. """
  545. cuted_list = []
  546. # filter out mask blocks,including header, footer, unordered and child_blocks
  547. # 0: horizontal, 1: vertical
  548. cut_aixis = 0 if cut_direction == "horizontal" else 1
  549. blocks.sort(key=lambda x: x.bbox[cut_aixis + 2])
  550. cut_coordinates.append(float("inf"))
  551. cut_coordinates = list(set(cut_coordinates))
  552. cut_coordinates.sort()
  553. cut_idx = 0
  554. for cut_coordinate in cut_coordinates:
  555. group_blocks = []
  556. block_idx = cut_idx
  557. while block_idx < len(blocks):
  558. block = blocks[block_idx]
  559. if block.bbox[cut_aixis + 2] > cut_coordinate:
  560. break
  561. elif block.order_label not in mask_labels:
  562. group_blocks.append(block)
  563. block_idx += 1
  564. cut_idx = block_idx
  565. if group_blocks:
  566. cuted_list.append(group_blocks)
  567. return cuted_list
  568. def get_blocks_by_direction_interval(
  569. blocks: List[LayoutBlock],
  570. start_index: int,
  571. end_index: int,
  572. direction: str = "horizontal",
  573. ) -> List[LayoutBlock]:
  574. """
  575. Get blocks within a specified direction interval.
  576. Args:
  577. blocks (List[LayoutBlock]): A list of blocks.
  578. start_index (int): The starting index of the direction.
  579. end_index (int): The ending index of the direction.
  580. direction (str, optional): The direction to consider. Defaults to "horizontal".
  581. Returns:
  582. List[LayoutBlock]: A list of blocks within the specified direction interval.
  583. """
  584. interval_blocks = []
  585. aixis = 0 if direction == "horizontal" else 1
  586. blocks.sort(key=lambda x: x.bbox[aixis + 2])
  587. for block in blocks:
  588. if block.bbox[aixis] >= start_index and block.bbox[aixis + 2] <= end_index:
  589. interval_blocks.append(block)
  590. return interval_blocks
  591. def get_nearest_blocks(
  592. block: LayoutBlock,
  593. ref_blocks: List[LayoutBlock],
  594. overlap_threshold,
  595. direction="horizontal",
  596. ) -> List:
  597. """
  598. Get the adjacent blocks with the same direction as the current block.
  599. Args:
  600. block (LayoutBlock): The current block.
  601. blocks (List[LayoutBlock]): A list of all blocks.
  602. ref_block_idxes (List[int]): A list of indices of reference blocks.
  603. iou_threshold (float): The IOU threshold to determine if two blocks are considered adjacent.
  604. Returns:
  605. Int: The index of the previous block with same direction.
  606. Int: The index of the following block with same direction.
  607. """
  608. prev_blocks: List[LayoutBlock] = []
  609. post_blocks: List[LayoutBlock] = []
  610. sort_index = 1 if direction == "horizontal" else 0
  611. for ref_block in ref_blocks:
  612. if ref_block.index == block.index:
  613. continue
  614. overlap_ratio = calculate_projection_overlap_ratio(
  615. block.bbox, ref_block.bbox, direction, mode="small"
  616. )
  617. if overlap_ratio > overlap_threshold:
  618. if ref_block.bbox[sort_index] <= block.bbox[sort_index]:
  619. prev_blocks.append(ref_block)
  620. else:
  621. post_blocks.append(ref_block)
  622. if prev_blocks:
  623. prev_blocks.sort(key=lambda x: x.bbox[sort_index], reverse=True)
  624. if post_blocks:
  625. post_blocks.sort(key=lambda x: x.bbox[sort_index])
  626. return prev_blocks, post_blocks
  627. def update_doc_title_child_blocks(
  628. block: LayoutBlock,
  629. region: LayoutRegion,
  630. ) -> None:
  631. """
  632. Update the child blocks of a document title block.
  633. The child blocks need to meet the following conditions:
  634. 1. They must be adjacent
  635. 2. They must have the same direction as the parent block.
  636. 3. Their short side length should be less than 80% of the parent's short side length.
  637. 4. Their long side length should be less than 150% of the parent's long side length.
  638. 5. The child block must be text block.
  639. 6. The nearest edge distance should be less than 2 times of the text line height.
  640. Args:
  641. blocks (List[LayoutBlock]): overall blocks.
  642. block (LayoutBlock): document title block.
  643. prev_idx (int): previous block index, None if not exist.
  644. post_idx (int): post block index, None if not exist.
  645. config (dict): configurations.
  646. Returns:
  647. None
  648. """
  649. ref_blocks = [region.block_map[idx] for idx in region.normal_text_block_idxes]
  650. overlap_threshold = XYCUT_SETTINGS["child_block_overlap_ratio_threshold"]
  651. prev_blocks, post_blocks = get_nearest_blocks(
  652. block, ref_blocks, overlap_threshold, block.direction
  653. )
  654. prev_block = None
  655. post_block = None
  656. if prev_blocks:
  657. prev_block = prev_blocks[0]
  658. if post_blocks:
  659. post_block = post_blocks[0]
  660. for ref_block in [prev_block, post_block]:
  661. if ref_block is None:
  662. continue
  663. with_seem_direction = ref_block.direction == block.direction
  664. short_side_length_condition = (
  665. ref_block.short_side_length < block.short_side_length * 0.8
  666. )
  667. long_side_length_condition = (
  668. ref_block.long_side_length < block.long_side_length
  669. or ref_block.long_side_length > 1.5 * block.long_side_length
  670. )
  671. nearest_edge_distance = get_nearest_edge_distance(block.bbox, ref_block.bbox)
  672. if (
  673. with_seem_direction
  674. and ref_block.label in BLOCK_LABEL_MAP["text_labels"]
  675. and short_side_length_condition
  676. and long_side_length_condition
  677. and ref_block.num_of_lines < 3
  678. and nearest_edge_distance < ref_block.text_line_height * 2
  679. ):
  680. ref_block.order_label = "doc_title_text"
  681. block.append_child_block(ref_block)
  682. region.normal_text_block_idxes.remove(ref_block.index)
  683. for ref_block in ref_blocks:
  684. if ref_block.order_label == "doc_title_text":
  685. continue
  686. with_seem_direction = ref_block.direction == block.direction
  687. overlap_ratio = calculate_overlap_ratio(
  688. block.bbox, ref_block.bbox, mode="small"
  689. )
  690. if overlap_ratio > 0.9 and with_seem_direction:
  691. ref_block.order_label = "doc_title_text"
  692. block.append_child_block(ref_block)
  693. region.normal_text_block_idxes.remove(ref_block.index)
  694. def update_paragraph_title_child_blocks(
  695. block: LayoutBlock,
  696. region: LayoutRegion,
  697. ) -> None:
  698. """
  699. Update the child blocks of a paragraph title block.
  700. The child blocks need to meet the following conditions:
  701. 1. They must be adjacent
  702. 2. They must have the same direction as the parent block.
  703. 3. The child block must be paragraph title block.
  704. Args:
  705. blocks (List[LayoutBlock]): overall blocks.
  706. block (LayoutBlock): document title block.
  707. prev_idx (int): previous block index, None if not exist.
  708. post_idx (int): post block index, None if not exist.
  709. config (dict): configurations.
  710. Returns:
  711. None
  712. """
  713. if block.order_label == "sub_paragraph_title":
  714. return
  715. ref_blocks = [
  716. region.block_map[idx]
  717. for idx in region.paragraph_title_block_idxes + region.normal_text_block_idxes
  718. ]
  719. overlap_threshold = XYCUT_SETTINGS["child_block_overlap_ratio_threshold"]
  720. prev_blocks, post_blocks = get_nearest_blocks(
  721. block, ref_blocks, overlap_threshold, block.direction
  722. )
  723. for ref_blocks in [prev_blocks, post_blocks]:
  724. for ref_block in ref_blocks:
  725. if ref_block.label not in BLOCK_LABEL_MAP["paragraph_title_labels"]:
  726. break
  727. min_text_line_height = min(
  728. block.text_line_height, ref_block.text_line_height
  729. )
  730. nearest_edge_distance = get_nearest_edge_distance(
  731. block.bbox, ref_block.bbox
  732. )
  733. with_seem_direction = ref_block.direction == block.direction
  734. with_seem_start = (
  735. abs(ref_block.start_coordinate - block.start_coordinate)
  736. < min_text_line_height * 2
  737. )
  738. if (
  739. with_seem_direction
  740. and with_seem_start
  741. and nearest_edge_distance <= min_text_line_height * 1.5
  742. ):
  743. ref_block.order_label = "sub_paragraph_title"
  744. block.append_child_block(ref_block)
  745. region.paragraph_title_block_idxes.remove(ref_block.index)
  746. def update_vision_child_blocks(
  747. block: LayoutBlock,
  748. region: LayoutRegion,
  749. ) -> None:
  750. """
  751. Update the child blocks of a paragraph title block.
  752. The child blocks need to meet the following conditions:
  753. - For Both:
  754. 1. They must be adjacent
  755. 2. The child block must be vision_title or text block.
  756. - For vision_title:
  757. 1. The distance between the child block and the parent block should be less than 1/2 of the parent's height.
  758. - For text block:
  759. 1. The distance between the child block and the parent block should be less than 15.
  760. 2. The child short_side_length should be less than the parent's short side length.
  761. 3. The child long_side_length should be less than 50% of the parent's long side length.
  762. 4. The difference between their centers is very small.
  763. Args:
  764. blocks (List[LayoutBlock]): overall blocks.
  765. block (LayoutBlock): document title block.
  766. ref_block_idxes (List[int]): A list of indices of reference blocks.
  767. prev_idx (int): previous block index, None if not exist.
  768. post_idx (int): post block index, None if not exist.
  769. config (dict): configurations.
  770. Returns:
  771. None
  772. """
  773. ref_blocks = [
  774. region.block_map[idx]
  775. for idx in region.normal_text_block_idxes + region.vision_title_block_idxes
  776. ]
  777. overlap_threshold = XYCUT_SETTINGS["child_block_overlap_ratio_threshold"]
  778. has_vision_footnote = False
  779. has_vision_title = False
  780. for direction in [block.direction, block.secondary_direction]:
  781. prev_blocks, post_blocks = get_nearest_blocks(
  782. block, ref_blocks, overlap_threshold, direction
  783. )
  784. for ref_block in prev_blocks:
  785. if (
  786. ref_block.label
  787. not in BLOCK_LABEL_MAP["text_labels"]
  788. + BLOCK_LABEL_MAP["vision_title_labels"]
  789. ):
  790. break
  791. nearest_edge_distance = get_nearest_edge_distance(
  792. block.bbox, ref_block.bbox
  793. )
  794. block_center = block.get_centroid()
  795. ref_block_center = ref_block.get_centroid()
  796. if (
  797. ref_block.label in BLOCK_LABEL_MAP["vision_title_labels"]
  798. and nearest_edge_distance <= ref_block.text_line_height * 2
  799. ):
  800. has_vision_title = True
  801. ref_block.order_label = "vision_title"
  802. block.append_child_block(ref_block)
  803. region.vision_title_block_idxes.remove(ref_block.index)
  804. if ref_block.label in BLOCK_LABEL_MAP["text_labels"]:
  805. if (
  806. not has_vision_footnote
  807. and ref_block.direction == block.direction
  808. and ref_block.long_side_length < block.long_side_length
  809. and nearest_edge_distance <= ref_block.text_line_height * 2
  810. ):
  811. if (
  812. (
  813. ref_block.short_side_length < block.short_side_length
  814. and ref_block.long_side_length
  815. < 0.5 * block.long_side_length
  816. and abs(block_center[0] - ref_block_center[0]) < 10
  817. )
  818. or (
  819. block.bbox[0] - ref_block.bbox[0] < 10
  820. and ref_block.num_of_lines == 1
  821. )
  822. or (
  823. block.bbox[2] - ref_block.bbox[2] < 10
  824. and ref_block.num_of_lines == 1
  825. )
  826. ):
  827. has_vision_footnote = True
  828. ref_block.order_label = "vision_footnote"
  829. block.append_child_block(ref_block)
  830. region.normal_text_block_idxes.remove(ref_block.index)
  831. break
  832. for ref_block in post_blocks:
  833. if (
  834. has_vision_footnote
  835. and ref_block.label in BLOCK_LABEL_MAP["text_labels"]
  836. ):
  837. break
  838. nearest_edge_distance = get_nearest_edge_distance(
  839. block.bbox, ref_block.bbox
  840. )
  841. block_center = block.get_centroid()
  842. ref_block_center = ref_block.get_centroid()
  843. if (
  844. ref_block.label in BLOCK_LABEL_MAP["vision_title_labels"]
  845. and nearest_edge_distance <= ref_block.text_line_height * 2
  846. ):
  847. has_vision_title = True
  848. ref_block.order_label = "vision_title"
  849. block.append_child_block(ref_block)
  850. region.vision_title_block_idxes.remove(ref_block.index)
  851. if ref_block.label in BLOCK_LABEL_MAP["text_labels"]:
  852. if (
  853. not has_vision_footnote
  854. and ref_block.direction == block.direction
  855. and ref_block.long_side_length < block.long_side_length
  856. and nearest_edge_distance <= ref_block.text_line_height * 2
  857. ):
  858. if (
  859. (
  860. ref_block.short_side_length < block.short_side_length
  861. and ref_block.long_side_length
  862. < 0.5 * block.long_side_length
  863. and abs(block_center[0] - ref_block_center[0]) < 10
  864. )
  865. or (
  866. block.bbox[0] - ref_block.bbox[0] < 10
  867. and ref_block.num_of_lines == 1
  868. )
  869. or (
  870. block.bbox[2] - ref_block.bbox[2] < 10
  871. and ref_block.num_of_lines == 1
  872. )
  873. ):
  874. has_vision_footnote = True
  875. ref_block.label = "vision_footnote"
  876. ref_block.order_label = "vision_footnote"
  877. block.append_child_block(ref_block)
  878. region.normal_text_block_idxes.remove(ref_block.index)
  879. break
  880. if has_vision_title:
  881. break
  882. for ref_block in ref_blocks:
  883. if ref_block.index not in region.normal_text_block_idxes:
  884. continue
  885. overlap_ratio = calculate_overlap_ratio(
  886. block.bbox, ref_block.bbox, mode="small"
  887. )
  888. if overlap_ratio > 0.9:
  889. ref_block.label = "vision_footnote"
  890. ref_block.order_label = "vision_footnote"
  891. block.append_child_block(ref_block)
  892. region.normal_text_block_idxes.remove(ref_block.index)
  893. def update_region_child_blocks(
  894. block: LayoutBlock,
  895. region: LayoutRegion,
  896. ) -> None:
  897. """Update child blocks of a region.
  898. Args:
  899. block (LayoutBlock): document title block.
  900. region (LayoutRegion): layout region.
  901. Returns:
  902. None
  903. """
  904. for ref_block in region.block_map.values():
  905. if block.index != ref_block.index:
  906. bbox_iou = calculate_overlap_ratio(block.bbox, ref_block.bbox)
  907. if (
  908. bbox_iou > 0
  909. and block.area > ref_block.area
  910. and ref_block.order_label != "sub_region"
  911. ):
  912. ref_block.order_label = "sub_region"
  913. block.append_child_block(ref_block)
  914. region.normal_text_block_idxes.remove(ref_block.index)
  915. def calculate_discontinuous_projection(
  916. boxes, direction="horizontal", return_num=False
  917. ) -> List:
  918. """
  919. Calculate the discontinuous projection of boxes along the specified direction.
  920. Args:
  921. boxes (ndarray): Array of bounding boxes represented by [[x_min, y_min, x_max, y_max]].
  922. direction (str): direction along which to perform the projection ('horizontal' or 'vertical').
  923. Returns:
  924. list: List of tuples representing the merged intervals.
  925. """
  926. boxes = np.array(boxes)
  927. if direction == "horizontal":
  928. intervals = boxes[:, [0, 2]]
  929. elif direction == "vertical":
  930. intervals = boxes[:, [1, 3]]
  931. else:
  932. raise ValueError("direction must be 'horizontal' or 'vertical'")
  933. intervals = intervals[np.argsort(intervals[:, 0])]
  934. merged_intervals = []
  935. num = 1
  936. current_start, current_end = intervals[0]
  937. num_list = []
  938. for start, end in intervals[1:]:
  939. if start <= current_end:
  940. num += 1
  941. current_end = max(current_end, end)
  942. else:
  943. num_list.append(num)
  944. merged_intervals.append((current_start, current_end))
  945. num = 1
  946. current_start, current_end = start, end
  947. num_list.append(num)
  948. merged_intervals.append((current_start, current_end))
  949. if return_num:
  950. return merged_intervals, num_list
  951. return merged_intervals
  952. def shrink_overlapping_boxes(
  953. boxes, direction="horizontal", min_threshold=0, max_threshold=0.1
  954. ) -> List:
  955. """
  956. Shrink overlapping boxes along the specified direction.
  957. Args:
  958. boxes (ndarray): Array of bounding boxes represented by [[x_min, y_min, x_max, y_max]].
  959. direction (str): direction along which to perform the shrinking ('horizontal' or 'vertical').
  960. min_threshold (float): Minimum threshold for shrinking. Default is 0.
  961. max_threshold (float): Maximum threshold for shrinking. Default is 0.2.
  962. Returns:
  963. list: List of tuples representing the merged intervals.
  964. """
  965. current_block = boxes[0]
  966. for block in boxes[1:]:
  967. x1, y1, x2, y2 = current_block.bbox
  968. x1_prime, y1_prime, x2_prime, y2_prime = block.bbox
  969. cut_iou = calculate_projection_overlap_ratio(
  970. current_block.bbox, block.bbox, direction=direction
  971. )
  972. match_iou = calculate_projection_overlap_ratio(
  973. current_block.bbox,
  974. block.bbox,
  975. direction="horizontal" if direction == "vertical" else "vertical",
  976. )
  977. if direction == "vertical":
  978. if (
  979. (match_iou > 0 and cut_iou > min_threshold and cut_iou < max_threshold)
  980. or y2 == y1_prime
  981. or abs(y2 - y1_prime) <= 3
  982. ):
  983. overlap_y_min = max(y1, y1_prime)
  984. overlap_y_max = min(y2, y2_prime)
  985. split_y = int((overlap_y_min + overlap_y_max) / 2)
  986. overlap_y_min = split_y - 1
  987. overlap_y_max = split_y + 1
  988. if y1 < y1_prime:
  989. current_block.bbox = [x1, y1, x2, overlap_y_min]
  990. block.bbox = [x1_prime, overlap_y_max, x2_prime, y2_prime]
  991. else:
  992. current_block.bbox = [x1, overlap_y_min, x2, y2]
  993. block.bbox = [x1_prime, y1_prime, x2_prime, overlap_y_max]
  994. else:
  995. if (
  996. (match_iou > 0 and cut_iou > min_threshold and cut_iou < max_threshold)
  997. or x2 == x1_prime
  998. or abs(x2 - x1_prime) <= 3
  999. ):
  1000. overlap_x_min = max(x1, x1_prime)
  1001. overlap_x_max = min(x2, x2_prime)
  1002. split_x = int((overlap_x_min + overlap_x_max) / 2)
  1003. overlap_x_min = split_x - 1
  1004. overlap_x_max = split_x + 1
  1005. if x1 < x1_prime:
  1006. current_block.bbox = [x1, y1, overlap_x_min, y2]
  1007. block.bbox = [overlap_x_max, y1_prime, x2_prime, y2_prime]
  1008. else:
  1009. current_block.bbox = [overlap_x_min, y1, x2, y2]
  1010. block.bbox = [x1_prime, y1_prime, overlap_x_max, y2_prime]
  1011. current_block = block
  1012. return boxes
  1013. def find_local_minima_flat_regions(arr) -> List:
  1014. """
  1015. Find all local minima regions in a flat array.
  1016. Args:
  1017. arr (list): The input array.
  1018. Returns:
  1019. list: A list of tuples containing the indices of the local minima regions.
  1020. """
  1021. n = len(arr)
  1022. if n == 0:
  1023. return []
  1024. flat_minima_regions = []
  1025. start = 0
  1026. for i in range(1, n):
  1027. if arr[i] != arr[i - 1]:
  1028. if (start == 0 or arr[start - 1] > arr[start]) and (
  1029. i == n or arr[i] > arr[start]
  1030. ):
  1031. flat_minima_regions.append((start, i - 1))
  1032. start = i
  1033. return flat_minima_regions[1:] if len(flat_minima_regions) > 1 else None