xycuts.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict, List, Tuple
  15. import numpy as np
  16. from ..result_v2 import LayoutParsingBlock, LayoutParsingRegion
  17. from ..setting import BLOCK_LABEL_MAP
  18. from ..utils import calculate_overlap_ratio, calculate_projection_overlap_ratio
  19. from .utils import (
  20. calculate_discontinuous_projection,
  21. get_cut_blocks,
  22. get_nearest_edge_distance,
  23. insert_child_blocks,
  24. is_projection_consistent,
  25. manhattan_insert,
  26. recursive_xy_cut,
  27. recursive_yx_cut,
  28. reference_insert,
  29. shrink_overlapping_boxes,
  30. sort_normal_blocks,
  31. update_doc_title_child_blocks,
  32. update_paragraph_title_child_blocks,
  33. update_vision_child_blocks,
  34. weighted_distance_insert,
  35. )
  36. def pre_process(
  37. region: LayoutParsingRegion,
  38. ) -> List:
  39. """
  40. Preprocess the layout for sorting purposes.
  41. This function performs two main tasks:
  42. 1. Pre-cuts the layout to ensure the document is correctly partitioned and sorted.
  43. 2. Match the blocks with their children.
  44. Args:
  45. region: LayoutParsingRegion, the layout region to be pre-processed.
  46. Returns:
  47. List: A list of pre-cutted layout blocks list.
  48. """
  49. mask_labels = [
  50. "header",
  51. "unordered",
  52. "footer",
  53. "vision_footnote",
  54. "sub_paragraph_title",
  55. "doc_title_text",
  56. "vision_title",
  57. ]
  58. pre_cut_block_idxes = []
  59. block_map = region.block_map
  60. blocks: List[LayoutParsingBlock] = list(block_map.values())
  61. for block in blocks:
  62. if block.order_label not in mask_labels:
  63. update_region_label(block, region)
  64. block_direction = block.direction
  65. if block_direction == "horizontal":
  66. tolerance_len = block.long_side_length // 5
  67. else:
  68. tolerance_len = block.short_side_length // 10
  69. block_center = (block.start_coordinate + block.end_coordinate) / 2
  70. center_offset = abs(block_center - region.direction_center_coordinate)
  71. is_centered = center_offset <= tolerance_len
  72. if is_centered:
  73. pre_cut_block_idxes.append(block.index)
  74. pre_cut_list = []
  75. cut_direction = region.secondary_direction
  76. cut_coordinates = []
  77. discontinuous = []
  78. all_boxes = np.array(
  79. [block.bbox for block in blocks if block.order_label not in mask_labels]
  80. )
  81. if len(all_boxes) == 0:
  82. return pre_cut_list
  83. if pre_cut_block_idxes:
  84. discontinuous, num_list = calculate_discontinuous_projection(
  85. all_boxes, direction=cut_direction, return_num=True
  86. )
  87. for idx in pre_cut_block_idxes:
  88. block = block_map[idx]
  89. if (
  90. block.order_label not in mask_labels
  91. and block.secondary_direction == cut_direction
  92. ):
  93. if (
  94. block.secondary_direction_start_coordinate,
  95. block.secondary_direction_end_coordinate,
  96. ) in discontinuous:
  97. idx = discontinuous.index(
  98. (
  99. block.secondary_direction_start_coordinate,
  100. block.secondary_direction_end_coordinate,
  101. )
  102. )
  103. if num_list[idx] == 1:
  104. cut_coordinates.append(
  105. block.secondary_direction_start_coordinate
  106. )
  107. cut_coordinates.append(block.secondary_direction_end_coordinate)
  108. secondary_discontinuous = calculate_discontinuous_projection(
  109. all_boxes, direction=region.direction
  110. )
  111. if len(secondary_discontinuous) == 1:
  112. if not discontinuous:
  113. discontinuous = calculate_discontinuous_projection(
  114. all_boxes, direction=cut_direction
  115. )
  116. current_interval = discontinuous[0]
  117. for interval in discontinuous[1:]:
  118. gap_len = interval[0] - current_interval[1]
  119. if gap_len >= region.text_line_height * 5:
  120. cut_coordinates.append(current_interval[1])
  121. elif gap_len > region.text_line_height * 2:
  122. x1, _, x2, __ = region.bbox
  123. y1 = current_interval[1]
  124. y2 = interval[0]
  125. bbox = [x1, y1, x2, y2]
  126. ref_interval = interval[0] - current_interval[1]
  127. ref_bboxes = []
  128. for block in blocks:
  129. if get_nearest_edge_distance(bbox, block.bbox) < ref_interval * 2:
  130. ref_bboxes.append(block.bbox)
  131. discontinuous = calculate_discontinuous_projection(
  132. ref_bboxes, direction=region.direction
  133. )
  134. if len(discontinuous) != 2:
  135. cut_coordinates.append(current_interval[1])
  136. current_interval = interval
  137. cut_list = get_cut_blocks(
  138. blocks, cut_direction, cut_coordinates, region.bbox, mask_labels
  139. )
  140. pre_cut_list.extend(cut_list)
  141. if region.direction == "vertical":
  142. pre_cut_list = pre_cut_list[::-1]
  143. return pre_cut_list
  144. def update_region_label(
  145. block: LayoutParsingBlock,
  146. region: LayoutParsingRegion,
  147. ) -> None:
  148. """
  149. Update the region label of a block based on its label and match the block with its children.
  150. Args:
  151. blocks (List[LayoutParsingBlock]): The list of blocks to process.
  152. config (Dict[str, Any]): The configuration dictionary containing the necessary information.
  153. block_idx (int): The index of the current block being processed.
  154. Returns:
  155. None
  156. """
  157. if block.label in BLOCK_LABEL_MAP["header_labels"]:
  158. block.order_label = "header"
  159. elif block.label in BLOCK_LABEL_MAP["doc_title_labels"]:
  160. block.order_label = "doc_title"
  161. elif (
  162. block.label in BLOCK_LABEL_MAP["paragraph_title_labels"]
  163. and block.order_label is None
  164. ):
  165. block.order_label = "paragraph_title"
  166. elif block.label in BLOCK_LABEL_MAP["vision_labels"]:
  167. block.order_label = "vision"
  168. block.num_of_lines = 1
  169. block.update_direction_info()
  170. elif block.label in BLOCK_LABEL_MAP["footer_labels"]:
  171. block.order_label = "footer"
  172. elif block.label in BLOCK_LABEL_MAP["unordered_labels"]:
  173. block.order_label = "unordered"
  174. else:
  175. block.order_label = "normal_text"
  176. # only vision and doc title block can have child block
  177. if block.order_label not in ["vision", "doc_title", "paragraph_title"]:
  178. return
  179. # match doc title text block
  180. if block.order_label == "doc_title":
  181. update_doc_title_child_blocks(block, region)
  182. # match sub title block
  183. elif block.order_label == "paragraph_title":
  184. update_paragraph_title_child_blocks(block, region)
  185. # match vision title block and vision footnote block
  186. elif block.order_label == "vision":
  187. update_vision_child_blocks(block, region)
  188. def get_layout_structure(
  189. blocks: List[LayoutParsingBlock],
  190. region_direction: str,
  191. region_secondary_direction: str,
  192. ) -> Tuple[List[Dict[str, any]], bool]:
  193. """
  194. Determine the layout cross column of blocks.
  195. Args:
  196. blocks (List[Dict[str, any]]): List of block dictionaries containing 'label' and 'block_bbox'.
  197. Returns:
  198. Tuple[List[Dict[str, any]], bool]: Updated list of blocks with layout information and a boolean
  199. indicating if the cross layout area is greater than the single layout area.
  200. """
  201. blocks.sort(
  202. key=lambda x: (x.bbox[0], x.width),
  203. )
  204. mask_labels = ["doc_title", "cross_layout", "cross_reference"]
  205. for block_idx, block in enumerate(blocks):
  206. if block.order_label in mask_labels:
  207. continue
  208. for ref_idx, ref_block in enumerate(blocks):
  209. if block_idx == ref_idx or ref_block.order_label in mask_labels:
  210. continue
  211. bbox_iou = calculate_overlap_ratio(block.bbox, ref_block.bbox)
  212. if bbox_iou > 0:
  213. if ref_block.order_label == "vision":
  214. ref_block.order_label = "cross_layout"
  215. break
  216. if block.order_label == "vision" or block.area < ref_block.area:
  217. block.order_label = "cross_layout"
  218. break
  219. match_projection_iou = calculate_projection_overlap_ratio(
  220. block.bbox,
  221. ref_block.bbox,
  222. region_direction,
  223. )
  224. if match_projection_iou > 0:
  225. for second_ref_idx, second_ref_block in enumerate(blocks):
  226. if (
  227. second_ref_idx in [block_idx, ref_idx]
  228. or second_ref_block.order_label in mask_labels
  229. ):
  230. continue
  231. bbox_iou = calculate_overlap_ratio(
  232. block.bbox, second_ref_block.bbox
  233. )
  234. if bbox_iou > 0.1:
  235. if second_ref_block.order_label == "vision":
  236. second_ref_block.order_label = "cross_layout"
  237. break
  238. if (
  239. block.order_label == "vision"
  240. or block.area < second_ref_block.area
  241. ):
  242. block.order_label = "cross_layout"
  243. break
  244. second_match_projection_iou = calculate_projection_overlap_ratio(
  245. block.bbox,
  246. second_ref_block.bbox,
  247. region_direction,
  248. )
  249. ref_match_projection_iou = calculate_projection_overlap_ratio(
  250. ref_block.bbox,
  251. second_ref_block.bbox,
  252. region_direction,
  253. )
  254. ref_match_projection_iou_ = calculate_projection_overlap_ratio(
  255. ref_block.bbox,
  256. second_ref_block.bbox,
  257. region_secondary_direction,
  258. )
  259. if (
  260. second_match_projection_iou > 0
  261. and ref_match_projection_iou == 0
  262. and ref_match_projection_iou_ > 0
  263. ):
  264. if block.order_label == "vision" or (
  265. ref_block.order_label == "normal_text"
  266. and second_ref_block.order_label == "normal_text"
  267. ):
  268. block.order_label = (
  269. "cross_reference"
  270. if block.label == "reference"
  271. else "cross_layout"
  272. )
  273. def sort_by_xycut(
  274. block_bboxes: List,
  275. direction: str = "vertical",
  276. min_gap: int = 1,
  277. ) -> List[int]:
  278. """
  279. Sort bounding boxes using recursive XY cut method based on the specified direction.
  280. Args:
  281. block_bboxes (Union[np.ndarray, List[List[int]]]): An array or list of bounding boxes,
  282. where each box is represented as
  283. [x_min, y_min, x_max, y_max].
  284. direction (int): direction for the initial cut. Use 1 for Y-axis first and 0 for X-axis first.
  285. Defaults to 0.
  286. min_gap (int): Minimum gap width to consider a separation between segments. Defaults to 1.
  287. Returns:
  288. List[int]: A list of indices representing the order of sorted bounding boxes.
  289. """
  290. block_bboxes = np.asarray(block_bboxes).astype(int)
  291. res = []
  292. if direction == "vertical":
  293. recursive_yx_cut(
  294. block_bboxes,
  295. np.arange(len(block_bboxes)).tolist(),
  296. res,
  297. min_gap,
  298. )
  299. else:
  300. recursive_xy_cut(
  301. block_bboxes,
  302. np.arange(len(block_bboxes)).tolist(),
  303. res,
  304. min_gap,
  305. )
  306. return res
  307. def match_unsorted_blocks(
  308. sorted_blocks: List[LayoutParsingBlock],
  309. unsorted_blocks: List[LayoutParsingBlock],
  310. region: LayoutParsingRegion,
  311. ) -> List[LayoutParsingBlock]:
  312. """
  313. Match special blocks with the sorted blocks based on their region labels.
  314. Args:
  315. sorted_blocks (List[LayoutParsingBlock]): Sorted blocks to be matched.
  316. unsorted_blocks (List[LayoutParsingBlock]): Unsorted blocks to be matched.
  317. config (Dict): Configuration dictionary containing various parameters.
  318. median_width (int): Median width value used for calculations.
  319. Returns:
  320. List[LayoutParsingBlock]: The updated sorted blocks after matching special blocks.
  321. """
  322. distance_type_map = {
  323. "cross_layout": weighted_distance_insert,
  324. "paragraph_title": weighted_distance_insert,
  325. "doc_title": weighted_distance_insert,
  326. "vision_title": weighted_distance_insert,
  327. "vision": weighted_distance_insert,
  328. "cross_reference": reference_insert,
  329. "unordered": manhattan_insert,
  330. "other": manhattan_insert,
  331. }
  332. unsorted_blocks = sort_normal_blocks(
  333. unsorted_blocks,
  334. region.text_line_height,
  335. region.text_line_width,
  336. region.direction,
  337. )
  338. for idx, block in enumerate(unsorted_blocks):
  339. order_label = block.order_label
  340. if idx == 0 and order_label == "doc_title":
  341. sorted_blocks.insert(0, block)
  342. continue
  343. sorted_blocks = distance_type_map[order_label](block, sorted_blocks, region)
  344. return sorted_blocks
  345. def xycut_enhanced(
  346. region: LayoutParsingRegion,
  347. ) -> LayoutParsingRegion:
  348. """
  349. xycut_enhance function performs the following steps:
  350. 1. Preprocess the input blocks by extracting headers, footers, and pre-cut blocks.
  351. 2. Mask blocks that are crossing different blocks.
  352. 3. Perform xycut_enhanced algorithm on the remaining blocks.
  353. 4. Match unsorted blocks with the sorted blocks based on their order labels.
  354. 5. Update child blocks of the sorted blocks based on their parent blocks.
  355. 6. Return the ordered result list.
  356. Args:
  357. blocks (List[LayoutParsingBlock]): Input blocks to be processed.
  358. Returns:
  359. List[LayoutParsingBlock]: Ordered result list after processing.
  360. """
  361. if len(region.block_map) == 0:
  362. return []
  363. pre_cut_list: List[List[LayoutParsingBlock]] = pre_process(region)
  364. final_order_res_list: List[LayoutParsingBlock] = []
  365. header_blocks: List[LayoutParsingBlock] = [
  366. region.block_map[idx] for idx in region.header_block_idxes
  367. ]
  368. unordered_blocks: List[LayoutParsingBlock] = [
  369. region.block_map[idx] for idx in region.unordered_block_idxes
  370. ]
  371. footer_blocks: List[LayoutParsingBlock] = [
  372. region.block_map[idx] for idx in region.footer_block_idxes
  373. ]
  374. header_blocks: List[LayoutParsingBlock] = sort_normal_blocks(
  375. header_blocks, region.text_line_height, region.text_line_width, region.direction
  376. )
  377. footer_blocks: List[LayoutParsingBlock] = sort_normal_blocks(
  378. footer_blocks, region.text_line_height, region.text_line_width, region.direction
  379. )
  380. unordered_blocks: List[LayoutParsingBlock] = sort_normal_blocks(
  381. unordered_blocks,
  382. region.text_line_height,
  383. region.text_line_width,
  384. region.direction,
  385. )
  386. final_order_res_list.extend(header_blocks)
  387. unsorted_blocks: List[LayoutParsingBlock] = []
  388. sorted_blocks_by_pre_cuts: List[LayoutParsingBlock] = []
  389. for pre_cut_blocks in pre_cut_list:
  390. sorted_blocks: List[LayoutParsingBlock] = []
  391. doc_title_blocks: List[LayoutParsingBlock] = []
  392. xy_cut_blocks: List[LayoutParsingBlock] = []
  393. get_layout_structure(
  394. pre_cut_blocks, region.direction, region.secondary_direction
  395. )
  396. # Get xy cut blocks and add other blocks in special_block_map
  397. for block in pre_cut_blocks:
  398. if block.order_label not in [
  399. "cross_layout",
  400. "cross_reference",
  401. "doc_title",
  402. "unordered",
  403. ]:
  404. xy_cut_blocks.append(block)
  405. elif block.label == "doc_title":
  406. doc_title_blocks.append(block)
  407. else:
  408. unsorted_blocks.append(block)
  409. if len(xy_cut_blocks) > 0:
  410. block_bboxes = np.array([block.bbox for block in xy_cut_blocks])
  411. block_text_lines = [block.num_of_lines for block in xy_cut_blocks]
  412. discontinuous = calculate_discontinuous_projection(
  413. block_bboxes, direction=region.direction
  414. )
  415. if len(discontinuous) > 1:
  416. xy_cut_blocks = [block for block in xy_cut_blocks]
  417. # if len(discontinuous) == 1 or max(block_text_lines) == 1 or (not is_projection_consistent(xy_cut_blocks, discontinuous, direction=region.direction) and len(discontinuous) > 2 and max(block_text_lines) - min(block_text_lines) < 3):
  418. if len(discontinuous) == 1 or max(block_text_lines) == 1:
  419. xy_cut_blocks.sort(
  420. key=lambda x: (
  421. x.bbox[region.secondary_direction_start_index]
  422. // (region.text_line_height // 2),
  423. x.bbox[region.direction_start_index],
  424. )
  425. )
  426. xy_cut_blocks = shrink_overlapping_boxes(
  427. xy_cut_blocks, region.secondary_direction
  428. )
  429. if (
  430. len(discontinuous) == 1
  431. or max(block_text_lines) == 1
  432. or (
  433. not is_projection_consistent(
  434. xy_cut_blocks, discontinuous, direction=region.direction
  435. )
  436. and len(discontinuous) > 2
  437. and max(block_text_lines) - min(block_text_lines) < 3
  438. )
  439. ):
  440. xy_cut_blocks.sort(
  441. key=lambda x: (
  442. x.bbox[region.secondary_direction_start_index]
  443. // (region.text_line_height // 2),
  444. x.bbox[region.direction_start_index],
  445. )
  446. )
  447. xy_cut_blocks = shrink_overlapping_boxes(
  448. xy_cut_blocks, region.secondary_direction
  449. )
  450. block_bboxes = np.array([block.bbox for block in xy_cut_blocks])
  451. sorted_indexes = sort_by_xycut(
  452. block_bboxes, direction=region.secondary_direction, min_gap=1
  453. )
  454. else:
  455. xy_cut_blocks.sort(
  456. key=lambda x: (
  457. x.bbox[region.direction_start_index]
  458. // (region.text_line_width // 2),
  459. x.bbox[region.secondary_direction_start_index],
  460. )
  461. )
  462. xy_cut_blocks = shrink_overlapping_boxes(
  463. xy_cut_blocks, region.direction
  464. )
  465. block_bboxes = np.array([block.bbox for block in xy_cut_blocks])
  466. sorted_indexes = sort_by_xycut(
  467. block_bboxes, direction=region.direction, min_gap=1
  468. )
  469. sorted_blocks = [xy_cut_blocks[i] for i in sorted_indexes]
  470. sorted_blocks = match_unsorted_blocks(
  471. sorted_blocks,
  472. doc_title_blocks,
  473. region=region,
  474. )
  475. sorted_blocks_by_pre_cuts.extend(sorted_blocks)
  476. final_order_res_list = match_unsorted_blocks(
  477. sorted_blocks_by_pre_cuts,
  478. unsorted_blocks,
  479. region=region,
  480. )
  481. final_order_res_list.extend(footer_blocks)
  482. final_order_res_list.extend(unordered_blocks)
  483. for block_idx, block in enumerate(final_order_res_list):
  484. final_order_res_list = insert_child_blocks(
  485. block, block_idx, final_order_res_list
  486. )
  487. block = final_order_res_list[block_idx]
  488. return final_order_res_list