utils.py 75 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. __all__ = [
  15. "get_sub_regions_ocr_res",
  16. "get_layout_ordering",
  17. "get_single_block_parsing_res",
  18. "recursive_img_array2path",
  19. "get_show_color",
  20. "sorted_layout_boxes",
  21. ]
  22. import numpy as np
  23. import copy
  24. import cv2
  25. import uuid
  26. from pathlib import Path
  27. from typing import Optional, Union, List, Tuple, Dict, Any
  28. from ..ocr.result import OCRResult
  29. from ...models.object_detection.result import DetResult
  30. from ..components import convert_points_to_boxes
  31. def get_overlap_boxes_idx(src_boxes: np.ndarray, ref_boxes: np.ndarray) -> List:
  32. """
  33. Get the indices of source boxes that overlap with reference boxes based on a specified threshold.
  34. Args:
  35. src_boxes (np.ndarray): A 2D numpy array of source bounding boxes.
  36. ref_boxes (np.ndarray): A 2D numpy array of reference bounding boxes.
  37. Returns:
  38. match_idx_list (list): A list of indices of source boxes that overlap with reference boxes.
  39. """
  40. match_idx_list = []
  41. src_boxes_num = len(src_boxes)
  42. if src_boxes_num > 0 and len(ref_boxes) > 0:
  43. for rno in range(len(ref_boxes)):
  44. ref_box = ref_boxes[rno]
  45. x1 = np.maximum(ref_box[0], src_boxes[:, 0])
  46. y1 = np.maximum(ref_box[1], src_boxes[:, 1])
  47. x2 = np.minimum(ref_box[2], src_boxes[:, 2])
  48. y2 = np.minimum(ref_box[3], src_boxes[:, 3])
  49. pub_w = x2 - x1
  50. pub_h = y2 - y1
  51. match_idx = np.where((pub_w > 3) & (pub_h > 3))[0]
  52. match_idx_list.extend(match_idx)
  53. return match_idx_list
  54. def get_sub_regions_ocr_res(
  55. overall_ocr_res: OCRResult,
  56. object_boxes: List,
  57. flag_within: bool = True,
  58. return_match_idx: bool = False,
  59. ) -> OCRResult:
  60. """
  61. Filters OCR results to only include text boxes within specified object boxes based on a flag.
  62. Args:
  63. overall_ocr_res (OCRResult): The original OCR result containing all text boxes.
  64. object_boxes (list): A list of bounding boxes for the objects of interest.
  65. flag_within (bool): If True, only include text boxes within the object boxes. If False, exclude text boxes within the object boxes.
  66. return_match_idx (bool): If True, return the list of matching indices.
  67. Returns:
  68. OCRResult: A filtered OCR result containing only the relevant text boxes.
  69. """
  70. sub_regions_ocr_res = {}
  71. sub_regions_ocr_res["rec_polys"] = []
  72. sub_regions_ocr_res["rec_texts"] = []
  73. sub_regions_ocr_res["rec_scores"] = []
  74. sub_regions_ocr_res["rec_boxes"] = []
  75. overall_text_boxes = overall_ocr_res["rec_boxes"]
  76. match_idx_list = get_overlap_boxes_idx(overall_text_boxes, object_boxes)
  77. match_idx_list = list(set(match_idx_list))
  78. for box_no in range(len(overall_text_boxes)):
  79. if flag_within:
  80. if box_no in match_idx_list:
  81. flag_match = True
  82. else:
  83. flag_match = False
  84. else:
  85. if box_no not in match_idx_list:
  86. flag_match = True
  87. else:
  88. flag_match = False
  89. if flag_match:
  90. sub_regions_ocr_res["rec_polys"].append(
  91. overall_ocr_res["rec_polys"][box_no]
  92. )
  93. sub_regions_ocr_res["rec_texts"].append(
  94. overall_ocr_res["rec_texts"][box_no]
  95. )
  96. sub_regions_ocr_res["rec_scores"].append(
  97. overall_ocr_res["rec_scores"][box_no]
  98. )
  99. sub_regions_ocr_res["rec_boxes"].append(
  100. overall_ocr_res["rec_boxes"][box_no]
  101. )
  102. for key in ["rec_polys", "rec_scores", "rec_boxes"]:
  103. sub_regions_ocr_res[key] = np.array(sub_regions_ocr_res[key])
  104. return (
  105. (sub_regions_ocr_res, match_idx_list)
  106. if return_match_idx
  107. else sub_regions_ocr_res
  108. )
  109. def sorted_layout_boxes(res, w):
  110. """
  111. Sort text boxes in order from top to bottom, left to right
  112. Args:
  113. res: List of dictionaries containing layout information.
  114. w: Width of image.
  115. Returns:
  116. List of dictionaries containing sorted layout information.
  117. """
  118. num_boxes = len(res)
  119. if num_boxes == 1:
  120. res[0]["layout"] = "single"
  121. return res
  122. # Sort on the y axis first or sort it on the x axis
  123. sorted_boxes = sorted(res, key=lambda x: (x["layout_bbox"][1], x["layout_bbox"][0]))
  124. _boxes = list(sorted_boxes)
  125. new_res = []
  126. res_left = []
  127. res_right = []
  128. i = 0
  129. while True:
  130. if i >= num_boxes:
  131. break
  132. # Check that the bbox is on the left
  133. elif (
  134. _boxes[i]["layout_bbox"][0] < w / 4
  135. and _boxes[i]["layout_bbox"][2] < 3 * w / 5
  136. ):
  137. _boxes[i]["layout"] = "double"
  138. res_left.append(_boxes[i])
  139. i += 1
  140. elif _boxes[i]["layout_bbox"][0] > 2 * w / 5:
  141. _boxes[i]["layout"] = "double"
  142. res_right.append(_boxes[i])
  143. i += 1
  144. else:
  145. new_res += res_left
  146. new_res += res_right
  147. _boxes[i]["layout"] = "single"
  148. new_res.append(_boxes[i])
  149. res_left = []
  150. res_right = []
  151. i += 1
  152. res_left = sorted(res_left, key=lambda x: (x["layout_bbox"][1]))
  153. res_right = sorted(res_right, key=lambda x: (x["layout_bbox"][1]))
  154. if res_left:
  155. new_res += res_left
  156. if res_right:
  157. new_res += res_right
  158. return new_res
  159. def _calculate_overlap_area_div_minbox_area_ratio(
  160. bbox1: Union[list, tuple],
  161. bbox2: Union[list, tuple],
  162. ) -> float:
  163. """
  164. Calculate the ratio of the overlap area between bbox1 and bbox2
  165. to the area of the smaller bounding box.
  166. Args:
  167. bbox1 (list or tuple): Coordinates of the first bounding box [x_min, y_min, x_max, y_max].
  168. bbox2 (list or tuple): Coordinates of the second bounding box [x_min, y_min, x_max, y_max].
  169. Returns:
  170. float: The ratio of the overlap area to the area of the smaller bounding box.
  171. """
  172. bbox1 = list(map(int, bbox1))
  173. bbox2 = list(map(int, bbox2))
  174. x_left = max(bbox1[0], bbox2[0])
  175. y_top = max(bbox1[1], bbox2[1])
  176. x_right = min(bbox1[2], bbox2[2])
  177. y_bottom = min(bbox1[3], bbox2[3])
  178. if x_right <= x_left or y_bottom <= y_top:
  179. return 0.0
  180. intersection_area = (x_right - x_left) * (y_bottom - y_top)
  181. area_bbox1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
  182. area_bbox2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
  183. min_box_area = min(area_bbox1, area_bbox2)
  184. if min_box_area <= 0:
  185. return 0.0
  186. return intersection_area / min_box_area
  187. def _whether_y_overlap_exceeds_threshold(
  188. bbox1: Union[list, tuple],
  189. bbox2: Union[list, tuple],
  190. overlap_ratio_threshold: float = 0.6,
  191. ) -> bool:
  192. """
  193. Determines whether the vertical overlap between two bounding boxes exceeds a given threshold.
  194. Args:
  195. bbox1 (list or tuple): The first bounding box defined as (left, top, right, bottom).
  196. bbox2 (list or tuple): The second bounding box defined as (left, top, right, bottom).
  197. overlap_ratio_threshold (float): The threshold ratio to determine if the overlap is significant.
  198. Defaults to 0.6.
  199. Returns:
  200. bool: True if the vertical overlap divided by the minimum height of the two bounding boxes
  201. exceeds the overlap_ratio_threshold, otherwise False.
  202. """
  203. _, y1_0, _, y1_1 = bbox1
  204. _, y2_0, _, y2_1 = bbox2
  205. overlap = max(0, min(y1_1, y2_1) - max(y1_0, y2_0))
  206. min_height = min(y1_1 - y1_0, y2_1 - y2_0)
  207. return (overlap / min_height) > overlap_ratio_threshold
  208. def _adjust_span_text(span: List[str], prepend: bool = False, append: bool = False):
  209. """
  210. Adjust the text of a span by prepending or appending a newline.
  211. Args:
  212. span (list): A list where the second element is the text of the span.
  213. prepend (bool): If True, prepend a newline to the text.
  214. append (bool): If True, append a newline to the text.
  215. Returns:
  216. None: The function modifies the span in place.
  217. """
  218. if prepend:
  219. span[1] = "\n" + span[1]
  220. if append:
  221. span[1] = span[1] + "\n"
  222. def _format_line(
  223. line: List[List[Union[List[int], str]]],
  224. layout_min: int,
  225. layout_max: int,
  226. is_reference: bool = False,
  227. ) -> None:
  228. """
  229. Format a line of text spans based on layout constraints.
  230. Args:
  231. line (list): A list of spans, where each span is a list containing a bounding box and text.
  232. layout_min (int): The minimum x-coordinate of the layout bounding box.
  233. layout_max (int): The maximum x-coordinate of the layout bounding box.
  234. is_reference (bool): A flag indicating whether the line is a reference line, which affects formatting rules.
  235. Returns:
  236. None: The function modifies the line in place.
  237. """
  238. first_span = line[0]
  239. end_span = line[-1]
  240. if not is_reference:
  241. if first_span[0][0] - layout_min > 10:
  242. _adjust_span_text(first_span, prepend=True)
  243. if layout_max - end_span[0][2] > 10:
  244. _adjust_span_text(end_span, append=True)
  245. else:
  246. if first_span[0][0] - layout_min < 5:
  247. _adjust_span_text(first_span, prepend=True)
  248. if layout_max - end_span[0][2] > 20:
  249. _adjust_span_text(end_span, append=True)
  250. def _sort_ocr_res_by_y_projection(
  251. label: Any,
  252. layout_bbox: Tuple[int, int, int, int],
  253. ocr_res: Dict[str, List[Any]],
  254. line_height_iou_threshold: float = 0.7,
  255. ) -> Dict[str, List[Any]]:
  256. """
  257. Sorts OCR results based on their spatial arrangement, grouping them into lines and blocks.
  258. Args:
  259. label (Any): The label associated with the OCR results. It's not used in the function but might be
  260. relevant for other parts of the calling context.
  261. layout_bbox (Tuple[int, int, int, int]): A tuple representing the layout bounding box, defined as
  262. (left, top, right, bottom).
  263. ocr_res (Dict[str, List[Any]]): A dictionary containing OCR results with the following keys:
  264. - "boxes": A list of bounding boxes, each defined as [left, top, right, bottom].
  265. - "rec_texts": A corresponding list of recognized text strings for each box.
  266. line_height_iou_threshold (float): The threshold for determining whether two boxes belong to
  267. the same line based on their vertical overlap. Defaults to 0.7.
  268. Returns:
  269. Dict[str, List[Any]]: A dictionary with the same structure as `ocr_res`, but with boxes and texts sorted
  270. and grouped into lines and blocks.
  271. """
  272. assert (
  273. ocr_res["boxes"] and ocr_res["rec_texts"]
  274. ), "OCR results must contain 'boxes' and 'rec_texts'"
  275. boxes = ocr_res["boxes"]
  276. rec_texts = ocr_res["rec_texts"]
  277. x_min, _, x_max, _ = layout_bbox
  278. inline_x_min = min([box[0] for box in boxes])
  279. inline_x_max = max([box[2] for box in boxes])
  280. spans = list(zip(boxes, rec_texts))
  281. spans.sort(key=lambda span: span[0][1])
  282. spans = [list(span) for span in spans]
  283. lines = []
  284. current_line = [spans[0]]
  285. current_y0, current_y1 = spans[0][0][1], spans[0][0][3]
  286. for span in spans[1:]:
  287. y0, y1 = span[0][1], span[0][3]
  288. if _whether_y_overlap_exceeds_threshold(
  289. (0, current_y0, 0, current_y1),
  290. (0, y0, 0, y1),
  291. line_height_iou_threshold,
  292. ):
  293. current_line.append(span)
  294. current_y0 = min(current_y0, y0)
  295. current_y1 = max(current_y1, y1)
  296. else:
  297. lines.append(current_line)
  298. current_line = [span]
  299. current_y0, current_y1 = y0, y1
  300. if current_line:
  301. lines.append(current_line)
  302. for line in lines:
  303. line.sort(key=lambda span: span[0][0])
  304. if label == "reference":
  305. line = _format_line(line, inline_x_min, inline_x_max, is_reference=True)
  306. else:
  307. line = _format_line(line, x_min, x_max)
  308. # Flatten lines back into a single list for boxes and texts
  309. ocr_res["boxes"] = [span[0] for line in lines for span in line]
  310. ocr_res["rec_texts"] = [span[1] + " " for line in lines for span in line]
  311. return ocr_res
  312. def get_single_block_parsing_res(
  313. overall_ocr_res: OCRResult,
  314. layout_det_res: DetResult,
  315. table_res_list: list,
  316. seal_res_list: list,
  317. ) -> OCRResult:
  318. """
  319. Extract structured information from OCR and layout detection results.
  320. Args:
  321. overall_ocr_res (OCRResult): An object containing the overall OCR results, including detected text boxes and recognized text. The structure is expected to have:
  322. - "input_img": The image on which OCR was performed.
  323. - "dt_boxes": A list of detected text box coordinates.
  324. - "rec_texts": A list of recognized text corresponding to the detected boxes.
  325. layout_det_res (DetResult): An object containing the layout detection results, including detected layout boxes and their labels. The structure is expected to have:
  326. - "boxes": A list of dictionaries with keys "coordinate" for box coordinates and "label" for the type of content.
  327. table_res_list (list): A list of table detection results, where each item is a dictionary containing:
  328. - "layout_bbox": The bounding box of the table layout.
  329. - "pred_html": The predicted HTML representation of the table.
  330. seal_res_list (List): A list of seal detection results. The details of each item depend on the specific application context.
  331. Returns:
  332. list: A list of structured boxes where each item is a dictionary containing:
  333. - "label": The label of the content (e.g., 'table', 'chart', 'image').
  334. - The label as a key with either table HTML or image data and text.
  335. - "layout_bbox": The coordinates of the layout box.
  336. """
  337. single_block_layout_parsing_res = []
  338. input_img = overall_ocr_res["doc_preprocessor_res"]["output_img"]
  339. for box_info in layout_det_res["boxes"]:
  340. layout_bbox = box_info["coordinate"]
  341. label = box_info["label"]
  342. rec_res = {"boxes": [], "rec_texts": [], "flag": False}
  343. seg_start_flag = True
  344. seg_end_flag = True
  345. if label == "table":
  346. for i, table_res in enumerate(table_res_list):
  347. if (
  348. _calculate_overlap_area_div_minbox_area_ratio(
  349. layout_bbox, table_res["cell_box_list"][0]
  350. )
  351. > 0.5
  352. ):
  353. single_block_layout_parsing_res.append(
  354. {
  355. "label": label,
  356. f"{label}": table_res["pred_html"],
  357. "layout_bbox": layout_bbox,
  358. "seg_start_flag": seg_start_flag,
  359. "seg_end_flag": seg_end_flag,
  360. },
  361. )
  362. del table_res_list[i]
  363. break
  364. else:
  365. overall_text_boxes = overall_ocr_res["rec_boxes"]
  366. for box_no in range(len(overall_text_boxes)):
  367. if (
  368. _calculate_overlap_area_div_minbox_area_ratio(
  369. layout_bbox, overall_text_boxes[box_no]
  370. )
  371. > 0.5
  372. ):
  373. rec_res["boxes"].append(overall_text_boxes[box_no])
  374. rec_res["rec_texts"].append(
  375. overall_ocr_res["rec_texts"][box_no],
  376. )
  377. rec_res["flag"] = True
  378. if rec_res["flag"]:
  379. rec_res = _sort_ocr_res_by_y_projection(
  380. label, layout_bbox, rec_res, 0.7
  381. )
  382. rec_res_first_bbox = rec_res["boxes"][0]
  383. rec_res_end_bbox = rec_res["boxes"][-1]
  384. if rec_res_first_bbox[0] - layout_bbox[0] < 10:
  385. seg_start_flag = False
  386. if layout_bbox[2] - rec_res_end_bbox[2] < 10:
  387. seg_end_flag = False
  388. if label == "formula":
  389. rec_res["rec_texts"] = [
  390. rec_res_text.replace("$", "")
  391. for rec_res_text in rec_res["rec_texts"]
  392. ]
  393. if label in ["chart", "image", "seal"]:
  394. single_block_layout_parsing_res.append(
  395. {
  396. "label": label,
  397. f"{label}": {
  398. "img": input_img[
  399. int(layout_bbox[1]) : int(layout_bbox[3]),
  400. int(layout_bbox[0]) : int(layout_bbox[2]),
  401. ],
  402. },
  403. "layout_bbox": layout_bbox,
  404. "seg_start_flag": seg_start_flag,
  405. "seg_end_flag": seg_end_flag,
  406. },
  407. )
  408. else:
  409. single_block_layout_parsing_res.append(
  410. {
  411. "label": label,
  412. f"{label}": "".join(rec_res["rec_texts"]),
  413. "layout_bbox": layout_bbox,
  414. "seg_start_flag": seg_start_flag,
  415. "seg_end_flag": seg_end_flag,
  416. },
  417. )
  418. return single_block_layout_parsing_res
  419. def _projection_by_bboxes(boxes: np.ndarray, axis: int) -> np.ndarray:
  420. """
  421. Generate a 1D projection histogram from bounding boxes along a specified axis.
  422. Args:
  423. boxes: A (N, 4) array of bounding boxes defined by [x_min, y_min, x_max, y_max].
  424. axis: Axis for projection; 0 for horizontal (x-axis), 1 for vertical (y-axis).
  425. Returns:
  426. A 1D numpy array representing the projection histogram based on bounding box intervals.
  427. """
  428. assert axis in [0, 1]
  429. max_length = np.max(boxes[:, axis::2])
  430. projection = np.zeros(max_length, dtype=int)
  431. # Increment projection histogram over the interval defined by each bounding box
  432. for start, end in boxes[:, axis::2]:
  433. projection[start:end] += 1
  434. return projection
  435. def _split_projection_profile(arr_values: np.ndarray, min_value: float, min_gap: float):
  436. """
  437. Split the projection profile into segments based on specified thresholds.
  438. Args:
  439. arr_values: 1D array representing the projection profile.
  440. min_value: Minimum value threshold to consider a profile segment significant.
  441. min_gap: Minimum gap width to consider a separation between segments.
  442. Returns:
  443. A tuple of start and end indices for each segment that meets the criteria.
  444. """
  445. # Identify indices where the projection exceeds the minimum value
  446. significant_indices = np.where(arr_values > min_value)[0]
  447. if not len(significant_indices):
  448. return
  449. # Calculate gaps between significant indices
  450. index_diffs = significant_indices[1:] - significant_indices[:-1]
  451. gap_indices = np.where(index_diffs > min_gap)[0]
  452. # Determine start and end indices of segments
  453. segment_starts = np.insert(
  454. significant_indices[gap_indices + 1],
  455. 0,
  456. significant_indices[0],
  457. )
  458. segment_ends = np.append(
  459. significant_indices[gap_indices],
  460. significant_indices[-1] + 1,
  461. )
  462. return segment_starts, segment_ends
  463. def _recursive_yx_cut(
  464. boxes: np.ndarray, indices: List[int], res: List[int], min_gap: int = 1
  465. ):
  466. """
  467. Recursively project and segment bounding boxes, starting with Y-axis and followed by X-axis.
  468. Args:
  469. boxes: A (N, 4) array representing bounding boxes.
  470. indices: List of indices indicating the original position of boxes.
  471. res: List to store indices of the final segmented bounding boxes.
  472. min_gap (int): Minimum gap width to consider a separation between segments on the X-axis. Defaults to 1.
  473. Returns:
  474. None: This function modifies the `res` list in place.
  475. """
  476. assert len(boxes) == len(
  477. indices
  478. ), "The length of boxes and indices must be the same."
  479. # Sort by y_min for Y-axis projection
  480. y_sorted_indices = boxes[:, 1].argsort()
  481. y_sorted_boxes = boxes[y_sorted_indices]
  482. y_sorted_indices = np.array(indices)[y_sorted_indices]
  483. # Perform Y-axis projection
  484. y_projection = _projection_by_bboxes(boxes=y_sorted_boxes, axis=1)
  485. y_intervals = _split_projection_profile(y_projection, 0, 1)
  486. if not y_intervals:
  487. return
  488. # Process each segment defined by Y-axis projection
  489. for y_start, y_end in zip(*y_intervals):
  490. # Select boxes within the current y interval
  491. y_interval_indices = (y_start <= y_sorted_boxes[:, 1]) & (
  492. y_sorted_boxes[:, 1] < y_end
  493. )
  494. y_boxes_chunk = y_sorted_boxes[y_interval_indices]
  495. y_indices_chunk = y_sorted_indices[y_interval_indices]
  496. # Sort by x_min for X-axis projection
  497. x_sorted_indices = y_boxes_chunk[:, 0].argsort()
  498. x_sorted_boxes_chunk = y_boxes_chunk[x_sorted_indices]
  499. x_sorted_indices_chunk = y_indices_chunk[x_sorted_indices]
  500. # Perform X-axis projection
  501. x_projection = _projection_by_bboxes(boxes=x_sorted_boxes_chunk, axis=0)
  502. x_intervals = _split_projection_profile(x_projection, 0, min_gap)
  503. if not x_intervals:
  504. continue
  505. # If X-axis cannot be further segmented, add current indices to results
  506. if len(x_intervals[0]) == 1:
  507. res.extend(x_sorted_indices_chunk)
  508. continue
  509. # Recursively process each segment defined by X-axis projection
  510. for x_start, x_end in zip(*x_intervals):
  511. x_interval_indices = (x_start <= x_sorted_boxes_chunk[:, 0]) & (
  512. x_sorted_boxes_chunk[:, 0] < x_end
  513. )
  514. _recursive_yx_cut(
  515. x_sorted_boxes_chunk[x_interval_indices],
  516. x_sorted_indices_chunk[x_interval_indices],
  517. res,
  518. )
  519. def _recursive_xy_cut(
  520. boxes: np.ndarray, indices: List[int], res: List[int], min_gap: int = 1
  521. ):
  522. """
  523. Recursively performs X-axis projection followed by Y-axis projection to segment bounding boxes.
  524. Args:
  525. boxes: A (N, 4) array representing bounding boxes with [x_min, y_min, x_max, y_max].
  526. indices: A list of indices representing the position of boxes in the original data.
  527. res: A list to store indices of bounding boxes that meet the criteria.
  528. min_gap (int): Minimum gap width to consider a separation between segments on the X-axis. Defaults to 1.
  529. Returns:
  530. None: This function modifies the `res` list in place.
  531. """
  532. # Ensure boxes and indices have the same length
  533. assert len(boxes) == len(
  534. indices
  535. ), "The length of boxes and indices must be the same."
  536. # Sort by x_min to prepare for X-axis projection
  537. x_sorted_indices = boxes[:, 0].argsort()
  538. x_sorted_boxes = boxes[x_sorted_indices]
  539. x_sorted_indices = np.array(indices)[x_sorted_indices]
  540. # Perform X-axis projection
  541. x_projection = _projection_by_bboxes(boxes=x_sorted_boxes, axis=0)
  542. x_intervals = _split_projection_profile(x_projection, 0, 1)
  543. if not x_intervals:
  544. return
  545. # Process each segment defined by X-axis projection
  546. for x_start, x_end in zip(*x_intervals):
  547. # Select boxes within the current x interval
  548. x_interval_indices = (x_start <= x_sorted_boxes[:, 0]) & (
  549. x_sorted_boxes[:, 0] < x_end
  550. )
  551. x_boxes_chunk = x_sorted_boxes[x_interval_indices]
  552. x_indices_chunk = x_sorted_indices[x_interval_indices]
  553. # Sort selected boxes by y_min to prepare for Y-axis projection
  554. y_sorted_indices = x_boxes_chunk[:, 1].argsort()
  555. y_sorted_boxes_chunk = x_boxes_chunk[y_sorted_indices]
  556. y_sorted_indices_chunk = x_indices_chunk[y_sorted_indices]
  557. # Perform Y-axis projection
  558. y_projection = _projection_by_bboxes(boxes=y_sorted_boxes_chunk, axis=1)
  559. y_intervals = _split_projection_profile(y_projection, 0, min_gap)
  560. if not y_intervals:
  561. continue
  562. # If Y-axis cannot be further segmented, add current indices to results
  563. if len(y_intervals[0]) == 1:
  564. res.extend(y_sorted_indices_chunk)
  565. continue
  566. # Recursively process each segment defined by Y-axis projection
  567. for y_start, y_end in zip(*y_intervals):
  568. y_interval_indices = (y_start <= y_sorted_boxes_chunk[:, 1]) & (
  569. y_sorted_boxes_chunk[:, 1] < y_end
  570. )
  571. _recursive_xy_cut(
  572. y_sorted_boxes_chunk[y_interval_indices],
  573. y_sorted_indices_chunk[y_interval_indices],
  574. res,
  575. )
  576. def sort_by_xycut(
  577. block_bboxes: Union[np.ndarray, List[List[int]]],
  578. direction: int = 0,
  579. min_gap: int = 1,
  580. ) -> List[int]:
  581. """
  582. Sort bounding boxes using recursive XY cut method based on the specified direction.
  583. Args:
  584. block_bboxes (Union[np.ndarray, List[List[int]]]): An array or list of bounding boxes,
  585. where each box is represented as
  586. [x_min, y_min, x_max, y_max].
  587. direction (int): Direction for the initial cut. Use 1 for Y-axis first and 0 for X-axis first.
  588. Defaults to 0.
  589. min_gap (int): Minimum gap width to consider a separation between segments. Defaults to 1.
  590. Returns:
  591. List[int]: A list of indices representing the order of sorted bounding boxes.
  592. """
  593. block_bboxes = np.asarray(block_bboxes).astype(int)
  594. res = []
  595. if direction == 1:
  596. _recursive_yx_cut(
  597. block_bboxes,
  598. np.arange(len(block_bboxes)),
  599. res,
  600. min_gap,
  601. )
  602. else:
  603. _recursive_xy_cut(
  604. block_bboxes,
  605. np.arange(len(block_bboxes)),
  606. res,
  607. min_gap,
  608. )
  609. return res
  610. def _img_array2path(data: np.ndarray, save_path: Union[str, Path]) -> str:
  611. """
  612. Save an image array to disk and return the relative file path.
  613. Args:
  614. data (np.ndarray): An image represented as a numpy array with 3 dimensions (H, W, C).
  615. save_path (Union[str, Path]): The base path where images should be saved.
  616. Returns:
  617. str: The relative path of the saved image file.
  618. Raises:
  619. ValueError: If the input data is not a valid image array.
  620. """
  621. if isinstance(data, np.ndarray) and data.ndim == 3:
  622. # Generate a unique filename using UUID
  623. img_name = f"image_{uuid.uuid4().hex}.png"
  624. img_path = Path(save_path) / "imgs" / img_name
  625. img_path.parent.mkdir(
  626. parents=True, exist_ok=True
  627. ) # Ensure the directory exists
  628. # Save the image using OpenCV
  629. success = cv2.imwrite(str(img_path), data)
  630. if not success:
  631. raise IOError(f"Failed to save image to {img_path}")
  632. return f"imgs/{img_name}"
  633. else:
  634. raise ValueError(
  635. "Input data must be a 3-dimensional numpy array representing an image."
  636. )
  637. def recursive_img_array2path(
  638. data: Union[Dict[str, Any], List[Any]],
  639. save_path: Union[str, Path],
  640. labels: List[str] = [],
  641. ) -> None:
  642. """
  643. Recursively process a dictionary or list to save image arrays to disk
  644. and replace them with file paths.
  645. Args:
  646. data (Union[Dict[str, Any], List[Any]]): The data structure that may contain image arrays.
  647. save_path (Union[str, Path]): The base path where images should be saved.
  648. labels (List[str]): List of keys to check for image arrays in dictionaries.
  649. Returns:
  650. None: This function modifies the input data structure in place.
  651. """
  652. if isinstance(data, dict):
  653. for k, v in data.items():
  654. if k in labels and isinstance(v, np.ndarray) and v.ndim == 3:
  655. data[k] = _img_array2path(v, save_path)
  656. else:
  657. recursive_img_array2path(v, save_path, labels)
  658. elif isinstance(data, list):
  659. for item in data:
  660. recursive_img_array2path(item, save_path, labels)
  661. def _get_minbox_if_overlap_by_ratio(
  662. bbox1: Union[List[int], Tuple[int, int, int, int]],
  663. bbox2: Union[List[int], Tuple[int, int, int, int]],
  664. ratio: float,
  665. smaller: bool = True,
  666. ) -> Optional[Union[List[int], Tuple[int, int, int, int]]]:
  667. """
  668. Determine if the overlap area between two bounding boxes exceeds a given ratio
  669. and return the smaller (or larger) bounding box based on the `smaller` flag.
  670. Args:
  671. bbox1 (Union[List[int], Tuple[int, int, int, int]]): Coordinates of the first bounding box [x_min, y_min, x_max, y_max].
  672. bbox2 (Union[List[int], Tuple[int, int, int, int]]): Coordinates of the second bounding box [x_min, y_min, x_max, y_max].
  673. ratio (float): The overlap ratio threshold.
  674. smaller (bool): If True, return the smaller bounding box; otherwise, return the larger one.
  675. Returns:
  676. Optional[Union[List[int], Tuple[int, int, int, int]]]:
  677. The selected bounding box or None if the overlap ratio is not exceeded.
  678. """
  679. # Calculate the areas of both bounding boxes
  680. area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
  681. area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
  682. # Calculate the overlap ratio using a helper function
  683. overlap_ratio = _calculate_overlap_area_div_minbox_area_ratio(bbox1, bbox2)
  684. # Check if the overlap ratio exceeds the threshold
  685. if overlap_ratio > ratio:
  686. if (area1 <= area2 and smaller) or (area1 >= area2 and not smaller):
  687. return 1
  688. else:
  689. return 2
  690. return None
  691. def _remove_overlap_blocks(
  692. blocks: List[Dict[str, List[int]]], threshold: float = 0.65, smaller: bool = True
  693. ) -> Tuple[List[Dict[str, List[int]]], List[Dict[str, List[int]]]]:
  694. """
  695. Remove overlapping blocks based on a specified overlap ratio threshold.
  696. Args:
  697. blocks (List[Dict[str, List[int]]]): List of block dictionaries, each containing a 'layout_bbox' key.
  698. threshold (float): Ratio threshold to determine significant overlap.
  699. smaller (bool): If True, the smaller block in overlap is removed.
  700. Returns:
  701. Tuple[List[Dict[str, List[int]]], List[Dict[str, List[int]]]]:
  702. A tuple containing the updated list of blocks and a list of dropped blocks.
  703. """
  704. dropped_blocks = []
  705. dropped_indexes = set()
  706. # Iterate over each pair of blocks to find overlaps
  707. for i, block1 in enumerate(blocks):
  708. for j in range(i + 1, len(blocks)):
  709. block2 = blocks[j]
  710. # Skip blocks that are already marked for removal
  711. if i in dropped_indexes or j in dropped_indexes:
  712. continue
  713. # Check for overlap and determine which block to remove
  714. overlap_box_index = _get_minbox_if_overlap_by_ratio(
  715. block1["layout_bbox"],
  716. block2["layout_bbox"],
  717. threshold,
  718. smaller=smaller,
  719. )
  720. if overlap_box_index is not None:
  721. # Determine which block to remove based on overlap_box_index
  722. if overlap_box_index == 1:
  723. drop_index = i
  724. else:
  725. drop_index = j
  726. dropped_indexes.add(drop_index)
  727. # Remove marked blocks from the original list
  728. for index in sorted(dropped_indexes, reverse=True):
  729. dropped_blocks.append(blocks[index])
  730. del blocks[index]
  731. return blocks, dropped_blocks
  732. def _get_text_median_width(blocks: List[Dict[str, any]]) -> float:
  733. """
  734. Calculate the median width of blocks labeled as "text".
  735. Args:
  736. blocks (List[Dict[str, any]]): List of block dictionaries, each containing a 'layout_bbox' and 'label'.
  737. Returns:
  738. float: The median width of text blocks, or infinity if no text blocks are found.
  739. """
  740. widths = [
  741. block["layout_bbox"][2] - block["layout_bbox"][0]
  742. for block in blocks
  743. if block.get("label") == "text"
  744. ]
  745. return np.median(widths) if widths else float("inf")
  746. def _get_layout_property(
  747. blocks: List[Dict[str, any]],
  748. median_width: float,
  749. no_mask_labels: List[str],
  750. threshold: float = 0.8,
  751. ) -> Tuple[List[Dict[str, any]], bool]:
  752. """
  753. Determine the layout (single or double column) of text blocks.
  754. Args:
  755. blocks (List[Dict[str, any]]): List of block dictionaries containing 'label' and 'layout_bbox'.
  756. median_width (float): Median width of text blocks.
  757. no_mask_labels (List[str]): Labels of blocks to be considered for layout analysis.
  758. threshold (float): Threshold for determining layout overlap.
  759. Returns:
  760. Tuple[List[Dict[str, any]], bool]: Updated list of blocks with layout information and a boolean
  761. indicating if the double layout area is greater than the single layout area.
  762. """
  763. blocks.sort(
  764. key=lambda x: (
  765. x["layout_bbox"][0],
  766. (x["layout_bbox"][2] - x["layout_bbox"][0]),
  767. ),
  768. )
  769. check_single_layout = {}
  770. page_min_x, page_max_x = float("inf"), 0
  771. double_label_area = 0
  772. single_label_area = 0
  773. for i, block in enumerate(blocks):
  774. page_min_x = min(page_min_x, block["layout_bbox"][0])
  775. page_max_x = max(page_max_x, block["layout_bbox"][2])
  776. page_width = page_max_x - page_min_x
  777. for i, block in enumerate(blocks):
  778. if block["label"] not in no_mask_labels:
  779. continue
  780. x_min_i, _, x_max_i, _ = block["layout_bbox"]
  781. layout_length = x_max_i - x_min_i
  782. cover_count, cover_with_threshold_count = 0, 0
  783. match_block_with_threshold_indexes = []
  784. for j, other_block in enumerate(blocks):
  785. if i == j or other_block["label"] not in no_mask_labels:
  786. continue
  787. x_min_j, _, x_max_j, _ = other_block["layout_bbox"]
  788. x_match_min, x_match_max = max(
  789. x_min_i,
  790. x_min_j,
  791. ), min(x_max_i, x_max_j)
  792. match_block_iou = (x_match_max - x_match_min) / (x_max_j - x_min_j)
  793. if match_block_iou > 0:
  794. cover_count += 1
  795. if match_block_iou > threshold:
  796. cover_with_threshold_count += 1
  797. match_block_with_threshold_indexes.append(
  798. (j, match_block_iou),
  799. )
  800. x_min_i = x_match_max
  801. if x_min_i >= x_max_i:
  802. break
  803. if (
  804. layout_length > median_width * 1.3
  805. and (cover_with_threshold_count >= 2 or cover_count >= 2)
  806. ) or layout_length > 0.6 * page_width:
  807. # if layout_length > median_width * 1.3 and (cover_with_threshold_count >= 2):
  808. block["layout"] = "double"
  809. double_label_area += (block["layout_bbox"][2] - block["layout_bbox"][0]) * (
  810. block["layout_bbox"][3] - block["layout_bbox"][1]
  811. )
  812. else:
  813. block["layout"] = "single"
  814. check_single_layout[i] = match_block_with_threshold_indexes
  815. # Check single-layout block
  816. for i, single_layout in check_single_layout.items():
  817. if single_layout:
  818. index, match_iou = single_layout[-1]
  819. if match_iou > 0.9 and blocks[index]["layout"] == "double":
  820. blocks[i]["layout"] = "double"
  821. double_label_area += (
  822. blocks[i]["layout_bbox"][2] - blocks[i]["layout_bbox"][0]
  823. ) * (blocks[i]["layout_bbox"][3] - blocks[i]["layout_bbox"][1])
  824. else:
  825. single_label_area += (
  826. blocks[i]["layout_bbox"][2] - blocks[i]["layout_bbox"][0]
  827. ) * (blocks[i]["layout_bbox"][3] - blocks[i]["layout_bbox"][1])
  828. return blocks, (double_label_area > single_label_area)
  829. def _get_bbox_direction(input_bbox: List[float], ratio: float = 1.0) -> bool:
  830. """
  831. Determine if a bounding box is horizontal or vertical.
  832. Args:
  833. input_bbox (List[float]): Bounding box [x_min, y_min, x_max, y_max].
  834. ratio (float): Ratio for determining orientation. Default is 1.0.
  835. Returns:
  836. bool: True if the bounding box is considered horizontal, False if vertical.
  837. """
  838. width = input_bbox[2] - input_bbox[0]
  839. height = input_bbox[3] - input_bbox[1]
  840. return width * ratio >= height
  841. def _get_projection_iou(
  842. input_bbox: List[float], match_bbox: List[float], is_horizontal: bool = True
  843. ) -> float:
  844. """
  845. Calculate the IoU of lines between two bounding boxes.
  846. Args:
  847. input_bbox (List[float]): First bounding box [x_min, y_min, x_max, y_max].
  848. match_bbox (List[float]): Second bounding box [x_min, y_min, x_max, y_max].
  849. is_horizontal (bool): Whether to compare horizontally or vertically.
  850. Returns:
  851. float: Line IoU. Returns 0 if there is no overlap.
  852. """
  853. if is_horizontal:
  854. x_match_min = max(input_bbox[0], match_bbox[0])
  855. x_match_max = min(input_bbox[2], match_bbox[2])
  856. overlap = max(0, x_match_max - x_match_min)
  857. input_width = input_bbox[2] - input_bbox[0]
  858. else:
  859. y_match_min = max(input_bbox[1], match_bbox[1])
  860. y_match_max = min(input_bbox[3], match_bbox[3])
  861. overlap = max(0, y_match_max - y_match_min)
  862. input_width = input_bbox[3] - input_bbox[1]
  863. return overlap / input_width if input_width > 0 else 0.0
  864. def _get_sub_category(
  865. blocks: List[Dict[str, Any]], title_labels: List[str]
  866. ) -> List[Dict[str, Any]]:
  867. """
  868. Determine the layout of title and text blocks.
  869. Args:
  870. blocks (List[Dict[str, Any]]): List of block dictionaries.
  871. title_labels (List[str]): List of labels considered as titles.
  872. Returns:
  873. List[Dict[str, Any]]: Updated list of blocks with title-text layout information.
  874. """
  875. sub_title_labels = ["paragraph_title"]
  876. vision_labels = ["image", "table", "chart", "figure"]
  877. for i, block1 in enumerate(blocks):
  878. block1.setdefault("title_text", [])
  879. block1.setdefault("sub_title", [])
  880. block1.setdefault("vision_footnote", [])
  881. block1.setdefault("sub_label", block1["label"])
  882. if (
  883. block1["label"] not in title_labels
  884. and block1["label"] not in sub_title_labels
  885. and block1["label"] not in vision_labels
  886. ):
  887. continue
  888. bbox1 = block1["layout_bbox"]
  889. x1, y1, x2, y2 = bbox1
  890. is_horizontal_1 = _get_bbox_direction(block1["layout_bbox"])
  891. left_up_title_text_distance = float("inf")
  892. left_up_title_text_index = -1
  893. left_up_title_text_direction = None
  894. right_down_title_text_distance = float("inf")
  895. right_down_title_text_index = -1
  896. right_down_title_text_direction = None
  897. for j, block2 in enumerate(blocks):
  898. if i == j:
  899. continue
  900. bbox2 = block2["layout_bbox"]
  901. x1_prime, y1_prime, x2_prime, y2_prime = bbox2
  902. is_horizontal_2 = _get_bbox_direction(bbox2)
  903. match_block_iou = _get_projection_iou(
  904. bbox2,
  905. bbox1,
  906. is_horizontal_1,
  907. )
  908. def distance_(is_horizontal, is_left_up):
  909. if is_horizontal:
  910. if is_left_up:
  911. return (y1 - y2_prime + 2) // 5 + x1_prime / 5000
  912. else:
  913. return (y1_prime - y2 + 2) // 5 + x1_prime / 5000
  914. else:
  915. if is_left_up:
  916. return (x1 - x2_prime + 2) // 5 + y1_prime / 5000
  917. else:
  918. return (x1_prime - x2 + 2) // 5 + y1_prime / 5000
  919. block_iou_threshold = 0.1
  920. if block1["label"] in sub_title_labels:
  921. match_block_iou = _calculate_overlap_area_div_minbox_area_ratio(
  922. bbox2,
  923. bbox1,
  924. )
  925. block_iou_threshold = 0.7
  926. if is_horizontal_1:
  927. if match_block_iou >= block_iou_threshold:
  928. left_up_distance = distance_(True, True)
  929. right_down_distance = distance_(True, False)
  930. if (
  931. y2_prime <= y1
  932. and left_up_distance <= left_up_title_text_distance
  933. ):
  934. left_up_title_text_distance = left_up_distance
  935. left_up_title_text_index = j
  936. left_up_title_text_direction = is_horizontal_2
  937. elif (
  938. y1_prime > y2
  939. and right_down_distance < right_down_title_text_distance
  940. ):
  941. right_down_title_text_distance = right_down_distance
  942. right_down_title_text_index = j
  943. right_down_title_text_direction = is_horizontal_2
  944. else:
  945. if match_block_iou >= block_iou_threshold:
  946. left_up_distance = distance_(False, True)
  947. right_down_distance = distance_(False, False)
  948. if (
  949. x2_prime <= x1
  950. and left_up_distance <= left_up_title_text_distance
  951. ):
  952. left_up_title_text_distance = left_up_distance
  953. left_up_title_text_index = j
  954. left_up_title_text_direction = is_horizontal_2
  955. elif (
  956. x1_prime > x2
  957. and right_down_distance < right_down_title_text_distance
  958. ):
  959. right_down_title_text_distance = right_down_distance
  960. right_down_title_text_index = j
  961. right_down_title_text_direction = is_horizontal_2
  962. height = bbox1[3] - bbox1[1]
  963. width = bbox1[2] - bbox1[0]
  964. title_text_weight = [0.8, 0.8]
  965. title_text, sub_title, vision_footnote = [], [], []
  966. def get_sub_category_(
  967. title_text_direction,
  968. title_text_index,
  969. label,
  970. is_left_up=True,
  971. ):
  972. direction_ = [1, 3] if is_left_up else [2, 4]
  973. if (
  974. title_text_direction == is_horizontal_1
  975. and title_text_index != -1
  976. and (label == "text" or label == "paragraph_title")
  977. ):
  978. bbox2 = blocks[title_text_index]["layout_bbox"]
  979. if is_horizontal_1:
  980. height1 = bbox2[3] - bbox2[1]
  981. width1 = bbox2[2] - bbox2[0]
  982. if label == "text":
  983. if (
  984. _nearest_edge_distance(bbox1, bbox2)[0] <= 15
  985. and block1["label"] in vision_labels
  986. and width1 < width
  987. and height1 < 0.5 * height
  988. ):
  989. blocks[title_text_index]["sub_label"] = "vision_footnote"
  990. vision_footnote.append(bbox2)
  991. elif (
  992. height1 < height * title_text_weight[0]
  993. and (width1 < width or width1 > 1.5 * width)
  994. and block1["label"] in title_labels
  995. ):
  996. blocks[title_text_index]["sub_label"] = "title_text"
  997. title_text.append((direction_[0], bbox2))
  998. elif (
  999. label == "paragraph_title"
  1000. and block1["label"] in sub_title_labels
  1001. ):
  1002. sub_title.append(bbox2)
  1003. else:
  1004. height1 = bbox2[3] - bbox2[1]
  1005. width1 = bbox2[2] - bbox2[0]
  1006. if label == "text":
  1007. if (
  1008. _nearest_edge_distance(bbox1, bbox2)[0] <= 15
  1009. and block1["label"] in vision_labels
  1010. and height1 < height
  1011. and width1 < 0.5 * width
  1012. ):
  1013. blocks[title_text_index]["sub_label"] = "vision_footnote"
  1014. vision_footnote.append(bbox2)
  1015. elif (
  1016. width1 < width * title_text_weight[1]
  1017. and block1["label"] in title_labels
  1018. ):
  1019. blocks[title_text_index]["sub_label"] = "title_text"
  1020. title_text.append((direction_[1], bbox2))
  1021. elif (
  1022. label == "paragraph_title"
  1023. and block1["label"] in sub_title_labels
  1024. ):
  1025. sub_title.append(bbox2)
  1026. if (
  1027. is_horizontal_1
  1028. and abs(left_up_title_text_distance - right_down_title_text_distance) * 5
  1029. > height
  1030. ) or (
  1031. not is_horizontal_1
  1032. and abs(left_up_title_text_distance - right_down_title_text_distance) * 5
  1033. > width
  1034. ):
  1035. if left_up_title_text_distance < right_down_title_text_distance:
  1036. get_sub_category_(
  1037. left_up_title_text_direction,
  1038. left_up_title_text_index,
  1039. blocks[left_up_title_text_index]["label"],
  1040. True,
  1041. )
  1042. else:
  1043. get_sub_category_(
  1044. right_down_title_text_direction,
  1045. right_down_title_text_index,
  1046. blocks[right_down_title_text_index]["label"],
  1047. False,
  1048. )
  1049. else:
  1050. get_sub_category_(
  1051. left_up_title_text_direction,
  1052. left_up_title_text_index,
  1053. blocks[left_up_title_text_index]["label"],
  1054. True,
  1055. )
  1056. get_sub_category_(
  1057. right_down_title_text_direction,
  1058. right_down_title_text_index,
  1059. blocks[right_down_title_text_index]["label"],
  1060. False,
  1061. )
  1062. if block1["label"] in title_labels:
  1063. if blocks[i].get("title_text") == []:
  1064. blocks[i]["title_text"] = title_text
  1065. if block1["label"] in sub_title_labels:
  1066. if blocks[i].get("sub_title") == []:
  1067. blocks[i]["sub_title"] = sub_title
  1068. if block1["label"] in vision_labels:
  1069. if blocks[i].get("vision_footnote") == []:
  1070. blocks[i]["vision_footnote"] = vision_footnote
  1071. return blocks
  1072. def get_layout_ordering(
  1073. data: List[Dict[str, Any]],
  1074. no_mask_labels: List[str] = [],
  1075. already_sorted: bool = False,
  1076. ) -> None:
  1077. """
  1078. Process layout parsing results to remove overlapping bounding boxes
  1079. and assign an ordering index based on their positions.
  1080. Modifies:
  1081. The 'data' list by adding an 'index' to each block.
  1082. Args:
  1083. data (List[Dict[str, Any]]): List of block dictionaries with 'layout_bbox' and 'label'.
  1084. no_mask_labels (List[str]): Labels for which overlapping removal is not performed.
  1085. already_sorted (bool): Assumes data is already sorted by position if True.
  1086. """
  1087. if already_sorted:
  1088. return data
  1089. title_text_labels = ["doc_title"]
  1090. title_labels = ["doc_title", "paragraph_title"]
  1091. vision_labels = ["image", "table", "seal", "chart", "figure"]
  1092. vision_title_labels = ["table_title", "chart_title", "figure_title"]
  1093. parsing_result = data["sub_blocks"]
  1094. parsing_result, _ = _remove_overlap_blocks(
  1095. parsing_result,
  1096. threshold=0.5,
  1097. smaller=True,
  1098. )
  1099. parsing_result = _get_sub_category(parsing_result, title_text_labels)
  1100. doc_flag = False
  1101. median_width = _get_text_median_width(parsing_result)
  1102. parsing_result, projection_direction = _get_layout_property(
  1103. parsing_result,
  1104. median_width,
  1105. no_mask_labels=no_mask_labels,
  1106. threshold=0.3,
  1107. )
  1108. # Convert bounding boxes to float and remove overlaps
  1109. (
  1110. double_text_blocks,
  1111. title_text_blocks,
  1112. title_blocks,
  1113. vision_blocks,
  1114. vision_title_blocks,
  1115. vision_footnote_blocks,
  1116. other_blocks,
  1117. ) = ([], [], [], [], [], [], [])
  1118. drop_indexes = []
  1119. for index, block in enumerate(parsing_result):
  1120. label = block["sub_label"]
  1121. block["layout_bbox"] = list(map(int, block["layout_bbox"]))
  1122. if label == "doc_title":
  1123. doc_flag = True
  1124. if label in no_mask_labels:
  1125. if block["layout"] == "double":
  1126. double_text_blocks.append(block)
  1127. drop_indexes.append(index)
  1128. elif label == "title_text":
  1129. title_text_blocks.append(block)
  1130. drop_indexes.append(index)
  1131. elif label == "vision_footnote":
  1132. vision_footnote_blocks.append(block)
  1133. drop_indexes.append(index)
  1134. elif label in vision_title_labels:
  1135. vision_title_blocks.append(block)
  1136. drop_indexes.append(index)
  1137. elif label in title_labels:
  1138. title_blocks.append(block)
  1139. drop_indexes.append(index)
  1140. elif label in vision_labels:
  1141. vision_blocks.append(block)
  1142. drop_indexes.append(index)
  1143. else:
  1144. other_blocks.append(block)
  1145. drop_indexes.append(index)
  1146. for index in sorted(drop_indexes, reverse=True):
  1147. del parsing_result[index]
  1148. if len(parsing_result) > 0:
  1149. # single text label
  1150. if len(double_text_blocks) > len(parsing_result) or projection_direction:
  1151. parsing_result.extend(title_blocks + double_text_blocks)
  1152. title_blocks = []
  1153. double_text_blocks = []
  1154. block_bboxes = [block["layout_bbox"] for block in parsing_result]
  1155. block_bboxes.sort(
  1156. key=lambda x: (
  1157. x[0] // max(20, median_width),
  1158. x[1],
  1159. ),
  1160. )
  1161. block_bboxes = np.array(block_bboxes)
  1162. sorted_indices = sort_by_xycut(
  1163. block_bboxes,
  1164. direction=1,
  1165. min_gap=1,
  1166. )
  1167. else:
  1168. block_bboxes = [block["layout_bbox"] for block in parsing_result]
  1169. block_bboxes.sort(key=lambda x: (x[0] // 20, x[1]))
  1170. block_bboxes = np.array(block_bboxes)
  1171. sorted_indices = sort_by_xycut(
  1172. block_bboxes,
  1173. direction=0,
  1174. min_gap=20,
  1175. )
  1176. sorted_boxes = block_bboxes[sorted_indices].tolist()
  1177. for block in parsing_result:
  1178. block["index"] = sorted_boxes.index(block["layout_bbox"]) + 1
  1179. block["sub_index"] = sorted_boxes.index(block["layout_bbox"]) + 1
  1180. def nearest_match_(input_blocks, distance_type="manhattan", is_add_index=True):
  1181. for block in input_blocks:
  1182. bbox = block["layout_bbox"]
  1183. min_distance = float("inf")
  1184. min_distance_config = [
  1185. [float("inf"), float("inf")],
  1186. float("inf"),
  1187. float("inf"),
  1188. ] # for double text
  1189. nearest_gt_index = 0
  1190. for match_block in parsing_result:
  1191. match_bbox = match_block["layout_bbox"]
  1192. if distance_type == "nearest_iou_edge_distance":
  1193. distance, min_distance_config = _nearest_iou_edge_distance(
  1194. bbox,
  1195. match_bbox,
  1196. block["sub_label"],
  1197. vision_labels=vision_labels,
  1198. no_mask_labels=no_mask_labels,
  1199. median_width=median_width,
  1200. title_labels=title_labels,
  1201. title_text=block["title_text"],
  1202. sub_title=block["sub_title"],
  1203. min_distance_config=min_distance_config,
  1204. tolerance_len=10,
  1205. )
  1206. elif distance_type == "title_text":
  1207. if (
  1208. match_block["label"] in title_labels + ["abstract"]
  1209. and match_block["title_text"] != []
  1210. ):
  1211. iou_left_up = _calculate_overlap_area_div_minbox_area_ratio(
  1212. bbox,
  1213. match_block["title_text"][0][1],
  1214. )
  1215. iou_right_down = _calculate_overlap_area_div_minbox_area_ratio(
  1216. bbox,
  1217. match_block["title_text"][-1][1],
  1218. )
  1219. iou = 1 - max(iou_left_up, iou_right_down)
  1220. distance = _manhattan_distance(bbox, match_bbox) * iou
  1221. else:
  1222. distance = float("inf")
  1223. elif distance_type == "manhattan":
  1224. distance = _manhattan_distance(bbox, match_bbox)
  1225. elif distance_type == "vision_footnote":
  1226. if (
  1227. match_block["label"] in vision_labels
  1228. and match_block["vision_footnote"] != []
  1229. ):
  1230. iou_left_up = _calculate_overlap_area_div_minbox_area_ratio(
  1231. bbox,
  1232. match_block["vision_footnote"][0],
  1233. )
  1234. iou_right_down = _calculate_overlap_area_div_minbox_area_ratio(
  1235. bbox,
  1236. match_block["vision_footnote"][-1],
  1237. )
  1238. iou = 1 - max(iou_left_up, iou_right_down)
  1239. distance = _manhattan_distance(bbox, match_bbox) * iou
  1240. else:
  1241. distance = float("inf")
  1242. elif distance_type == "vision_body":
  1243. if (
  1244. match_block["label"] in vision_title_labels
  1245. and block["vision_footnote"] != []
  1246. ):
  1247. iou_left_up = _calculate_overlap_area_div_minbox_area_ratio(
  1248. match_bbox,
  1249. block["vision_footnote"][0],
  1250. )
  1251. iou_right_down = _calculate_overlap_area_div_minbox_area_ratio(
  1252. match_bbox,
  1253. block["vision_footnote"][-1],
  1254. )
  1255. iou = 1 - max(iou_left_up, iou_right_down)
  1256. distance = _manhattan_distance(bbox, match_bbox) * iou
  1257. else:
  1258. distance = float("inf")
  1259. else:
  1260. raise NotImplementedError
  1261. if distance < min_distance:
  1262. min_distance = distance
  1263. if is_add_index:
  1264. nearest_gt_index = match_block.get("index", 999)
  1265. else:
  1266. nearest_gt_index = match_block.get("sub_index", 999)
  1267. if is_add_index:
  1268. block["index"] = nearest_gt_index
  1269. else:
  1270. block["sub_index"] = nearest_gt_index
  1271. parsing_result.append(block)
  1272. # double text label
  1273. double_text_blocks.sort(
  1274. key=lambda x: (
  1275. x["layout_bbox"][1] // 10,
  1276. x["layout_bbox"][0] // median_width,
  1277. x["layout_bbox"][1] ** 2 + x["layout_bbox"][0] ** 2,
  1278. ),
  1279. )
  1280. nearest_match_(
  1281. double_text_blocks,
  1282. distance_type="nearest_iou_edge_distance",
  1283. )
  1284. parsing_result.sort(
  1285. key=lambda x: (x["index"], x["layout_bbox"][1], x["layout_bbox"][0]),
  1286. )
  1287. for idx, block in enumerate(parsing_result):
  1288. block["index"] = idx + 1
  1289. block["sub_index"] = idx + 1
  1290. # title label
  1291. title_blocks.sort(
  1292. key=lambda x: (
  1293. x["layout_bbox"][1] // 10,
  1294. x["layout_bbox"][0] // median_width,
  1295. x["layout_bbox"][1] ** 2 + x["layout_bbox"][0] ** 2,
  1296. ),
  1297. )
  1298. nearest_match_(title_blocks, distance_type="nearest_iou_edge_distance")
  1299. if doc_flag:
  1300. text_sort_labels = ["doc_title"]
  1301. text_label_priority = {
  1302. label: priority for priority, label in enumerate(text_sort_labels)
  1303. }
  1304. doc_titles = []
  1305. for i, block in enumerate(parsing_result):
  1306. if block["label"] == "doc_title":
  1307. doc_titles.append(
  1308. (i, block["layout_bbox"][1], block["layout_bbox"][0]),
  1309. )
  1310. doc_titles.sort(key=lambda x: (x[1], x[2]))
  1311. first_doc_title_index = doc_titles[0][0]
  1312. parsing_result[first_doc_title_index]["index"] = 1
  1313. parsing_result.sort(
  1314. key=lambda x: (
  1315. x["index"],
  1316. text_label_priority.get(x["label"], 9999),
  1317. x["layout_bbox"][1],
  1318. x["layout_bbox"][0],
  1319. ),
  1320. )
  1321. else:
  1322. parsing_result.sort(
  1323. key=lambda x: (
  1324. x["index"],
  1325. x["layout_bbox"][1],
  1326. x["layout_bbox"][0],
  1327. ),
  1328. )
  1329. for idx, block in enumerate(parsing_result):
  1330. block["index"] = idx + 1
  1331. block["sub_index"] = idx + 1
  1332. # title-text label
  1333. nearest_match_(title_text_blocks, distance_type="title_text")
  1334. text_sort_labels = ["doc_title", "paragraph_title", "title_text"]
  1335. text_label_priority = {
  1336. label: priority for priority, label in enumerate(text_sort_labels)
  1337. }
  1338. parsing_result.sort(
  1339. key=lambda x: (
  1340. x["index"],
  1341. text_label_priority.get(x["sub_label"], 9999),
  1342. x["layout_bbox"][1],
  1343. x["layout_bbox"][0],
  1344. ),
  1345. )
  1346. for idx, block in enumerate(parsing_result):
  1347. block["index"] = idx + 1
  1348. block["sub_index"] = idx + 1
  1349. # image,figure,chart,seal label
  1350. nearest_match_(
  1351. vision_title_blocks,
  1352. distance_type="nearest_iou_edge_distance",
  1353. is_add_index=False,
  1354. )
  1355. parsing_result.sort(
  1356. key=lambda x: (
  1357. x["sub_index"],
  1358. x["layout_bbox"][1],
  1359. x["layout_bbox"][0],
  1360. ),
  1361. )
  1362. for idx, block in enumerate(parsing_result):
  1363. block["sub_index"] = idx + 1
  1364. # image,figure,chart,seal label
  1365. nearest_match_(
  1366. vision_blocks,
  1367. distance_type="nearest_iou_edge_distance",
  1368. is_add_index=False,
  1369. )
  1370. parsing_result.sort(
  1371. key=lambda x: (
  1372. x["sub_index"],
  1373. x["layout_bbox"][1],
  1374. x["layout_bbox"][0],
  1375. ),
  1376. )
  1377. for idx, block in enumerate(parsing_result):
  1378. block["sub_index"] = idx + 1
  1379. # vision footnote label
  1380. nearest_match_(
  1381. vision_footnote_blocks,
  1382. distance_type="vision_footnote",
  1383. is_add_index=False,
  1384. )
  1385. text_label_priority = {"vision_footnote": 9999}
  1386. parsing_result.sort(
  1387. key=lambda x: (
  1388. x["sub_index"],
  1389. text_label_priority.get(x["sub_label"], 0),
  1390. x["layout_bbox"][1],
  1391. x["layout_bbox"][0],
  1392. ),
  1393. )
  1394. for idx, block in enumerate(parsing_result):
  1395. block["sub_index"] = idx + 1
  1396. # header、footnote、header_image... label
  1397. nearest_match_(other_blocks, distance_type="manhattan", is_add_index=False)
  1398. return data
  1399. def _manhattan_distance(
  1400. point1: Tuple[float, float],
  1401. point2: Tuple[float, float],
  1402. weight_x: float = 1.0,
  1403. weight_y: float = 1.0,
  1404. ) -> float:
  1405. """
  1406. Calculate the weighted Manhattan distance between two points.
  1407. Args:
  1408. point1 (Tuple[float, float]): The first point as (x, y).
  1409. point2 (Tuple[float, float]): The second point as (x, y).
  1410. weight_x (float): The weight for the x-axis distance. Default is 1.0.
  1411. weight_y (float): The weight for the y-axis distance. Default is 1.0.
  1412. Returns:
  1413. float: The weighted Manhattan distance between the two points.
  1414. """
  1415. return weight_x * abs(point1[0] - point2[0]) + weight_y * abs(point1[1] - point2[1])
  1416. def _calculate_horizontal_distance(
  1417. input_bbox: List[int],
  1418. match_bbox: List[int],
  1419. height: int,
  1420. disperse: int,
  1421. title_text: List[Tuple[int, List[int]]],
  1422. ) -> float:
  1423. """
  1424. Calculate the horizontal distance between two bounding boxes, considering title text adjustments.
  1425. Args:
  1426. input_bbox (List[int]): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  1427. match_bbox (List[int]): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  1428. height (int): The height of the input bounding box used for normalization.
  1429. disperse (int): The dispersion factor used to normalize the horizontal distance.
  1430. title_text (List[Tuple[int, List[int]]]): A list of tuples containing title text information and their bounding box coordinates.
  1431. Format: [(position_indicator, [x1, y1, x2, y2]), ...].
  1432. Returns:
  1433. float: The calculated horizontal distance taking into account the title text adjustments.
  1434. """
  1435. x1, y1, x2, y2 = input_bbox
  1436. x1_prime, y1_prime, x2_prime, y2_prime = match_bbox
  1437. # Determine vertical distance adjustment based on title text
  1438. if y2 < y1_prime:
  1439. if title_text and title_text[-1][0] == 2:
  1440. y2 += title_text[-1][1][3] - title_text[-1][1][1]
  1441. vertical_adjustment = (y1_prime - y2) * 0.5
  1442. else:
  1443. if title_text and title_text[0][0] == 1:
  1444. y1 -= title_text[0][1][3] - title_text[0][1][1]
  1445. vertical_adjustment = y1 - y2_prime
  1446. # Calculate horizontal distance with adjustments
  1447. horizontal_distance = (
  1448. abs(x2_prime - x1) // disperse
  1449. + vertical_adjustment // height
  1450. + vertical_adjustment / 5000
  1451. )
  1452. return horizontal_distance
  1453. def _calculate_vertical_distance(
  1454. input_bbox: List[int],
  1455. match_bbox: List[int],
  1456. width: int,
  1457. disperse: int,
  1458. title_text: List[Tuple[int, List[int]]],
  1459. ) -> float:
  1460. """
  1461. Calculate the vertical distance between two bounding boxes, considering title text adjustments.
  1462. Args:
  1463. input_bbox (List[int]): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  1464. match_bbox (List[int]): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  1465. width (int): The width of the input bounding box used for normalization.
  1466. disperse (int): The dispersion factor used to normalize the vertical distance.
  1467. title_text (List[Tuple[int, List[int]]]): A list of tuples containing title text information and their bounding box coordinates.
  1468. Format: [(position_indicator, [x1, y1, x2, y2]), ...].
  1469. Returns:
  1470. float: The calculated vertical distance taking into account the title text adjustments.
  1471. """
  1472. x1, y1, x2, y2 = input_bbox
  1473. x1_prime, y1_prime, x2_prime, y2_prime = match_bbox
  1474. # Determine horizontal distance adjustment based on title text
  1475. if x1 > x2_prime:
  1476. if title_text and title_text[0][0] == 3:
  1477. x1 -= title_text[0][1][2] - title_text[0][1][0]
  1478. horizontal_adjustment = (x1 - x2_prime) * 0.5
  1479. else:
  1480. if title_text and title_text[-1][0] == 4:
  1481. x2 += title_text[-1][1][2] - title_text[-1][1][0]
  1482. horizontal_adjustment = x1_prime - x2
  1483. # Calculate vertical distance with adjustments
  1484. vertical_distance = (
  1485. abs(y2_prime - y1) // disperse
  1486. + horizontal_adjustment // width
  1487. + horizontal_adjustment / 5000
  1488. )
  1489. return vertical_distance
  1490. def _nearest_edge_distance(
  1491. input_bbox: List[int],
  1492. match_bbox: List[int],
  1493. weight: List[float] = [1.0, 1.0, 1.0, 1.0],
  1494. label: str = "text",
  1495. no_mask_labels: List[str] = [],
  1496. min_edge_distances_config: List[float] = [],
  1497. tolerance_len: float = 10.0,
  1498. ) -> Tuple[float, List[float]]:
  1499. """
  1500. Calculate the nearest edge distance between two bounding boxes, considering directional weights.
  1501. Args:
  1502. input_bbox (list): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  1503. match_bbox (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  1504. weight (list, optional): Directional weights for the edge distances [left, right, up, down]. Defaults to [1, 1, 1, 1].
  1505. label (str, optional): The label/type of the object in the bounding box (e.g., 'text'). Defaults to 'text'.
  1506. no_mask_labels (list, optional): Labels for which no masking is applied when calculating edge distances. Defaults to an empty list.
  1507. min_edge_distances_config (list, optional): Configuration for minimum edge distances [min_edge_distance_x, min_edge_distance_y].
  1508. Defaults to [float('inf'), float('inf')].
  1509. tolerance_len (float, optional): The tolerance length for adjusting edge distances. Defaults to 10.
  1510. Returns:
  1511. Tuple[float, List[float]]: A tuple containing:
  1512. - The calculated minimum edge distance between the bounding boxes.
  1513. - A list with the minimum edge distances in the x and y directions.
  1514. """
  1515. match_bbox_iou = _calculate_overlap_area_div_minbox_area_ratio(
  1516. input_bbox,
  1517. match_bbox,
  1518. )
  1519. if match_bbox_iou > 0 and label not in no_mask_labels:
  1520. return 0, [0, 0]
  1521. if not min_edge_distances_config:
  1522. min_edge_distances_config = [float("inf"), float("inf")]
  1523. min_edge_distance_x, min_edge_distance_y = min_edge_distances_config
  1524. x1, y1, x2, y2 = input_bbox
  1525. x1_prime, y1_prime, x2_prime, y2_prime = match_bbox
  1526. direction_num = 0
  1527. distance_x = float("inf")
  1528. distance_y = float("inf")
  1529. distance = [float("inf")] * 4
  1530. # input_bbox is to the left of match_bbox
  1531. if x2 < x1_prime:
  1532. direction_num += 1
  1533. distance[0] = x1_prime - x2
  1534. if abs(distance[0] - min_edge_distance_x) <= tolerance_len:
  1535. distance_x = min_edge_distance_x * weight[0]
  1536. else:
  1537. distance_x = distance[0] * weight[0]
  1538. # input_bbox is to the right of match_bbox
  1539. elif x1 > x2_prime:
  1540. direction_num += 1
  1541. distance[1] = x1 - x2_prime
  1542. if abs(distance[1] - min_edge_distance_x) <= tolerance_len:
  1543. distance_x = min_edge_distance_x * weight[1]
  1544. else:
  1545. distance_x = distance[1] * weight[1]
  1546. elif match_bbox_iou > 0:
  1547. distance[0] = 0
  1548. distance_x = 0
  1549. # input_bbox is above match_bbox
  1550. if y2 < y1_prime:
  1551. direction_num += 1
  1552. distance[2] = y1_prime - y2
  1553. if abs(distance[2] - min_edge_distance_y) <= tolerance_len:
  1554. distance_y = min_edge_distance_y * weight[2]
  1555. else:
  1556. distance_y = distance[2] * weight[2]
  1557. if label in no_mask_labels:
  1558. distance_y = max(0.1, distance_y) * 100
  1559. # input_bbox is below match_bbox
  1560. elif y1 > y2_prime:
  1561. direction_num += 1
  1562. distance[3] = y1 - y2_prime
  1563. if abs(distance[3] - min_edge_distance_y) <= tolerance_len:
  1564. distance_y = min_edge_distance_y * weight[3]
  1565. else:
  1566. distance_y = distance[3] * weight[3]
  1567. elif match_bbox_iou > 0:
  1568. distance[2] = 0
  1569. distance_y = 0
  1570. if direction_num == 2:
  1571. return (distance_x + distance_y), [
  1572. min(distance[0], distance[1]),
  1573. min(distance[2], distance[3]),
  1574. ]
  1575. else:
  1576. return min(distance_x, distance_y), [
  1577. min(distance[0], distance[1]),
  1578. min(distance[2], distance[3]),
  1579. ]
  1580. def _get_weights(label, horizontal):
  1581. """Define weights based on the label and orientation."""
  1582. if label == "doc_title":
  1583. return (
  1584. [1, 0.1, 0.1, 1] if horizontal else [0.2, 0.1, 1, 1]
  1585. ) # left-down , right-left
  1586. elif label in [
  1587. "paragraph_title",
  1588. "abstract",
  1589. "figure_title",
  1590. "chart_title",
  1591. "image",
  1592. "seal",
  1593. "chart",
  1594. "figure",
  1595. ]:
  1596. return [1, 1, 0.1, 1] # down
  1597. else:
  1598. return [1, 1, 1, 0.1] # up
  1599. def _nearest_iou_edge_distance(
  1600. input_bbox: List[int],
  1601. match_bbox: List[int],
  1602. label: str,
  1603. vision_labels: List[str],
  1604. no_mask_labels: List[str],
  1605. median_width: int = -1,
  1606. title_labels: List[str] = [],
  1607. title_text: List[Tuple[int, List[int]]] = [],
  1608. sub_title: List[List[int]] = [],
  1609. min_distance_config: List[float] = [],
  1610. tolerance_len: float = 10.0,
  1611. ) -> Tuple[float, List[float]]:
  1612. """
  1613. Calculate the nearest IOU edge distance between two bounding boxes, considering label types, title adjustments, and minimum distance configurations.
  1614. This function computes the edge distance between two bounding boxes while considering their overlap (IOU) and various adjustments based on label types,
  1615. title text, and subtitle information. It also applies minimum distance configurations and tolerance adjustments.
  1616. Args:
  1617. input_bbox (List[int]): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  1618. match_bbox (List[int]): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  1619. label (str): The label/type of the object in the bounding box (e.g., 'image', 'text', etc.).
  1620. vision_labels (List[str]): List of labels for vision-related objects (e.g., images, icons).
  1621. no_mask_labels (List[str]): Labels for which no masking is applied when calculating edge distances.
  1622. median_width (int, optional): The median width for title dispersion calculation. Defaults to -1.
  1623. title_labels (List[str], optional): Labels that indicate the object is a title. Defaults to an empty list.
  1624. title_text (List[Tuple[int, List[int]]], optional): Text content associated with title labels, in the format [(position_indicator, [x1, y1, x2, y2]), ...].
  1625. sub_title (List[List[int]], optional): List of subtitle bounding boxes to adjust the input_bbox. Defaults to an empty list.
  1626. min_distance_config (List[float], optional): Configuration for minimum distances [min_edge_distances_config, up_edge_distances_config, total_distance].
  1627. tolerance_len (float, optional): The tolerance length for adjusting edge distances. Defaults to 10.0.
  1628. Returns:
  1629. Tuple[float, List[float]]: A tuple containing:
  1630. - The calculated distance considering IOU and adjustments.
  1631. - The updated minimum distance configuration.
  1632. """
  1633. x1, y1, x2, y2 = input_bbox
  1634. x1_prime, y1_prime, x2_prime, y2_prime = match_bbox
  1635. min_edge_distances_config, up_edge_distances_config, total_distance = (
  1636. min_distance_config
  1637. )
  1638. iou_distance = 0
  1639. if label in vision_labels:
  1640. horizontal1 = horizontal2 = True
  1641. else:
  1642. horizontal1 = _get_bbox_direction(input_bbox)
  1643. horizontal2 = _get_bbox_direction(match_bbox, 3)
  1644. if (
  1645. horizontal1 != horizontal2
  1646. or _get_projection_iou(input_bbox, match_bbox, horizontal1) < 0.01
  1647. ):
  1648. iou_distance = 1
  1649. elif label == "doc_title" or (label in title_labels and title_text):
  1650. # Calculate distance for titles
  1651. disperse = max(1, median_width)
  1652. width = x2 - x1
  1653. height = y2 - y1
  1654. if horizontal1:
  1655. return (
  1656. _calculate_horizontal_distance(
  1657. input_bbox,
  1658. match_bbox,
  1659. height,
  1660. disperse,
  1661. title_text,
  1662. ),
  1663. min_distance_config,
  1664. )
  1665. else:
  1666. return (
  1667. _calculate_vertical_distance(
  1668. input_bbox,
  1669. match_bbox,
  1670. width,
  1671. disperse,
  1672. title_text,
  1673. ),
  1674. min_distance_config,
  1675. )
  1676. # Adjust input_bbox based on sub_title
  1677. if sub_title:
  1678. for sub in sub_title:
  1679. x1_, y1_, x2_, y2_ = sub
  1680. x1, y1, x2, y2 = (
  1681. min(x1, x1_),
  1682. min(
  1683. y1,
  1684. y1_,
  1685. ),
  1686. max(x2, x2_),
  1687. max(y2, y2_),
  1688. )
  1689. input_bbox = [x1, y1, x2, y2]
  1690. # Calculate edge distance
  1691. weight = _get_weights(label, horizontal1)
  1692. if label == "abstract":
  1693. tolerance_len *= 3
  1694. edge_distance, edge_distance_config = _nearest_edge_distance(
  1695. input_bbox,
  1696. match_bbox,
  1697. weight,
  1698. label=label,
  1699. no_mask_labels=no_mask_labels,
  1700. min_edge_distances_config=min_edge_distances_config,
  1701. tolerance_len=tolerance_len,
  1702. )
  1703. # Weights for combining distances
  1704. iou_edge_weight = [10**6, 10**3, 1, 0.001]
  1705. # Calculate up and left edge distances
  1706. up_edge_distance = y1_prime
  1707. left_edge_distance = x1_prime
  1708. if (
  1709. label in no_mask_labels or label == "paragraph_title" or label in vision_labels
  1710. ) and y1 > y2_prime:
  1711. up_edge_distance = -y2_prime
  1712. left_edge_distance = -x2_prime
  1713. min_up_edge_distance = up_edge_distances_config
  1714. if abs(min_up_edge_distance - up_edge_distance) <= tolerance_len:
  1715. up_edge_distance = min_up_edge_distance
  1716. # Calculate total distance
  1717. distance = (
  1718. iou_distance * iou_edge_weight[0]
  1719. + edge_distance * iou_edge_weight[1]
  1720. + up_edge_distance * iou_edge_weight[2]
  1721. + left_edge_distance * iou_edge_weight[3]
  1722. )
  1723. # Update minimum distance configuration if a smaller distance is found
  1724. if total_distance > distance:
  1725. edge_distance_config = [
  1726. min(min_edge_distances_config[0], edge_distance_config[0]),
  1727. min(min_edge_distances_config[1], edge_distance_config[1]),
  1728. ]
  1729. min_distance_config = [
  1730. edge_distance_config,
  1731. min(up_edge_distance, up_edge_distances_config),
  1732. distance,
  1733. ]
  1734. return distance, min_distance_config
  1735. def get_show_color(label: str) -> Tuple:
  1736. label_colors = {
  1737. # Medium Blue (from 'titles_list')
  1738. "paragraph_title": (102, 102, 255, 100),
  1739. "doc_title": (255, 248, 220, 100), # Cornsilk
  1740. # Light Yellow (from 'tables_caption_list')
  1741. "table_title": (255, 255, 102, 100),
  1742. # Sky Blue (from 'imgs_caption_list')
  1743. "figure_title": (102, 178, 255, 100),
  1744. "chart_title": (221, 160, 221, 100), # Plum
  1745. "vision_footnote": (144, 238, 144, 100), # Light Green
  1746. # Deep Purple (from 'texts_list')
  1747. "text": (153, 0, 76, 100),
  1748. # Bright Green (from 'interequations_list')
  1749. "formula": (0, 255, 0, 100),
  1750. "abstract": (255, 239, 213, 100), # Papaya Whip
  1751. # Medium Green (from 'lists_list' and 'indexs_list')
  1752. "content": (40, 169, 92, 100),
  1753. # Neutral Gray (from 'dropped_bbox_list')
  1754. "seal": (158, 158, 158, 100),
  1755. # Olive Yellow (from 'tables_body_list')
  1756. "table": (204, 204, 0, 100),
  1757. # Bright Green (from 'imgs_body_list')
  1758. "image": (153, 255, 51, 100),
  1759. # Bright Green (from 'imgs_body_list')
  1760. "figure": (153, 255, 51, 100),
  1761. "chart": (216, 191, 216, 100), # Thistle
  1762. # Pale Yellow-Green (from 'tables_footnote_list')
  1763. "reference": (229, 255, 204, 100),
  1764. "algorithm": (255, 250, 240, 100), # Floral White
  1765. }
  1766. default_color = (158, 158, 158, 100)
  1767. return label_colors.get(label, default_color)