xycuts.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, List, Tuple
  15. import numpy as np
  16. from ..result_v2 import LayoutParsingBlock
  17. from .utils import (
  18. calculate_discontinuous_projection,
  19. calculate_iou,
  20. calculate_projection_iou,
  21. get_adjacent_blocks_by_direction,
  22. get_cut_blocks,
  23. insert_child_blocks,
  24. manhattan_insert,
  25. recursive_xy_cut,
  26. recursive_yx_cut,
  27. reference_insert,
  28. shrink_overlapping_boxes,
  29. sort_blocks,
  30. update_doc_title_child_blocks,
  31. update_paragraph_title_child_blocks,
  32. update_vision_child_blocks,
  33. weighted_distance_insert,
  34. )
  35. def pre_process(
  36. blocks: List[LayoutParsingBlock],
  37. config: Dict,
  38. ) -> List:
  39. """
  40. Preprocess the layout for sorting purposes.
  41. This function performs two main tasks:
  42. 1. Pre-cuts the layout to ensure the document is correctly partitioned and sorted.
  43. 2. Match the blocks with their children.
  44. Args:
  45. blocks (List[LayoutParsingBlock]): A list of LayoutParsingBlock objects representing the layout.
  46. config (Dict): Configuration parameters that include settings for pre-cutting and sorting.
  47. Returns:
  48. List: A list of pre-cutted layout blocks list.
  49. """
  50. region_bbox = config.get("all_layout_region_box", None)
  51. region_x_center = (region_bbox[0] + region_bbox[2]) / 2
  52. region_y_center = (region_bbox[1] + region_bbox[3]) / 2
  53. header_block_idxes = config.get("header_block_idxes", [])
  54. header_blocks = []
  55. for idx in header_block_idxes:
  56. blocks[idx].region_label = "header"
  57. header_blocks.append(blocks[idx])
  58. unordered_block_idxes = config.get("unordered_block_idxes", [])
  59. unordered_blocks = []
  60. for idx in unordered_block_idxes:
  61. blocks[idx].region_label = "unordered"
  62. unordered_blocks.append(blocks[idx])
  63. footer_block_idxes = config.get("footer_block_idxes", [])
  64. footer_blocks = []
  65. for idx in footer_block_idxes:
  66. blocks[idx].region_label = "footer"
  67. footer_blocks.append(blocks[idx])
  68. mask_labels = ["header", "unordered", "footer"]
  69. child_labels = [
  70. "vision_footnote",
  71. "sub_paragraph_title",
  72. "doc_title_text",
  73. "vision_title",
  74. ]
  75. pre_cut_block_idxes = []
  76. for block_idx, block in enumerate(blocks):
  77. if block.label in mask_labels:
  78. continue
  79. if block.region_label not in child_labels:
  80. update_region_label(blocks, config, block_idx)
  81. block_direction = block.direction
  82. if block_direction == "horizontal":
  83. region_bbox_center = region_x_center
  84. tolerance_len = block.long_side_length // 5
  85. else:
  86. region_bbox_center = region_y_center
  87. tolerance_len = block.short_side_length // 10
  88. block_center = (block.start_coordinate + block.end_coordinate) / 2
  89. center_offset = abs(block_center - region_bbox_center)
  90. is_centered = center_offset <= tolerance_len
  91. if is_centered:
  92. pre_cut_block_idxes.append(block_idx)
  93. pre_cut_list = []
  94. cut_direction = "vertical"
  95. cut_coordinates = []
  96. discontinuous = []
  97. mask_labels = child_labels + mask_labels
  98. all_boxes = np.array(
  99. [block.bbox for block in blocks if block.region_label not in mask_labels]
  100. )
  101. if pre_cut_block_idxes:
  102. horizontal_cut_num = 0
  103. for block_idx in pre_cut_block_idxes:
  104. block = blocks[block_idx]
  105. horizontal_cut_num += 1 if block.secondary_direction == "horizontal" else 0
  106. cut_direction = (
  107. "horizontal"
  108. if horizontal_cut_num > len(pre_cut_block_idxes) * 0.5
  109. else "vertical"
  110. )
  111. discontinuous = calculate_discontinuous_projection(
  112. all_boxes, direction=cut_direction
  113. )
  114. for idx in pre_cut_block_idxes:
  115. block = blocks[idx]
  116. if (
  117. block.region_label not in mask_labels
  118. and block.secondary_direction == cut_direction
  119. ):
  120. if (
  121. block.secondary_direction_start_coordinate,
  122. block.secondary_direction_end_coordinate,
  123. ) in discontinuous:
  124. cut_coordinates.append(block.secondary_direction_start_coordinate)
  125. cut_coordinates.append(block.secondary_direction_end_coordinate)
  126. if not discontinuous:
  127. discontinuous = calculate_discontinuous_projection(
  128. all_boxes, direction=cut_direction
  129. )
  130. current_interval = discontinuous[0]
  131. for interval in discontinuous[1:]:
  132. gap_len = interval[0] - current_interval[1]
  133. if gap_len > 40:
  134. cut_coordinates.append(current_interval[1])
  135. current_interval = interval
  136. overall_region_box = config.get("all_layout_region_box")
  137. cut_list = get_cut_blocks(
  138. blocks, cut_direction, cut_coordinates, overall_region_box, mask_labels
  139. )
  140. pre_cut_list.extend(cut_list)
  141. return header_blocks, pre_cut_list, footer_blocks, unordered_blocks
  142. def update_region_label(
  143. blocks: List[LayoutParsingBlock], config: Dict[str, Any], block_idx: int
  144. ) -> None:
  145. """
  146. Update the region label of a block based on its label and match the block with its children.
  147. Args:
  148. blocks (List[LayoutParsingBlock]): The list of blocks to process.
  149. config (Dict[str, Any]): The configuration dictionary containing the necessary information.
  150. block_idx (int): The index of the current block being processed.
  151. Returns:
  152. None
  153. """
  154. # special title block labels
  155. doc_title_labels = config.get("doc_title_labels", [])
  156. paragraph_title_labels = config.get("paragraph_title_labels", [])
  157. vision_labels = config.get("vision_labels", [])
  158. block = blocks[block_idx]
  159. if block.label in doc_title_labels:
  160. block.region_label = "doc_title"
  161. # Force the direction of vision type to be horizontal
  162. if block.label in vision_labels:
  163. block.region_label = "vision"
  164. block.update_direction_info()
  165. # some paragraph title block may be labeled as sub_title, so we need to check if block.region_label is "other"(default).
  166. if block.label in paragraph_title_labels and block.region_label == "other":
  167. block.region_label = "paragraph_title"
  168. # only vision and doc title block can have child block
  169. if block.region_label not in ["vision", "doc_title", "paragraph_title"]:
  170. return
  171. iou_threshold = config.get("child_block_match_iou_threshold", 0.1)
  172. # match doc title text block
  173. if block.region_label == "doc_title":
  174. text_block_idxes = config.get("text_block_idxes", [])
  175. prev_idx, post_idx = get_adjacent_blocks_by_direction(
  176. blocks, block_idx, text_block_idxes, iou_threshold
  177. )
  178. update_doc_title_child_blocks(blocks, block, prev_idx, post_idx, config)
  179. # match sub title block
  180. elif block.region_label == "paragraph_title":
  181. iou_threshold = config.get("sub_title_match_iou_threshold", 0.1)
  182. paragraph_title_block_idxes = config.get("paragraph_title_block_idxes", [])
  183. text_block_idxes = config.get("text_block_idxes", [])
  184. megred_block_idxes = text_block_idxes + paragraph_title_block_idxes
  185. prev_idx, post_idx = get_adjacent_blocks_by_direction(
  186. blocks, block_idx, megred_block_idxes, iou_threshold
  187. )
  188. update_paragraph_title_child_blocks(blocks, block, prev_idx, post_idx, config)
  189. # match vision title block
  190. elif block.region_label == "vision":
  191. # for matching vision title block
  192. vision_title_block_idxes = config.get("vision_title_block_idxes", [])
  193. # for matching vision footnote block
  194. text_block_idxes = config.get("text_block_idxes", [])
  195. megred_block_idxes = text_block_idxes + vision_title_block_idxes
  196. # Some vision title block may be matched with multiple vision title block, so we need to try multiple times
  197. for i in range(3):
  198. prev_idx, post_idx = get_adjacent_blocks_by_direction(
  199. blocks, block_idx, megred_block_idxes, iou_threshold
  200. )
  201. update_vision_child_blocks(
  202. blocks, block, megred_block_idxes, prev_idx, post_idx, config
  203. )
  204. def get_layout_structure(
  205. blocks: List[LayoutParsingBlock],
  206. median_width: float,
  207. config: dict,
  208. threshold: float = 0.8,
  209. ) -> Tuple[List[Dict[str, any]], bool]:
  210. """
  211. Determine the layout cross column of blocks.
  212. Args:
  213. blocks (List[Dict[str, any]]): List of block dictionaries containing 'label' and 'block_bbox'.
  214. median_width (float): Median width of text blocks.
  215. no_mask_labels (List[str]): Labels of blocks to be considered for layout analysis.
  216. threshold (float): Threshold for determining layout overlap.
  217. Returns:
  218. Tuple[List[Dict[str, any]], bool]: Updated list of blocks with layout information and a boolean
  219. indicating if the cross layout area is greater than the single layout area.
  220. """
  221. blocks.sort(
  222. key=lambda x: (x.bbox[0], x.width),
  223. )
  224. check_single_layout = {}
  225. doc_title_labels = config.get("doc_title_labels", [])
  226. region_box = config.get("all_layout_region_box", [0, 0, 0, 0])
  227. for block_idx, block in enumerate(blocks):
  228. cover_count = 0
  229. match_block_with_threshold_indexes = []
  230. for ref_idx, ref_block in enumerate(blocks):
  231. if block_idx == ref_idx:
  232. continue
  233. bbox_iou = calculate_iou(block.bbox, ref_block.bbox)
  234. if bbox_iou > 0:
  235. if block.region_label == "vision" or block.area < ref_block.area:
  236. block.region_label = "cross_text"
  237. break
  238. match_projection_iou = calculate_projection_iou(
  239. block.bbox,
  240. ref_block.bbox,
  241. "horizontal",
  242. )
  243. if match_projection_iou > 0:
  244. cover_count += 1
  245. if match_projection_iou > threshold:
  246. match_block_with_threshold_indexes.append(
  247. (ref_idx, match_projection_iou),
  248. )
  249. if ref_block.bbox[2] >= block.bbox[2]:
  250. break
  251. block_center = (block.bbox[0] + block.bbox[2]) / 2
  252. region_bbox_center = (region_box[0] + region_box[2]) / 2
  253. center_offset = abs(block_center - region_bbox_center)
  254. is_centered = center_offset <= median_width * 0.05
  255. width_gather_than_median = block.width > median_width * 1.3
  256. if (
  257. cover_count >= 2
  258. and block.label not in doc_title_labels
  259. and (width_gather_than_median != is_centered)
  260. ):
  261. block.region_label = (
  262. "cross_reference" if block.label == "reference" else "cross_text"
  263. )
  264. else:
  265. check_single_layout[block_idx] = match_block_with_threshold_indexes
  266. # Check single-layout block
  267. for idx, single_layout in check_single_layout.items():
  268. if single_layout:
  269. index, match_iou = single_layout[-1]
  270. if match_iou > 0.9 and blocks[index].region_label == "cross_text":
  271. blocks[idx].region_label = (
  272. "cross_reference" if block.label == "reference" else "cross_text"
  273. )
  274. def sort_by_xycut(
  275. block_bboxes: List,
  276. direction: int = 0,
  277. min_gap: int = 1,
  278. ) -> List[int]:
  279. """
  280. Sort bounding boxes using recursive XY cut method based on the specified direction.
  281. Args:
  282. block_bboxes (Union[np.ndarray, List[List[int]]]): An array or list of bounding boxes,
  283. where each box is represented as
  284. [x_min, y_min, x_max, y_max].
  285. direction (int): Direction for the initial cut. Use 1 for Y-axis first and 0 for X-axis first.
  286. Defaults to 0.
  287. min_gap (int): Minimum gap width to consider a separation between segments. Defaults to 1.
  288. Returns:
  289. List[int]: A list of indices representing the order of sorted bounding boxes.
  290. """
  291. block_bboxes = np.asarray(block_bboxes).astype(int)
  292. res = []
  293. if direction == 1:
  294. recursive_yx_cut(
  295. block_bboxes,
  296. np.arange(len(block_bboxes)).tolist(),
  297. res,
  298. min_gap,
  299. )
  300. else:
  301. recursive_xy_cut(
  302. block_bboxes,
  303. np.arange(len(block_bboxes)).tolist(),
  304. res,
  305. min_gap,
  306. )
  307. return res
  308. def match_unsorted_blocks(
  309. sorted_blocks: List[LayoutParsingBlock],
  310. unsorted_blocks: List[LayoutParsingBlock],
  311. config: Dict,
  312. median_width: int,
  313. ) -> List[LayoutParsingBlock]:
  314. """
  315. Match special blocks with the sorted blocks based on their region labels.
  316. Args:
  317. sorted_blocks (List[LayoutParsingBlock]): Sorted blocks to be matched.
  318. unsorted_blocks (List[LayoutParsingBlock]): Unsorted blocks to be matched.
  319. config (Dict): Configuration dictionary containing various parameters.
  320. median_width (int): Median width value used for calculations.
  321. Returns:
  322. List[LayoutParsingBlock]: The updated sorted blocks after matching special blocks.
  323. """
  324. distance_type_map = {
  325. "cross_text": weighted_distance_insert,
  326. "paragraph_title": weighted_distance_insert,
  327. "doc_title": weighted_distance_insert,
  328. "vision_title": weighted_distance_insert,
  329. "vision": weighted_distance_insert,
  330. "cross_reference": reference_insert,
  331. "unordered": manhattan_insert,
  332. "other": manhattan_insert,
  333. }
  334. unsorted_blocks = sort_blocks(unsorted_blocks, median_width, reverse=False)
  335. for idx, block in enumerate(unsorted_blocks):
  336. region_label = block.region_label
  337. if idx == 0 and region_label == "doc_title":
  338. sorted_blocks.insert(0, block)
  339. continue
  340. sorted_blocks = distance_type_map[region_label](
  341. block, sorted_blocks, config, median_width
  342. )
  343. return sorted_blocks
  344. def xycut_enhanced(
  345. blocks: List[LayoutParsingBlock], config: Dict
  346. ) -> List[LayoutParsingBlock]:
  347. """
  348. xycut_enhance function performs the following steps:
  349. 1. Preprocess the input blocks by extracting headers, footers, and pre-cut blocks.
  350. 2. Mask blocks that are crossing different blocks.
  351. 3. Perform xycut_enhanced algorithm on the remaining blocks.
  352. 4. Match special blocks with the sorted blocks based on their region labels.
  353. 5. Update child blocks of the sorted blocks based on their parent blocks.
  354. 6. Return the ordered result list.
  355. Args:
  356. blocks (List[LayoutParsingBlock]): Input blocks to be processed.
  357. Returns:
  358. List[LayoutParsingBlock]: Ordered result list after processing.
  359. """
  360. if len(blocks) == 0:
  361. return blocks
  362. text_labels = config.get("text_labels", [])
  363. header_blocks, pre_cut_list, footer_blocks, unordered_blocks = pre_process(
  364. blocks, config
  365. )
  366. final_order_res_list: List[LayoutParsingBlock] = []
  367. header_blocks = sort_blocks(header_blocks)
  368. footer_blocks = sort_blocks(footer_blocks)
  369. unordered_blocks = sort_blocks(unordered_blocks)
  370. final_order_res_list.extend(header_blocks)
  371. unsorted_blocks: List[LayoutParsingBlock] = []
  372. sorted_blocks_by_pre_cuts = []
  373. for pre_cut_blocks in pre_cut_list:
  374. sorted_blocks: List[LayoutParsingBlock] = []
  375. doc_title_blocks: List[LayoutParsingBlock] = []
  376. xy_cut_blocks: List[LayoutParsingBlock] = []
  377. pre_cut_blocks: List[LayoutParsingBlock]
  378. median_width = 1
  379. text_block_width = [
  380. block.width for block in pre_cut_blocks if block.label in text_labels
  381. ]
  382. if len(text_block_width) > 0:
  383. median_width = int(np.median(text_block_width))
  384. get_layout_structure(
  385. pre_cut_blocks,
  386. median_width,
  387. config,
  388. )
  389. # Get xy cut blocks and add other blocks in special_block_map
  390. for block in pre_cut_blocks:
  391. if block.region_label not in [
  392. "cross_text",
  393. "cross_reference",
  394. "doc_title",
  395. "unordered",
  396. ]:
  397. xy_cut_blocks.append(block)
  398. elif block.label == "doc_title":
  399. doc_title_blocks.append(block)
  400. else:
  401. unsorted_blocks.append(block)
  402. if len(xy_cut_blocks) > 0:
  403. block_bboxes = np.array([block.bbox for block in xy_cut_blocks])
  404. block_text_lines = [block.num_of_lines for block in xy_cut_blocks]
  405. discontinuous = calculate_discontinuous_projection(
  406. block_bboxes, direction="horizontal"
  407. )
  408. if len(discontinuous) == 1 or max(block_text_lines) == 1:
  409. xy_cut_blocks.sort(key=lambda x: (x.bbox[1] // 5, x.bbox[0]))
  410. xy_cut_blocks = shrink_overlapping_boxes(xy_cut_blocks, "vertical")
  411. block_bboxes = np.array([block.bbox for block in xy_cut_blocks])
  412. sorted_indexes = sort_by_xycut(block_bboxes, direction=1, min_gap=1)
  413. else:
  414. xy_cut_blocks.sort(key=lambda x: (x.bbox[0] // 20, x.bbox[1]))
  415. xy_cut_blocks = shrink_overlapping_boxes(xy_cut_blocks, "horizontal")
  416. block_bboxes = np.array([block.bbox for block in xy_cut_blocks])
  417. sorted_indexes = sort_by_xycut(block_bboxes, direction=0, min_gap=20)
  418. sorted_blocks = [xy_cut_blocks[i] for i in sorted_indexes]
  419. sorted_blocks = match_unsorted_blocks(
  420. sorted_blocks,
  421. doc_title_blocks,
  422. config,
  423. median_width,
  424. )
  425. sorted_blocks_by_pre_cuts.extend(sorted_blocks)
  426. median_width = 1
  427. text_block_width = [block.width for block in blocks if block.label in text_labels]
  428. if len(text_block_width) > 0:
  429. median_width = int(np.median(text_block_width))
  430. final_order_res_list = match_unsorted_blocks(
  431. sorted_blocks_by_pre_cuts,
  432. unsorted_blocks,
  433. config,
  434. median_width,
  435. )
  436. final_order_res_list.extend(footer_blocks)
  437. final_order_res_list.extend(unordered_blocks)
  438. index = 0
  439. visualize_index_labels = config.get("visualize_index_labels", [])
  440. for block_idx, block in enumerate(final_order_res_list):
  441. if block.label not in visualize_index_labels:
  442. continue
  443. final_order_res_list = insert_child_blocks(
  444. block, block_idx, final_order_res_list
  445. )
  446. block = final_order_res_list[block_idx]
  447. index += 1
  448. block.index = index
  449. return final_order_res_list