xycuts.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, List, Tuple
  15. import numpy as np
  16. from ..result_v2 import LayoutParsingBlock
  17. from ..utils import calculate_overlap_ratio, calculate_projection_overlap_ratio
  18. from .utils import (
  19. calculate_discontinuous_projection,
  20. get_adjacent_blocks_by_orientation,
  21. get_cut_blocks,
  22. get_nearest_edge_distance,
  23. insert_child_blocks,
  24. manhattan_insert,
  25. recursive_xy_cut,
  26. recursive_yx_cut,
  27. reference_insert,
  28. shrink_overlapping_boxes,
  29. sort_blocks,
  30. update_doc_title_child_blocks,
  31. update_paragraph_title_child_blocks,
  32. update_vision_child_blocks,
  33. weighted_distance_insert,
  34. )
  35. def pre_process(
  36. blocks: List[LayoutParsingBlock],
  37. config: Dict,
  38. ) -> List:
  39. """
  40. Preprocess the layout for sorting purposes.
  41. This function performs two main tasks:
  42. 1. Pre-cuts the layout to ensure the document is correctly partitioned and sorted.
  43. 2. Match the blocks with their children.
  44. Args:
  45. blocks (List[LayoutParsingBlock]): A list of LayoutParsingBlock objects representing the layout.
  46. config (Dict): Configuration parameters that include settings for pre-cutting and sorting.
  47. Returns:
  48. List: A list of pre-cutted layout blocks list.
  49. """
  50. region_bbox = config.get("region_bbox", None)
  51. region_x_center = (region_bbox[0] + region_bbox[2]) / 2
  52. region_y_center = (region_bbox[1] + region_bbox[3]) / 2
  53. header_block_idxes = config.get("header_block_idxes", [])
  54. header_blocks = []
  55. for idx in header_block_idxes:
  56. blocks[idx].order_label = "header"
  57. header_blocks.append(blocks[idx])
  58. unordered_block_idxes = config.get("unordered_block_idxes", [])
  59. unordered_blocks = []
  60. for idx in unordered_block_idxes:
  61. blocks[idx].order_label = "unordered"
  62. unordered_blocks.append(blocks[idx])
  63. footer_block_idxes = config.get("footer_block_idxes", [])
  64. footer_blocks = []
  65. for idx in footer_block_idxes:
  66. blocks[idx].order_label = "footer"
  67. footer_blocks.append(blocks[idx])
  68. mask_labels = ["header", "unordered", "footer"]
  69. child_labels = [
  70. "vision_footnote",
  71. "sub_paragraph_title",
  72. "doc_title_text",
  73. "vision_title",
  74. ]
  75. pre_cut_block_idxes = []
  76. for block_idx, block in enumerate(blocks):
  77. if block.label in mask_labels:
  78. continue
  79. if block.order_label not in child_labels:
  80. update_region_label(blocks, config, block_idx)
  81. block_orientation = block.orientation
  82. if block_orientation == "horizontal":
  83. region_bbox_center = region_x_center
  84. tolerance_len = block.long_side_length // 5
  85. else:
  86. region_bbox_center = region_y_center
  87. tolerance_len = block.short_side_length // 10
  88. block_center = (block.start_coordinate + block.end_coordinate) / 2
  89. center_offset = abs(block_center - region_bbox_center)
  90. is_centered = center_offset <= tolerance_len
  91. if is_centered:
  92. pre_cut_block_idxes.append(block_idx)
  93. pre_cut_list = []
  94. cut_orientation = "vertical"
  95. cut_coordinates = []
  96. discontinuous = []
  97. mask_labels = child_labels + mask_labels
  98. all_boxes = np.array(
  99. [block.bbox for block in blocks if block.order_label not in mask_labels]
  100. )
  101. if len(all_boxes) == 0:
  102. return header_blocks, pre_cut_list, footer_blocks, unordered_blocks
  103. if pre_cut_block_idxes:
  104. horizontal_cut_num = 0
  105. for block_idx in pre_cut_block_idxes:
  106. block = blocks[block_idx]
  107. horizontal_cut_num += (
  108. 1 if block.secondary_orientation == "horizontal" else 0
  109. )
  110. cut_orientation = (
  111. "horizontal"
  112. if horizontal_cut_num > len(pre_cut_block_idxes) * 0.5
  113. else "vertical"
  114. )
  115. discontinuous, num_list = calculate_discontinuous_projection(
  116. all_boxes, orientation=cut_orientation, return_num=True
  117. )
  118. for idx in pre_cut_block_idxes:
  119. block = blocks[idx]
  120. if (
  121. block.order_label not in mask_labels
  122. and block.secondary_orientation == cut_orientation
  123. ):
  124. if (
  125. block.secondary_orientation_start_coordinate,
  126. block.secondary_orientation_end_coordinate,
  127. ) in discontinuous:
  128. idx = discontinuous.index(
  129. (
  130. block.secondary_orientation_start_coordinate,
  131. block.secondary_orientation_end_coordinate,
  132. )
  133. )
  134. if num_list[idx] == 1:
  135. cut_coordinates.append(
  136. block.secondary_orientation_start_coordinate
  137. )
  138. cut_coordinates.append(
  139. block.secondary_orientation_end_coordinate
  140. )
  141. if not discontinuous:
  142. discontinuous = calculate_discontinuous_projection(
  143. all_boxes, orientation=cut_orientation
  144. )
  145. current_interval = discontinuous[0]
  146. for interval in discontinuous[1:]:
  147. gap_len = interval[0] - current_interval[1]
  148. if gap_len >= 60:
  149. cut_coordinates.append(current_interval[1])
  150. elif gap_len > 40:
  151. x1, _, x2, __ = region_bbox
  152. y1 = current_interval[1]
  153. y2 = interval[0]
  154. bbox = [x1, y1, x2, y2]
  155. ref_interval = interval[0] - current_interval[1]
  156. ref_bboxes = []
  157. for block in blocks:
  158. if get_nearest_edge_distance(bbox, block.bbox) < ref_interval * 2:
  159. ref_bboxes.append(block.bbox)
  160. discontinuous = calculate_discontinuous_projection(
  161. ref_bboxes, orientation="horizontal"
  162. )
  163. if len(discontinuous) != 2:
  164. cut_coordinates.append(current_interval[1])
  165. current_interval = interval
  166. cut_list = get_cut_blocks(
  167. blocks, cut_orientation, cut_coordinates, region_bbox, mask_labels
  168. )
  169. pre_cut_list.extend(cut_list)
  170. return header_blocks, pre_cut_list, footer_blocks, unordered_blocks
  171. def update_region_label(
  172. blocks: List[LayoutParsingBlock], config: Dict[str, Any], block_idx: int
  173. ) -> None:
  174. """
  175. Update the region label of a block based on its label and match the block with its children.
  176. Args:
  177. blocks (List[LayoutParsingBlock]): The list of blocks to process.
  178. config (Dict[str, Any]): The configuration dictionary containing the necessary information.
  179. block_idx (int): The index of the current block being processed.
  180. Returns:
  181. None
  182. """
  183. # special title block labels
  184. doc_title_labels = config.get("doc_title_labels", [])
  185. paragraph_title_labels = config.get("paragraph_title_labels", [])
  186. vision_labels = config.get("vision_labels", [])
  187. block = blocks[block_idx]
  188. if block.label in doc_title_labels:
  189. block.order_label = "doc_title"
  190. # Force the orientation of vision type to be horizontal
  191. if block.label in vision_labels:
  192. block.order_label = "vision"
  193. block.num_of_lines = 1
  194. block.update_orientation_info()
  195. # some paragraph title block may be labeled as sub_title, so we need to check if block.order_label is "other"(default).
  196. if block.label in paragraph_title_labels and block.order_label == "other":
  197. block.order_label = "paragraph_title"
  198. # only vision and doc title block can have child block
  199. if block.order_label not in ["vision", "doc_title", "paragraph_title"]:
  200. return
  201. iou_threshold = config.get("child_block_match_iou_threshold", 0.1)
  202. # match doc title text block
  203. if block.order_label == "doc_title":
  204. text_block_idxes = config.get("text_block_idxes", [])
  205. prev_idx, post_idx = get_adjacent_blocks_by_orientation(
  206. blocks, block_idx, text_block_idxes, iou_threshold
  207. )
  208. update_doc_title_child_blocks(blocks, block, prev_idx, post_idx, config)
  209. # match sub title block
  210. elif block.order_label == "paragraph_title":
  211. iou_threshold = config.get("sub_title_match_iou_threshold", 0.1)
  212. paragraph_title_block_idxes = config.get("paragraph_title_block_idxes", [])
  213. text_block_idxes = config.get("text_block_idxes", [])
  214. megred_block_idxes = text_block_idxes + paragraph_title_block_idxes
  215. prev_idx, post_idx = get_adjacent_blocks_by_orientation(
  216. blocks, block_idx, megred_block_idxes, iou_threshold
  217. )
  218. update_paragraph_title_child_blocks(blocks, block, prev_idx, post_idx, config)
  219. # match vision title block
  220. elif block.order_label == "vision":
  221. # for matching vision title block
  222. vision_title_block_idxes = config.get("vision_title_block_idxes", [])
  223. # for matching vision footnote block
  224. text_block_idxes = config.get("text_block_idxes", [])
  225. megred_block_idxes = text_block_idxes + vision_title_block_idxes
  226. # Some vision title block may be matched with multiple vision title block, so we need to try multiple times
  227. for i in range(3):
  228. prev_idx, post_idx = get_adjacent_blocks_by_orientation(
  229. blocks, block_idx, megred_block_idxes, iou_threshold
  230. )
  231. update_vision_child_blocks(
  232. blocks, block, megred_block_idxes, prev_idx, post_idx, config
  233. )
  234. def get_layout_structure(
  235. blocks: List[LayoutParsingBlock],
  236. ) -> Tuple[List[Dict[str, any]], bool]:
  237. """
  238. Determine the layout cross column of blocks.
  239. Args:
  240. blocks (List[Dict[str, any]]): List of block dictionaries containing 'label' and 'block_bbox'.
  241. Returns:
  242. Tuple[List[Dict[str, any]], bool]: Updated list of blocks with layout information and a boolean
  243. indicating if the cross layout area is greater than the single layout area.
  244. """
  245. blocks.sort(
  246. key=lambda x: (x.bbox[0], x.width),
  247. )
  248. mask_labels = ["doc_title", "cross_text", "cross_reference"]
  249. for block_idx, block in enumerate(blocks):
  250. if block.order_label in mask_labels:
  251. continue
  252. for ref_idx, ref_block in enumerate(blocks):
  253. if block_idx == ref_idx or ref_block.order_label in mask_labels:
  254. continue
  255. bbox_iou = calculate_overlap_ratio(block.bbox, ref_block.bbox)
  256. if bbox_iou > 0:
  257. if ref_block.order_label == "vision":
  258. ref_block.order_label = "cross_text"
  259. break
  260. if block.order_label == "vision" or block.area < ref_block.area:
  261. block.order_label = "cross_text"
  262. break
  263. match_projection_iou = calculate_projection_overlap_ratio(
  264. block.bbox,
  265. ref_block.bbox,
  266. "horizontal",
  267. )
  268. if match_projection_iou > 0:
  269. for second_ref_idx, second_ref_block in enumerate(blocks):
  270. if (
  271. second_ref_idx in [block_idx, ref_idx]
  272. or second_ref_block.order_label in mask_labels
  273. ):
  274. continue
  275. bbox_iou = calculate_overlap_ratio(
  276. block.bbox, second_ref_block.bbox
  277. )
  278. if bbox_iou > 0.1:
  279. if second_ref_block.order_label == "vision":
  280. second_ref_block.order_label = "cross_text"
  281. break
  282. if (
  283. block.order_label == "vision"
  284. or block.area < second_ref_block.area
  285. ):
  286. block.order_label = "cross_text"
  287. break
  288. second_match_projection_iou = calculate_projection_overlap_ratio(
  289. block.bbox,
  290. second_ref_block.bbox,
  291. "horizontal",
  292. )
  293. ref_match_projection_iou = calculate_projection_overlap_ratio(
  294. ref_block.bbox,
  295. second_ref_block.bbox,
  296. "horizontal",
  297. )
  298. ref_match_projection_iou_ = calculate_projection_overlap_ratio(
  299. ref_block.bbox,
  300. second_ref_block.bbox,
  301. "vertical",
  302. )
  303. if (
  304. second_match_projection_iou > 0
  305. and ref_match_projection_iou == 0
  306. and ref_match_projection_iou_ > 0
  307. and "vision"
  308. not in [ref_block.order_label, second_ref_block.order_label]
  309. ):
  310. block.order_label = (
  311. "cross_reference"
  312. if block.label == "reference"
  313. else "cross_text"
  314. )
  315. def sort_by_xycut(
  316. block_bboxes: List,
  317. orientation: int = 0,
  318. min_gap: int = 1,
  319. ) -> List[int]:
  320. """
  321. Sort bounding boxes using recursive XY cut method based on the specified orientation.
  322. Args:
  323. block_bboxes (Union[np.ndarray, List[List[int]]]): An array or list of bounding boxes,
  324. where each box is represented as
  325. [x_min, y_min, x_max, y_max].
  326. orientation (int): orientation for the initial cut. Use 1 for Y-axis first and 0 for X-axis first.
  327. Defaults to 0.
  328. min_gap (int): Minimum gap width to consider a separation between segments. Defaults to 1.
  329. Returns:
  330. List[int]: A list of indices representing the order of sorted bounding boxes.
  331. """
  332. block_bboxes = np.asarray(block_bboxes).astype(int)
  333. res = []
  334. if orientation == 1:
  335. recursive_yx_cut(
  336. block_bboxes,
  337. np.arange(len(block_bboxes)).tolist(),
  338. res,
  339. min_gap,
  340. )
  341. else:
  342. recursive_xy_cut(
  343. block_bboxes,
  344. np.arange(len(block_bboxes)).tolist(),
  345. res,
  346. min_gap,
  347. )
  348. return res
  349. def match_unsorted_blocks(
  350. sorted_blocks: List[LayoutParsingBlock],
  351. unsorted_blocks: List[LayoutParsingBlock],
  352. config: Dict,
  353. median_width: int,
  354. ) -> List[LayoutParsingBlock]:
  355. """
  356. Match special blocks with the sorted blocks based on their region labels.
  357. Args:
  358. sorted_blocks (List[LayoutParsingBlock]): Sorted blocks to be matched.
  359. unsorted_blocks (List[LayoutParsingBlock]): Unsorted blocks to be matched.
  360. config (Dict): Configuration dictionary containing various parameters.
  361. median_width (int): Median width value used for calculations.
  362. Returns:
  363. List[LayoutParsingBlock]: The updated sorted blocks after matching special blocks.
  364. """
  365. distance_type_map = {
  366. "cross_text": weighted_distance_insert,
  367. "paragraph_title": weighted_distance_insert,
  368. "doc_title": weighted_distance_insert,
  369. "vision_title": weighted_distance_insert,
  370. "vision": weighted_distance_insert,
  371. "cross_reference": reference_insert,
  372. "unordered": manhattan_insert,
  373. "other": manhattan_insert,
  374. }
  375. unsorted_blocks = sort_blocks(unsorted_blocks, median_width, reverse=False)
  376. for idx, block in enumerate(unsorted_blocks):
  377. order_label = block.order_label
  378. if idx == 0 and order_label == "doc_title":
  379. sorted_blocks.insert(0, block)
  380. continue
  381. sorted_blocks = distance_type_map[order_label](
  382. block, sorted_blocks, config, median_width
  383. )
  384. return sorted_blocks
  385. def xycut_enhanced(
  386. blocks: List[LayoutParsingBlock], config: Dict
  387. ) -> List[LayoutParsingBlock]:
  388. """
  389. xycut_enhance function performs the following steps:
  390. 1. Preprocess the input blocks by extracting headers, footers, and pre-cut blocks.
  391. 2. Mask blocks that are crossing different blocks.
  392. 3. Perform xycut_enhanced algorithm on the remaining blocks.
  393. 4. Match unsorted blocks with the sorted blocks based on their order labels.
  394. 5. Update child blocks of the sorted blocks based on their parent blocks.
  395. 6. Return the ordered result list.
  396. Args:
  397. blocks (List[LayoutParsingBlock]): Input blocks to be processed.
  398. Returns:
  399. List[LayoutParsingBlock]: Ordered result list after processing.
  400. """
  401. if len(blocks) == 0:
  402. return blocks
  403. text_labels = config.get("text_labels", [])
  404. header_blocks, pre_cut_list, footer_blocks, unordered_blocks = pre_process(
  405. blocks, config
  406. )
  407. final_order_res_list: List[LayoutParsingBlock] = []
  408. header_blocks = sort_blocks(header_blocks)
  409. footer_blocks = sort_blocks(footer_blocks)
  410. unordered_blocks = sort_blocks(unordered_blocks)
  411. final_order_res_list.extend(header_blocks)
  412. unsorted_blocks: List[LayoutParsingBlock] = []
  413. sorted_blocks_by_pre_cuts = []
  414. for pre_cut_blocks in pre_cut_list:
  415. sorted_blocks: List[LayoutParsingBlock] = []
  416. doc_title_blocks: List[LayoutParsingBlock] = []
  417. xy_cut_blocks: List[LayoutParsingBlock] = []
  418. pre_cut_blocks: List[LayoutParsingBlock]
  419. median_width = 1
  420. text_block_width = [
  421. block.width for block in pre_cut_blocks if block.label in text_labels
  422. ]
  423. if len(text_block_width) > 0:
  424. median_width = int(np.median(text_block_width))
  425. get_layout_structure(
  426. pre_cut_blocks,
  427. )
  428. # Get xy cut blocks and add other blocks in special_block_map
  429. for block in pre_cut_blocks:
  430. if block.order_label not in [
  431. "cross_text",
  432. "cross_reference",
  433. "doc_title",
  434. "unordered",
  435. ]:
  436. xy_cut_blocks.append(block)
  437. elif block.label == "doc_title":
  438. doc_title_blocks.append(block)
  439. else:
  440. unsorted_blocks.append(block)
  441. if len(xy_cut_blocks) > 0:
  442. block_bboxes = np.array([block.bbox for block in xy_cut_blocks])
  443. block_text_lines = [block.num_of_lines for block in xy_cut_blocks]
  444. discontinuous = calculate_discontinuous_projection(
  445. block_bboxes, orientation="horizontal"
  446. )
  447. if len(discontinuous) > 1:
  448. xy_cut_blocks = [block for block in xy_cut_blocks]
  449. if len(discontinuous) == 1 or max(block_text_lines) == 1:
  450. xy_cut_blocks.sort(key=lambda x: (x.bbox[1] // 5, x.bbox[0]))
  451. xy_cut_blocks = shrink_overlapping_boxes(xy_cut_blocks, "vertical")
  452. block_bboxes = np.array([block.bbox for block in xy_cut_blocks])
  453. sorted_indexes = sort_by_xycut(block_bboxes, orientation=1, min_gap=1)
  454. else:
  455. xy_cut_blocks.sort(key=lambda x: (x.bbox[0] // 20, x.bbox[1]))
  456. xy_cut_blocks = shrink_overlapping_boxes(xy_cut_blocks, "horizontal")
  457. block_bboxes = np.array([block.bbox for block in xy_cut_blocks])
  458. sorted_indexes = sort_by_xycut(block_bboxes, orientation=0, min_gap=20)
  459. sorted_blocks = [xy_cut_blocks[i] for i in sorted_indexes]
  460. sorted_blocks = match_unsorted_blocks(
  461. sorted_blocks,
  462. doc_title_blocks,
  463. config,
  464. median_width,
  465. )
  466. sorted_blocks_by_pre_cuts.extend(sorted_blocks)
  467. median_width = 1
  468. text_block_width = [block.width for block in blocks if block.label in text_labels]
  469. if len(text_block_width) > 0:
  470. median_width = int(np.median(text_block_width))
  471. final_order_res_list = match_unsorted_blocks(
  472. sorted_blocks_by_pre_cuts,
  473. unsorted_blocks,
  474. config,
  475. median_width,
  476. )
  477. final_order_res_list.extend(footer_blocks)
  478. final_order_res_list.extend(unordered_blocks)
  479. for block_idx, block in enumerate(final_order_res_list):
  480. final_order_res_list = insert_child_blocks(
  481. block, block_idx, final_order_res_list
  482. )
  483. block = final_order_res_list[block_idx]
  484. return final_order_res_list