utils.py 65 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. __all__ = [
  15. "get_sub_regions_ocr_res",
  16. "get_layout_ordering",
  17. "recursive_img_array2path",
  18. "get_show_color",
  19. ]
  20. import numpy as np
  21. import copy
  22. import cv2
  23. import uuid
  24. from pathlib import Path
  25. from typing import List
  26. from ..ocr.result import OCRResult
  27. from ...models_new.object_detection.result import DetResult
  28. def get_overlap_boxes_idx(src_boxes: np.ndarray, ref_boxes: np.ndarray) -> List:
  29. """
  30. Get the indices of source boxes that overlap with reference boxes based on a specified threshold.
  31. Args:
  32. src_boxes (np.ndarray): A 2D numpy array of source bounding boxes.
  33. ref_boxes (np.ndarray): A 2D numpy array of reference bounding boxes.
  34. Returns:
  35. list: A list of indices of source boxes that overlap with any reference box.
  36. """
  37. match_idx_list = []
  38. src_boxes_num = len(src_boxes)
  39. if src_boxes_num > 0 and len(ref_boxes) > 0:
  40. for rno in range(len(ref_boxes)):
  41. ref_box = ref_boxes[rno]
  42. x1 = np.maximum(ref_box[0], src_boxes[:, 0])
  43. y1 = np.maximum(ref_box[1], src_boxes[:, 1])
  44. x2 = np.minimum(ref_box[2], src_boxes[:, 2])
  45. y2 = np.minimum(ref_box[3], src_boxes[:, 3])
  46. pub_w = x2 - x1
  47. pub_h = y2 - y1
  48. match_idx = np.where((pub_w > 3) & (pub_h > 3))[0]
  49. match_idx_list.extend(match_idx)
  50. return match_idx_list
  51. def get_sub_regions_ocr_res(
  52. overall_ocr_res: OCRResult, object_boxes: List, flag_within: bool = True
  53. ) -> OCRResult:
  54. """
  55. Filters OCR results to only include text boxes within specified object boxes based on a flag.
  56. Args:
  57. overall_ocr_res (OCRResult): The original OCR result containing all text boxes.
  58. object_boxes (list): A list of bounding boxes for the objects of interest.
  59. flag_within (bool): If True, only include text boxes within the object boxes. If False, exclude text boxes within the object boxes.
  60. Returns:
  61. OCRResult: A filtered OCR result containing only the relevant text boxes.
  62. """
  63. sub_regions_ocr_res = {}
  64. sub_regions_ocr_res["rec_polys"] = []
  65. sub_regions_ocr_res["rec_texts"] = []
  66. sub_regions_ocr_res["rec_scores"] = []
  67. sub_regions_ocr_res["rec_boxes"] = []
  68. overall_text_boxes = overall_ocr_res["rec_boxes"]
  69. match_idx_list = get_overlap_boxes_idx(overall_text_boxes, object_boxes)
  70. match_idx_list = list(set(match_idx_list))
  71. for box_no in range(len(overall_text_boxes)):
  72. if flag_within:
  73. if box_no in match_idx_list:
  74. flag_match = True
  75. else:
  76. flag_match = False
  77. else:
  78. if box_no not in match_idx_list:
  79. flag_match = True
  80. else:
  81. flag_match = False
  82. if flag_match:
  83. sub_regions_ocr_res["rec_polys"].append(
  84. overall_ocr_res["rec_polys"][box_no]
  85. )
  86. sub_regions_ocr_res["rec_texts"].append(
  87. overall_ocr_res["rec_texts"][box_no]
  88. )
  89. sub_regions_ocr_res["rec_scores"].append(
  90. overall_ocr_res["rec_scores"][box_no]
  91. )
  92. sub_regions_ocr_res["rec_boxes"].append(
  93. overall_ocr_res["rec_boxes"][box_no]
  94. )
  95. return sub_regions_ocr_res
  96. def calculate_iou(box1, box2):
  97. """
  98. Calculate Intersection over Union (IoU) between two bounding boxes.
  99. Args:
  100. box1, box2: Lists or tuples representing bounding boxes [x_min, y_min, x_max, y_max].
  101. Returns:
  102. float: The IoU value.
  103. """
  104. box1 = list(map(int, box1))
  105. box2 = list(map(int, box2))
  106. x1_min, y1_min, x1_max, y1_max = box1
  107. x2_min, y2_min, x2_max, y2_max = box2
  108. inter_x_min = max(x1_min, x2_min)
  109. inter_y_min = max(y1_min, y2_min)
  110. inter_x_max = min(x1_max, x2_max)
  111. inter_y_max = min(y1_max, y2_max)
  112. if inter_x_max <= inter_x_min or inter_y_max <= inter_y_min:
  113. return 0.0
  114. inter_area = (inter_x_max - inter_x_min) * (inter_y_max - inter_y_min)
  115. box1_area = (x1_max - x1_min) * (y1_max - y1_min)
  116. box2_area = (x2_max - x2_min) * (y2_max - y2_min)
  117. min_area = min(box1_area, box2_area)
  118. if min_area <= 0:
  119. return 0.0
  120. iou = inter_area / min_area
  121. return iou
  122. def _whether_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.6):
  123. _, y0_1, _, y1_1 = bbox1
  124. _, y0_2, _, y1_2 = bbox2
  125. overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2))
  126. min_height = min(y1_1 - y0_1, y1_2 - y0_2)
  127. return (overlap / min_height) > overlap_ratio_threshold
  128. def _sort_box_by_y_projection(layout_bbox, ocr_res, line_height_threshold=0.7):
  129. assert ocr_res["boxes"] and ocr_res["rec_texts"]
  130. # span->line->block
  131. boxes = ocr_res["boxes"]
  132. rec_text = ocr_res["rec_texts"]
  133. x_min, x_max = layout_bbox[0], layout_bbox[2]
  134. spans = list(zip(boxes, rec_text))
  135. spans.sort(key=lambda span: span[0][1])
  136. spans = [list(span) for span in spans]
  137. lines = []
  138. first_span = spans[0]
  139. current_line = [first_span]
  140. current_y0, current_y1 = first_span[0][1], first_span[0][3]
  141. for span in spans[1:]:
  142. y0, y1 = span[0][1], span[0][3]
  143. if _whether_overlaps_y_exceeds_threshold(
  144. (0, current_y0, 0, current_y1),
  145. (0, y0, 0, y1),
  146. line_height_threshold,
  147. ):
  148. current_line.append(span)
  149. current_y0 = min(current_y0, y0)
  150. current_y1 = max(current_y1, y1)
  151. else:
  152. lines.append(current_line)
  153. current_line = [span]
  154. current_y0, current_y1 = y0, y1
  155. if current_line:
  156. lines.append(current_line)
  157. for line in lines:
  158. line.sort(key=lambda span: span[0][0])
  159. first_span = line[0]
  160. end_span = line[-1]
  161. if first_span[0][0] - x_min > 20:
  162. first_span[1] = "\n" + first_span[1]
  163. if x_max - end_span[0][2] > 20:
  164. end_span[1] = end_span[1] + "\n"
  165. ocr_res["boxes"] = [span[0] for line in lines for span in line]
  166. ocr_res["rec_texts"] = [span[1] + " " for line in lines for span in line]
  167. return ocr_res
  168. def get_structure_res(
  169. overall_ocr_res: OCRResult,
  170. layout_det_res: DetResult,
  171. table_res_list,
  172. ) -> OCRResult:
  173. """
  174. Extract structured information from OCR and layout detection results.
  175. Args:
  176. overall_ocr_res (OCRResult): An object containing the overall OCR results, including detected text boxes and recognized text. The structure is expected to have:
  177. - "input_img": The image on which OCR was performed.
  178. - "dt_boxes": A list of detected text box coordinates.
  179. - "rec_texts": A list of recognized text corresponding to the detected boxes.
  180. layout_det_res (DetResult): An object containing the layout detection results, including detected layout boxes and their labels. The structure is expected to have:
  181. - "boxes": A list of dictionaries with keys "coordinate" for box coordinates and "label" for the type of content.
  182. table_res_list (list): A list of table detection results, where each item is a dictionary containing:
  183. - "layout_bbox": The bounding box of the table layout.
  184. - "pred_html": The predicted HTML representation of the table.
  185. Returns:
  186. list: A list of structured boxes where each item is a dictionary containing:
  187. - "label": The label of the content (e.g., 'table', 'chart', 'image').
  188. - The label as a key with either table HTML or image data and text.
  189. - "layout_bbox": The coordinates of the layout box.
  190. """
  191. structure_boxes = []
  192. input_img = overall_ocr_res["doc_preprocessor_res"]["output_img"]
  193. for box_info in layout_det_res["boxes"]:
  194. layout_bbox = box_info["coordinate"]
  195. label = box_info["label"]
  196. rec_res = {"boxes": [], "rec_texts": [], "flag": False}
  197. seg_start_flag = True
  198. seg_end_flag = True
  199. if label == "table":
  200. for i, table_res in enumerate(table_res_list):
  201. if calculate_iou(layout_bbox, table_res["cell_box_list"][0]) > 0.5:
  202. structure_boxes.append(
  203. {
  204. "label": label,
  205. f"{label}": table_res["pred_html"],
  206. "layout_bbox": layout_bbox,
  207. "seg_start_flag": seg_start_flag,
  208. "seg_end_flag": seg_end_flag,
  209. },
  210. )
  211. del table_res_list[i]
  212. break
  213. else:
  214. overall_text_boxes = overall_ocr_res["rec_boxes"]
  215. for box_no in range(len(overall_text_boxes)):
  216. if calculate_iou(layout_bbox, overall_text_boxes[box_no]) > 0.5:
  217. rec_res["boxes"].append(overall_text_boxes[box_no])
  218. rec_res["rec_texts"].append(
  219. overall_ocr_res["rec_texts"][box_no],
  220. )
  221. rec_res["flag"] = True
  222. if rec_res["flag"]:
  223. rec_res = _sort_box_by_y_projection(layout_bbox, rec_res, 0.7)
  224. rec_res_first_bbox = rec_res["boxes"][0]
  225. rec_res_end_bbox = rec_res["boxes"][-1]
  226. if rec_res_first_bbox[0] - layout_bbox[0] < 20:
  227. seg_start_flag = False
  228. if layout_bbox[2] - rec_res_end_bbox[2] < 20:
  229. seg_end_flag = False
  230. if label == "formula":
  231. rec_res["rec_texts"] = [
  232. rec_res_text.replace("$", "")
  233. for rec_res_text in rec_res["rec_texts"]
  234. ]
  235. if label in ["chart", "image"]:
  236. structure_boxes.append(
  237. {
  238. "label": label,
  239. f"{label}": {
  240. "img": input_img[
  241. int(layout_bbox[1]) : int(layout_bbox[3]),
  242. int(layout_bbox[0]) : int(layout_bbox[2]),
  243. ],
  244. },
  245. "layout_bbox": layout_bbox,
  246. "seg_start_flag": seg_start_flag,
  247. "seg_end_flag": seg_end_flag,
  248. },
  249. )
  250. else:
  251. structure_boxes.append(
  252. {
  253. "label": label,
  254. f"{label}": "".join(rec_res["rec_texts"]),
  255. "layout_bbox": layout_bbox,
  256. "seg_start_flag": seg_start_flag,
  257. "seg_end_flag": seg_end_flag,
  258. },
  259. )
  260. return structure_boxes
  261. def projection_by_bboxes(boxes: np.ndarray, axis: int) -> np.ndarray:
  262. """
  263. Generate a 1D projection histogram from bounding boxes along a specified axis.
  264. Args:
  265. boxes: A (N, 4) array of bounding boxes defined by [x_min, y_min, x_max, y_max].
  266. axis: Axis for projection; 0 for horizontal (x-axis), 1 for vertical (y-axis).
  267. Returns:
  268. A 1D numpy array representing the projection histogram based on bounding box intervals.
  269. """
  270. assert axis in [0, 1]
  271. max_length = np.max(boxes[:, axis::2])
  272. projection = np.zeros(max_length, dtype=int)
  273. # Increment projection histogram over the interval defined by each bounding box
  274. for start, end in boxes[:, axis::2]:
  275. projection[start:end] += 1
  276. return projection
  277. def split_projection_profile(arr_values: np.ndarray, min_value: float, min_gap: float):
  278. """
  279. Split the projection profile into segments based on specified thresholds.
  280. Args:
  281. arr_values: 1D array representing the projection profile.
  282. min_value: Minimum value threshold to consider a profile segment significant.
  283. min_gap: Minimum gap width to consider a separation between segments.
  284. Returns:
  285. A tuple of start and end indices for each segment that meets the criteria.
  286. """
  287. # Identify indices where the projection exceeds the minimum value
  288. significant_indices = np.where(arr_values > min_value)[0]
  289. if not len(significant_indices):
  290. return
  291. # Calculate gaps between significant indices
  292. index_diffs = significant_indices[1:] - significant_indices[:-1]
  293. gap_indices = np.where(index_diffs > min_gap)[0]
  294. # Determine start and end indices of segments
  295. segment_starts = np.insert(
  296. significant_indices[gap_indices + 1],
  297. 0,
  298. significant_indices[0],
  299. )
  300. segment_ends = np.append(
  301. significant_indices[gap_indices],
  302. significant_indices[-1] + 1,
  303. )
  304. return segment_starts, segment_ends
  305. def recursive_yx_cut(boxes: np.ndarray, indices: List[int], res: List[int], min_gap=1):
  306. """
  307. Recursively project and segment bounding boxes, starting with Y-axis and followed by X-axis.
  308. Args:
  309. boxes: A (N, 4) array representing bounding boxes.
  310. indices: List of indices indicating the original position of boxes.
  311. res: List to store indices of the final segmented bounding boxes.
  312. """
  313. assert len(boxes) == len(indices)
  314. # Sort by y_min for Y-axis projection
  315. y_sorted_indices = boxes[:, 1].argsort()
  316. y_sorted_boxes = boxes[y_sorted_indices]
  317. y_sorted_indices = np.array(indices)[y_sorted_indices]
  318. # Perform Y-axis projection
  319. y_projection = projection_by_bboxes(boxes=y_sorted_boxes, axis=1)
  320. y_intervals = split_projection_profile(y_projection, 0, 1)
  321. if not y_intervals:
  322. return
  323. # Process each segment defined by Y-axis projection
  324. for y_start, y_end in zip(*y_intervals):
  325. # Select boxes within the current y interval
  326. y_interval_indices = (y_start <= y_sorted_boxes[:, 1]) & (
  327. y_sorted_boxes[:, 1] < y_end
  328. )
  329. y_boxes_chunk = y_sorted_boxes[y_interval_indices]
  330. y_indices_chunk = y_sorted_indices[y_interval_indices]
  331. # Sort by x_min for X-axis projection
  332. x_sorted_indices = y_boxes_chunk[:, 0].argsort()
  333. x_sorted_boxes_chunk = y_boxes_chunk[x_sorted_indices]
  334. x_sorted_indices_chunk = y_indices_chunk[x_sorted_indices]
  335. # Perform X-axis projection
  336. x_projection = projection_by_bboxes(boxes=x_sorted_boxes_chunk, axis=0)
  337. x_intervals = split_projection_profile(x_projection, 0, min_gap)
  338. if not x_intervals:
  339. continue
  340. # If X-axis cannot be further segmented, add current indices to results
  341. if len(x_intervals[0]) == 1:
  342. res.extend(x_sorted_indices_chunk)
  343. continue
  344. # Recursively process each segment defined by X-axis projection
  345. for x_start, x_end in zip(*x_intervals):
  346. x_interval_indices = (x_start <= x_sorted_boxes_chunk[:, 0]) & (
  347. x_sorted_boxes_chunk[:, 0] < x_end
  348. )
  349. recursive_yx_cut(
  350. x_sorted_boxes_chunk[x_interval_indices],
  351. x_sorted_indices_chunk[x_interval_indices],
  352. res,
  353. )
  354. def recursive_xy_cut(boxes: np.ndarray, indices: List[int], res: List[int], min_gap=1):
  355. """
  356. Recursively performs X-axis projection followed by Y-axis projection to segment bounding boxes.
  357. Args:
  358. boxes: A (N, 4) array representing bounding boxes with [x_min, y_min, x_max, y_max].
  359. indices: A list of indices representing the position of boxes in the original data.
  360. res: A list to store indices of bounding boxes that meet the criteria.
  361. """
  362. # Ensure boxes and indices have the same length
  363. assert len(boxes) == len(indices)
  364. # Sort by x_min to prepare for X-axis projection
  365. x_sorted_indices = boxes[:, 0].argsort()
  366. x_sorted_boxes = boxes[x_sorted_indices]
  367. x_sorted_indices = np.array(indices)[x_sorted_indices]
  368. # Perform X-axis projection
  369. x_projection = projection_by_bboxes(boxes=x_sorted_boxes, axis=0)
  370. x_intervals = split_projection_profile(x_projection, 0, 1)
  371. if not x_intervals:
  372. return
  373. # Process each segment defined by X-axis projection
  374. for x_start, x_end in zip(*x_intervals):
  375. # Select boxes within the current x interval
  376. x_interval_indices = (x_start <= x_sorted_boxes[:, 0]) & (
  377. x_sorted_boxes[:, 0] < x_end
  378. )
  379. x_boxes_chunk = x_sorted_boxes[x_interval_indices]
  380. x_indices_chunk = x_sorted_indices[x_interval_indices]
  381. # Sort selected boxes by y_min to prepare for Y-axis projection
  382. y_sorted_indices = x_boxes_chunk[:, 1].argsort()
  383. y_sorted_boxes_chunk = x_boxes_chunk[y_sorted_indices]
  384. y_sorted_indices_chunk = x_indices_chunk[y_sorted_indices]
  385. # Perform Y-axis projection
  386. y_projection = projection_by_bboxes(boxes=y_sorted_boxes_chunk, axis=1)
  387. y_intervals = split_projection_profile(y_projection, 0, min_gap)
  388. if not y_intervals:
  389. continue
  390. # If Y-axis cannot be further segmented, add current indices to results
  391. if len(y_intervals[0]) == 1:
  392. res.extend(y_sorted_indices_chunk)
  393. continue
  394. # Recursively process each segment defined by Y-axis projection
  395. for y_start, y_end in zip(*y_intervals):
  396. y_interval_indices = (y_start <= y_sorted_boxes_chunk[:, 1]) & (
  397. y_sorted_boxes_chunk[:, 1] < y_end
  398. )
  399. recursive_xy_cut(
  400. y_sorted_boxes_chunk[y_interval_indices],
  401. y_sorted_indices_chunk[y_interval_indices],
  402. res,
  403. )
  404. def sort_by_xycut(block_bboxes, direction=0, min_gap=1):
  405. block_bboxes = np.asarray(block_bboxes).astype(int)
  406. res = []
  407. if direction == 1:
  408. recursive_yx_cut(
  409. block_bboxes,
  410. np.arange(
  411. len(block_bboxes),
  412. ),
  413. res,
  414. min_gap,
  415. )
  416. else:
  417. recursive_xy_cut(
  418. block_bboxes,
  419. np.arange(
  420. len(block_bboxes),
  421. ),
  422. res,
  423. min_gap,
  424. )
  425. return res
  426. def _img_array2path(data, save_path):
  427. """
  428. Save an image array to disk and return the file path.
  429. Args:
  430. data (np.ndarray): An image represented as a numpy array.
  431. save_path (str or Path): The base path where images should be saved.
  432. Returns:
  433. str: The relative path of the saved image file.
  434. """
  435. if isinstance(data, np.ndarray) and data.ndim == 3:
  436. # Generate a unique filename using UUID
  437. img_name = f"image_{uuid.uuid4().hex}.png"
  438. img_path = Path(save_path) / "imgs" / img_name
  439. img_path.parent.mkdir(
  440. parents=True,
  441. exist_ok=True,
  442. ) # Ensure the directory exists
  443. cv2.imwrite(str(img_path), data)
  444. return f"imgs/{img_name}"
  445. else:
  446. return ValueError
  447. def recursive_img_array2path(data, save_path, labels=[]):
  448. """
  449. Process a dictionary or list to save image arrays to disk and replace them with file paths.
  450. Args:
  451. data (dict or list): The data structure that may contain image arrays.
  452. save_path (str or Path): The base path where images should be saved.
  453. """
  454. if isinstance(data, dict):
  455. for k, v in data.items():
  456. if k in labels and isinstance(v, np.ndarray) and v.ndim == 3:
  457. data[k] = _img_array2path(v, save_path)
  458. else:
  459. recursive_img_array2path(v, save_path, labels)
  460. elif isinstance(data, list):
  461. for item in data:
  462. recursive_img_array2path(item, save_path, labels)
  463. def _calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2):
  464. """
  465. Calculate the ratio of the overlap area between bbox1 and bbox2
  466. to the area of the smaller bounding box.
  467. Args:
  468. bbox1 (list or tuple): Coordinates of the first bounding box [x_min, y_min, x_max, y_max].
  469. bbox2 (list or tuple): Coordinates of the second bounding box [x_min, y_min, x_max, y_max].
  470. Returns:
  471. float: The ratio of the overlap area to the area of the smaller bounding box.
  472. """
  473. x_left = max(bbox1[0], bbox2[0])
  474. y_top = max(bbox1[1], bbox2[1])
  475. x_right = min(bbox1[2], bbox2[2])
  476. y_bottom = min(bbox1[3], bbox2[3])
  477. if x_right <= x_left or y_bottom <= y_top:
  478. return 0.0
  479. # Calculate the area of the overlap
  480. intersection_area = (x_right - x_left) * (y_bottom - y_top)
  481. # Calculate the areas of both bounding boxes
  482. area_bbox1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
  483. area_bbox2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
  484. # Determine the minimum non-zero box area
  485. min_box_area = min(area_bbox1, area_bbox2)
  486. # Avoid division by zero in case of zero-area boxes
  487. if min_box_area == 0:
  488. return 0.0
  489. return intersection_area / min_box_area
  490. def _get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio, smaller=True):
  491. """
  492. Determine if the overlap area between two bounding boxes exceeds a given ratio
  493. and return the smaller (or larger) bounding box based on the `smaller` flag.
  494. Args:
  495. bbox1 (list or tuple): Coordinates of the first bounding box [x_min, y_min, x_max, y_max].
  496. bbox2 (list or tuple): Coordinates of the second bounding box [x_min, y_min, x_max, y_max].
  497. ratio (float): The overlap ratio threshold.
  498. smaller (bool): If True, return the smaller bounding box; otherwise, return the larger one.
  499. Returns:
  500. list or tuple: The selected bounding box or None if the overlap ratio is not exceeded.
  501. """
  502. # Calculate the areas of both bounding boxes
  503. area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
  504. area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
  505. # Calculate the overlap ratio using a helper function
  506. overlap_ratio = _calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2)
  507. # Check if the overlap ratio exceeds the threshold
  508. if overlap_ratio > ratio:
  509. if (area1 <= area2 and smaller) or (area1 >= area2 and not smaller):
  510. return 1
  511. else:
  512. return 2
  513. return None
  514. def _remove_overlap_blocks(blocks, threshold=0.65, smaller=True):
  515. """
  516. Remove overlapping blocks based on a specified overlap ratio threshold.
  517. Args:
  518. blocks (list): List of block dictionaries, each containing a 'layout_bbox' key.
  519. threshold (float): Ratio threshold to determine significant overlap.
  520. smaller (bool): If True, the smaller block in overlap is removed.
  521. Returns:
  522. tuple: A tuple containing the updated list of blocks and a list of dropped blocks.
  523. """
  524. dropped_blocks = []
  525. dropped_indexes = []
  526. # Iterate over each pair of blocks to find overlaps
  527. for i in range(len(blocks)):
  528. block1 = blocks[i]
  529. for j in range(i + 1, len(blocks)):
  530. block2 = blocks[j]
  531. # Skip blocks that are already marked for removal
  532. if i in dropped_indexes or j in dropped_indexes:
  533. continue
  534. # Check for overlap and determine which block to remove
  535. overlap_box_index = _get_minbox_if_overlap_by_ratio(
  536. block1["layout_bbox"],
  537. block2["layout_bbox"],
  538. threshold,
  539. smaller=smaller,
  540. )
  541. if overlap_box_index is not None:
  542. if overlap_box_index == 1:
  543. block_to_remove = block1
  544. drop_index = i
  545. else:
  546. block_to_remove = block2
  547. drop_index = j
  548. if drop_index not in dropped_indexes:
  549. dropped_indexes.append(drop_index)
  550. dropped_blocks.append(block_to_remove)
  551. dropped_indexes.sort()
  552. for i in reversed(dropped_indexes):
  553. del blocks[i]
  554. return blocks, dropped_blocks
  555. def _text_median_width(blocks):
  556. widths = [
  557. block["layout_bbox"][2] - block["layout_bbox"][0]
  558. for block in blocks
  559. if block["label"] in ["text"]
  560. ]
  561. return np.median(widths) if widths else float("inf")
  562. def _get_layout_property(blocks, median_width, no_mask_labels, threshold=0.8):
  563. """
  564. Determine the layout (single or double column) of text blocks.
  565. Args:
  566. blocks (list): List of block dictionaries containing 'label' and 'layout_bbox'.
  567. median_width (float): Median width of text blocks.
  568. threshold (float): Threshold for determining layout overlap.
  569. Returns:
  570. list: Updated list of blocks with layout information.
  571. """
  572. blocks.sort(
  573. key=lambda x: (
  574. x["layout_bbox"][0],
  575. (x["layout_bbox"][2] - x["layout_bbox"][0]),
  576. ),
  577. )
  578. check_single_layout = {}
  579. page_min_x, page_max_x = float("inf"), 0
  580. double_label_height = 0
  581. double_label_area = 0
  582. single_label_area = 0
  583. for i, block in enumerate(blocks):
  584. page_min_x = min(page_min_x, block["layout_bbox"][0])
  585. page_max_x = max(page_max_x, block["layout_bbox"][2])
  586. page_width = page_max_x - page_min_x
  587. for i, block in enumerate(blocks):
  588. if block["label"] not in no_mask_labels:
  589. continue
  590. x_min_i, _, x_max_i, _ = block["layout_bbox"]
  591. layout_length = x_max_i - x_min_i
  592. cover_count, cover_with_threshold_count = 0, 0
  593. match_block_with_threshold_indexes = []
  594. for j, other_block in enumerate(blocks):
  595. if i == j or other_block["label"] not in no_mask_labels:
  596. continue
  597. x_min_j, _, x_max_j, _ = other_block["layout_bbox"]
  598. x_match_min, x_match_max = max(
  599. x_min_i,
  600. x_min_j,
  601. ), min(x_max_i, x_max_j)
  602. match_block_iou = (x_match_max - x_match_min) / (x_max_j - x_min_j)
  603. if match_block_iou > 0:
  604. cover_count += 1
  605. if match_block_iou > threshold:
  606. cover_with_threshold_count += 1
  607. match_block_with_threshold_indexes.append(
  608. (j, match_block_iou),
  609. )
  610. x_min_i = x_match_max
  611. if x_min_i >= x_max_i:
  612. break
  613. if (
  614. layout_length > median_width * 1.3
  615. and (cover_with_threshold_count >= 2 or cover_count >= 2)
  616. ) or layout_length > 0.6 * page_width:
  617. # if layout_length > median_width * 1.3 and (cover_with_threshold_count >= 2):
  618. block["layout"] = "double"
  619. double_label_height += block["layout_bbox"][3] - block["layout_bbox"][1]
  620. double_label_area += (block["layout_bbox"][2] - block["layout_bbox"][0]) * (
  621. block["layout_bbox"][3] - block["layout_bbox"][1]
  622. )
  623. else:
  624. block["layout"] = "single"
  625. check_single_layout[i] = match_block_with_threshold_indexes
  626. # Check single-layout block
  627. for i, single_layout in check_single_layout.items():
  628. if single_layout:
  629. index, match_iou = single_layout[-1]
  630. if match_iou > 0.9 and blocks[index]["layout"] == "double":
  631. blocks[i]["layout"] = "double"
  632. double_label_height += (
  633. blocks[i]["layout_bbox"][3] - blocks[i]["layout_bbox"][1]
  634. )
  635. double_label_area += (
  636. blocks[i]["layout_bbox"][2] - blocks[i]["layout_bbox"][0]
  637. ) * (blocks[i]["layout_bbox"][3] - blocks[i]["layout_bbox"][1])
  638. else:
  639. single_label_area += (
  640. blocks[i]["layout_bbox"][2] - blocks[i]["layout_bbox"][0]
  641. ) * (blocks[i]["layout_bbox"][3] - blocks[i]["layout_bbox"][1])
  642. return blocks, (double_label_area > single_label_area)
  643. def _get_bbox_direction(input_bbox, ratio=1):
  644. """
  645. Determine if a bounding box is horizontal or vertical.
  646. Args:
  647. input_bbox (list): Bounding box [x_min, y_min, x_max, y_max].
  648. ratio (float): Ratio for determining orientation.
  649. Returns:
  650. bool: True if horizontal, False if vertical.
  651. """
  652. return (input_bbox[2] - input_bbox[0]) * ratio >= (input_bbox[3] - input_bbox[1])
  653. def _get_projection_iou(input_bbox, match_bbox, is_horizontal=True):
  654. """
  655. Calculate the IoU of lines between two bounding boxes.
  656. Args:
  657. input_bbox (list): First bounding box [x_min, y_min, x_max, y_max].
  658. match_bbox (list): Second bounding box [x_min, y_min, x_max, y_max].
  659. is_horizontal (bool): Whether to compare horizontally or vertically.
  660. Returns:
  661. float: Line IoU.
  662. """
  663. if is_horizontal:
  664. x_match_min = max(input_bbox[0], match_bbox[0])
  665. x_match_max = min(input_bbox[2], match_bbox[2])
  666. return (x_match_max - x_match_min) / (input_bbox[2] - input_bbox[0])
  667. else:
  668. y_match_min = max(input_bbox[1], match_bbox[1])
  669. y_match_max = min(input_bbox[3], match_bbox[3])
  670. return (y_match_max - y_match_min) / (input_bbox[3] - input_bbox[1])
  671. def _get_sub_category(blocks, title_labels):
  672. """
  673. Determine the layout of title and text blocks.
  674. Args:
  675. blocks (list): List of block dictionaries.
  676. title_labels (list): List of labels considered as titles.
  677. Returns:
  678. list: Updated list of blocks with title-text layout information.
  679. """
  680. sub_title_labels = ["paragraph_title"]
  681. vision_labels = ["image", "table", "chart", "figure"]
  682. for i, block1 in enumerate(blocks):
  683. if block1.get("title_text") is None:
  684. block1["title_text"] = []
  685. if block1.get("sub_title") is None:
  686. block1["sub_title"] = []
  687. if block1.get("vision_footnote") is None:
  688. block1["vision_footnote"] = []
  689. if block1.get("sub_label") is None:
  690. block1["sub_label"] = block1["label"]
  691. if (
  692. block1["label"] not in title_labels
  693. and block1["label"] not in sub_title_labels
  694. and block1["label"] not in vision_labels
  695. ):
  696. continue
  697. bbox1 = block1["layout_bbox"]
  698. x1, y1, x2, y2 = bbox1
  699. is_horizontal_1 = _get_bbox_direction(block1["layout_bbox"])
  700. left_up_title_text_distance = float("inf")
  701. left_up_title_text_index = -1
  702. left_up_title_text_direction = None
  703. right_down_title_text_distance = float("inf")
  704. right_down_title_text_index = -1
  705. right_down_title_text_direction = None
  706. for j, block2 in enumerate(blocks):
  707. if i == j:
  708. continue
  709. bbox2 = block2["layout_bbox"]
  710. x1_prime, y1_prime, x2_prime, y2_prime = bbox2
  711. is_horizontal_2 = _get_bbox_direction(bbox2)
  712. match_block_iou = _get_projection_iou(
  713. bbox2,
  714. bbox1,
  715. is_horizontal_1,
  716. )
  717. def distance_(is_horizontal, is_left_up):
  718. if is_horizontal:
  719. if is_left_up:
  720. return (y1 - y2_prime + 2) // 5 + x1_prime / 5000
  721. else:
  722. return (y1_prime - y2 + 2) // 5 + x1_prime / 5000
  723. else:
  724. if is_left_up:
  725. return (x1 - x2_prime + 2) // 5 + y1_prime / 5000
  726. else:
  727. return (x1_prime - x2 + 2) // 5 + y1_prime / 5000
  728. block_iou_threshold = 0.1
  729. if block1["label"] in sub_title_labels:
  730. match_block_iou = _calculate_overlap_area_2_minbox_area_ratio(
  731. bbox2,
  732. bbox1,
  733. )
  734. block_iou_threshold = 0.7
  735. if is_horizontal_1:
  736. if match_block_iou >= block_iou_threshold:
  737. left_up_distance = distance_(True, True)
  738. right_down_distance = distance_(True, False)
  739. if (
  740. y2_prime <= y1
  741. and left_up_distance <= left_up_title_text_distance
  742. ):
  743. left_up_title_text_distance = left_up_distance
  744. left_up_title_text_index = j
  745. left_up_title_text_direction = is_horizontal_2
  746. elif (
  747. y1_prime > y2
  748. and right_down_distance < right_down_title_text_distance
  749. ):
  750. right_down_title_text_distance = right_down_distance
  751. right_down_title_text_index = j
  752. right_down_title_text_direction = is_horizontal_2
  753. else:
  754. if match_block_iou >= block_iou_threshold:
  755. left_up_distance = distance_(False, True)
  756. right_down_distance = distance_(False, False)
  757. if (
  758. x2_prime <= x1
  759. and left_up_distance <= left_up_title_text_distance
  760. ):
  761. left_up_title_text_distance = left_up_distance
  762. left_up_title_text_index = j
  763. left_up_title_text_direction = is_horizontal_2
  764. elif (
  765. x1_prime > x2
  766. and right_down_distance < right_down_title_text_distance
  767. ):
  768. right_down_title_text_distance = right_down_distance
  769. right_down_title_text_index = j
  770. right_down_title_text_direction = is_horizontal_2
  771. height = bbox1[3] - bbox1[1]
  772. width = bbox1[2] - bbox1[0]
  773. title_text_weight = [0.8, 0.8]
  774. # title_text_weight = [2, 2]
  775. title_text = []
  776. sub_title = []
  777. vision_footnote = []
  778. def get_sub_category_(
  779. title_text_direction,
  780. title_text_index,
  781. label,
  782. is_left_up=True,
  783. ):
  784. direction_ = [1, 3] if is_left_up else [2, 4]
  785. if (
  786. title_text_direction == is_horizontal_1
  787. and title_text_index != -1
  788. and (label == "text" or label == "paragraph_title")
  789. ):
  790. bbox2 = blocks[title_text_index]["layout_bbox"]
  791. if is_horizontal_1:
  792. height1 = bbox2[3] - bbox2[1]
  793. width1 = bbox2[2] - bbox2[0]
  794. if label == "text":
  795. if (
  796. _nearest_edge_distance(bbox1, bbox2)[0] <= 15
  797. and block1["label"] in vision_labels
  798. and width1 < width
  799. and height1 < 0.5 * height
  800. ):
  801. blocks[title_text_index]["sub_label"] = "vision_footnote"
  802. vision_footnote.append(bbox2)
  803. elif (
  804. height1 < height * title_text_weight[0]
  805. and (width1 < width or width1 > 1.5 * width)
  806. and block1["label"] in title_labels
  807. ):
  808. blocks[title_text_index]["sub_label"] = "title_text"
  809. title_text.append((direction_[0], bbox2))
  810. elif (
  811. label == "paragraph_title"
  812. and block1["label"] in sub_title_labels
  813. ):
  814. sub_title.append(bbox2)
  815. else:
  816. height1 = bbox2[3] - bbox2[1]
  817. width1 = bbox2[2] - bbox2[0]
  818. if label == "text":
  819. if (
  820. _nearest_edge_distance(bbox1, bbox2)[0] <= 15
  821. and block1["label"] in vision_labels
  822. and height1 < height
  823. and width1 < 0.5 * width
  824. ):
  825. blocks[title_text_index]["sub_label"] = "vision_footnote"
  826. vision_footnote.append(bbox2)
  827. elif (
  828. width1 < width * title_text_weight[1]
  829. and block1["label"] in title_labels
  830. ):
  831. blocks[title_text_index]["sub_label"] = "title_text"
  832. title_text.append((direction_[1], bbox2))
  833. elif (
  834. label == "paragraph_title"
  835. and block1["label"] in sub_title_labels
  836. ):
  837. sub_title.append(bbox2)
  838. if (
  839. is_horizontal_1
  840. and abs(left_up_title_text_distance - right_down_title_text_distance) * 5
  841. > height
  842. ) or (
  843. not is_horizontal_1
  844. and abs(left_up_title_text_distance - right_down_title_text_distance) * 5
  845. > width
  846. ):
  847. if left_up_title_text_distance < right_down_title_text_distance:
  848. get_sub_category_(
  849. left_up_title_text_direction,
  850. left_up_title_text_index,
  851. blocks[left_up_title_text_index]["label"],
  852. True,
  853. )
  854. else:
  855. get_sub_category_(
  856. right_down_title_text_direction,
  857. right_down_title_text_index,
  858. blocks[right_down_title_text_index]["label"],
  859. False,
  860. )
  861. else:
  862. get_sub_category_(
  863. left_up_title_text_direction,
  864. left_up_title_text_index,
  865. blocks[left_up_title_text_index]["label"],
  866. True,
  867. )
  868. get_sub_category_(
  869. right_down_title_text_direction,
  870. right_down_title_text_index,
  871. blocks[right_down_title_text_index]["label"],
  872. False,
  873. )
  874. if block1["label"] in title_labels:
  875. if blocks[i].get("title_text") == []:
  876. blocks[i]["title_text"] = title_text
  877. if block1["label"] in sub_title_labels:
  878. if blocks[i].get("sub_title") == []:
  879. blocks[i]["sub_title"] = sub_title
  880. if block1["label"] in vision_labels:
  881. if blocks[i].get("vision_footnote") == []:
  882. blocks[i]["vision_footnote"] = vision_footnote
  883. return blocks
  884. def get_layout_ordering(data, no_mask_labels=[], already_sorted=False):
  885. """
  886. Process layout parsing results to remove overlapping bounding boxes
  887. and assign an ordering index based on their positions.
  888. Modifies:
  889. The 'parsing_result' list in 'layout_parsing_result' by adding an 'index' to each block.
  890. """
  891. if already_sorted:
  892. return data
  893. title_text_labels = ["doc_title"]
  894. title_labels = ["doc_title", "paragraph_title"]
  895. vision_labels = ["image", "table", "seal", "chart", "figure"]
  896. vision_title_labels = ["table_title", "chart_title", "figure_title"]
  897. parsing_result = data["sub_blocks"]
  898. parsing_result, _ = _remove_overlap_blocks(
  899. parsing_result,
  900. threshold=0.5,
  901. smaller=True,
  902. )
  903. parsing_result = _get_sub_category(parsing_result, title_text_labels)
  904. doc_flag = False
  905. median_width = _text_median_width(parsing_result)
  906. parsing_result, projection_direction = _get_layout_property(
  907. parsing_result,
  908. median_width,
  909. no_mask_labels=no_mask_labels,
  910. threshold=0.3,
  911. )
  912. # Convert bounding boxes to float and remove overlaps
  913. (
  914. double_text_blocks,
  915. title_text_blocks,
  916. title_blocks,
  917. vision_blocks,
  918. vision_title_blocks,
  919. vision_footnote_blocks,
  920. other_blocks,
  921. ) = ([], [], [], [], [], [], [])
  922. drop_indexes = []
  923. for index, block in enumerate(parsing_result):
  924. label = block["sub_label"]
  925. block["layout_bbox"] = list(map(int, block["layout_bbox"]))
  926. if label == "doc_title":
  927. doc_flag = True
  928. if label in no_mask_labels:
  929. if block["layout"] == "double":
  930. double_text_blocks.append(block)
  931. drop_indexes.append(index)
  932. elif label == "title_text":
  933. title_text_blocks.append(block)
  934. drop_indexes.append(index)
  935. elif label == "vision_footnote":
  936. vision_footnote_blocks.append(block)
  937. drop_indexes.append(index)
  938. elif label in vision_title_labels:
  939. vision_title_blocks.append(block)
  940. drop_indexes.append(index)
  941. elif label in title_labels:
  942. title_blocks.append(block)
  943. drop_indexes.append(index)
  944. elif label in vision_labels:
  945. vision_blocks.append(block)
  946. drop_indexes.append(index)
  947. else:
  948. other_blocks.append(block)
  949. drop_indexes.append(index)
  950. for index in sorted(drop_indexes, reverse=True):
  951. del parsing_result[index]
  952. if len(parsing_result) > 0:
  953. # single text label
  954. if len(double_text_blocks) > len(parsing_result) or projection_direction:
  955. parsing_result.extend(title_blocks + double_text_blocks)
  956. title_blocks = []
  957. double_text_blocks = []
  958. block_bboxes = [block["layout_bbox"] for block in parsing_result]
  959. block_bboxes.sort(
  960. key=lambda x: (
  961. x[0] // max(20, median_width),
  962. x[1],
  963. ),
  964. )
  965. block_bboxes = np.array(block_bboxes)
  966. print("sort by yxcut...")
  967. sorted_indices = sort_by_xycut(
  968. block_bboxes,
  969. direction=1,
  970. min_gap=1,
  971. )
  972. else:
  973. block_bboxes = [block["layout_bbox"] for block in parsing_result]
  974. block_bboxes.sort(key=lambda x: (x[0] // 20, x[1]))
  975. block_bboxes = np.array(block_bboxes)
  976. sorted_indices = sort_by_xycut(
  977. block_bboxes,
  978. direction=0,
  979. min_gap=20,
  980. )
  981. sorted_boxes = block_bboxes[sorted_indices].tolist()
  982. for block in parsing_result:
  983. block["index"] = sorted_boxes.index(block["layout_bbox"]) + 1
  984. block["sub_index"] = sorted_boxes.index(block["layout_bbox"]) + 1
  985. def nearest_match_(input_blocks, distance_type="manhattan", is_add_index=True):
  986. for block in input_blocks:
  987. bbox = block["layout_bbox"]
  988. min_distance = float("inf")
  989. min_distance_config = [
  990. [float("inf"), float("inf")],
  991. float("inf"),
  992. float("inf"),
  993. ] # for double text
  994. nearest_gt_index = 0
  995. for match_block in parsing_result:
  996. match_bbox = match_block["layout_bbox"]
  997. if distance_type == "nearest_iou_edge_distance":
  998. distance, min_distance_config = _nearest_iou_edge_distance(
  999. bbox,
  1000. match_bbox,
  1001. block["sub_label"],
  1002. vision_labels=vision_labels,
  1003. no_mask_labels=no_mask_labels,
  1004. median_width=median_width,
  1005. title_labels=title_labels,
  1006. title_text=block["title_text"],
  1007. sub_title=block["sub_title"],
  1008. min_distance_config=min_distance_config,
  1009. tolerance_len=10,
  1010. )
  1011. elif distance_type == "title_text":
  1012. if (
  1013. match_block["label"] in title_labels + ["abstract"]
  1014. and match_block["title_text"] != []
  1015. ):
  1016. iou_left_up = _calculate_overlap_area_2_minbox_area_ratio(
  1017. bbox,
  1018. match_block["title_text"][0][1],
  1019. )
  1020. iou_right_down = _calculate_overlap_area_2_minbox_area_ratio(
  1021. bbox,
  1022. match_block["title_text"][-1][1],
  1023. )
  1024. iou = 1 - max(iou_left_up, iou_right_down)
  1025. distance = _manhattan_distance(bbox, match_bbox) * iou
  1026. else:
  1027. distance = float("inf")
  1028. elif distance_type == "manhattan":
  1029. distance = _manhattan_distance(bbox, match_bbox)
  1030. elif distance_type == "vision_footnote":
  1031. if (
  1032. match_block["label"] in vision_labels
  1033. and match_block["vision_footnote"] != []
  1034. ):
  1035. iou_left_up = _calculate_overlap_area_2_minbox_area_ratio(
  1036. bbox,
  1037. match_block["vision_footnote"][0],
  1038. )
  1039. iou_right_down = _calculate_overlap_area_2_minbox_area_ratio(
  1040. bbox,
  1041. match_block["vision_footnote"][-1],
  1042. )
  1043. iou = 1 - max(iou_left_up, iou_right_down)
  1044. distance = _manhattan_distance(bbox, match_bbox) * iou
  1045. else:
  1046. distance = float("inf")
  1047. elif distance_type == "vision_body":
  1048. if (
  1049. match_block["label"] in vision_title_labels
  1050. and block["vision_footnote"] != []
  1051. ):
  1052. iou_left_up = _calculate_overlap_area_2_minbox_area_ratio(
  1053. match_bbox,
  1054. block["vision_footnote"][0],
  1055. )
  1056. iou_right_down = _calculate_overlap_area_2_minbox_area_ratio(
  1057. match_bbox,
  1058. block["vision_footnote"][-1],
  1059. )
  1060. iou = 1 - max(iou_left_up, iou_right_down)
  1061. distance = _manhattan_distance(bbox, match_bbox) * iou
  1062. else:
  1063. distance = float("inf")
  1064. else:
  1065. raise NotImplementedError
  1066. if distance < min_distance:
  1067. min_distance = distance
  1068. if is_add_index:
  1069. nearest_gt_index = match_block.get("index", 999)
  1070. else:
  1071. nearest_gt_index = match_block.get("sub_index", 999)
  1072. if is_add_index:
  1073. block["index"] = nearest_gt_index
  1074. else:
  1075. block["sub_index"] = nearest_gt_index
  1076. parsing_result.append(block)
  1077. # double text label
  1078. double_text_blocks.sort(
  1079. key=lambda x: (
  1080. x["layout_bbox"][1] // 10,
  1081. x["layout_bbox"][0] // median_width,
  1082. x["layout_bbox"][1] ** 2 + x["layout_bbox"][0] ** 2,
  1083. ),
  1084. )
  1085. nearest_match_(
  1086. double_text_blocks,
  1087. distance_type="nearest_iou_edge_distance",
  1088. )
  1089. parsing_result.sort(
  1090. key=lambda x: (x["index"], x["layout_bbox"][1], x["layout_bbox"][0]),
  1091. )
  1092. for idx, block in enumerate(parsing_result):
  1093. block["index"] = idx + 1
  1094. block["sub_index"] = idx + 1
  1095. # title label
  1096. title_blocks.sort(
  1097. key=lambda x: (
  1098. x["layout_bbox"][1] // 10,
  1099. x["layout_bbox"][0] // median_width,
  1100. x["layout_bbox"][1] ** 2 + x["layout_bbox"][0] ** 2,
  1101. ),
  1102. )
  1103. nearest_match_(title_blocks, distance_type="nearest_iou_edge_distance")
  1104. if doc_flag:
  1105. # text_sort_labels = ["doc_title","paragraph_title","abstract"]
  1106. text_sort_labels = ["doc_title"]
  1107. text_label_priority = {
  1108. label: priority for priority, label in enumerate(text_sort_labels)
  1109. }
  1110. doc_titles = []
  1111. for i, block in enumerate(parsing_result):
  1112. if block["label"] == "doc_title":
  1113. doc_titles.append(
  1114. (i, block["layout_bbox"][1], block["layout_bbox"][0]),
  1115. )
  1116. doc_titles.sort(key=lambda x: (x[1], x[2]))
  1117. first_doc_title_index = doc_titles[0][0]
  1118. parsing_result[first_doc_title_index]["index"] = 1
  1119. parsing_result.sort(
  1120. key=lambda x: (
  1121. x["index"],
  1122. text_label_priority.get(x["label"], 9999),
  1123. x["layout_bbox"][1],
  1124. x["layout_bbox"][0],
  1125. ),
  1126. )
  1127. else:
  1128. parsing_result.sort(
  1129. key=lambda x: (
  1130. x["index"],
  1131. x["layout_bbox"][1],
  1132. x["layout_bbox"][0],
  1133. ),
  1134. )
  1135. for idx, block in enumerate(parsing_result):
  1136. block["index"] = idx + 1
  1137. block["sub_index"] = idx + 1
  1138. # title-text label
  1139. nearest_match_(title_text_blocks, distance_type="title_text")
  1140. text_sort_labels = ["doc_title", "paragraph_title", "title_text"]
  1141. text_label_priority = {
  1142. label: priority for priority, label in enumerate(text_sort_labels)
  1143. }
  1144. parsing_result.sort(
  1145. key=lambda x: (
  1146. x["index"],
  1147. text_label_priority.get(x["sub_label"], 9999),
  1148. x["layout_bbox"][1],
  1149. x["layout_bbox"][0],
  1150. ),
  1151. )
  1152. for idx, block in enumerate(parsing_result):
  1153. block["index"] = idx + 1
  1154. block["sub_index"] = idx + 1
  1155. # image,figure,chart,seal label
  1156. nearest_match_(
  1157. vision_title_blocks,
  1158. distance_type="nearest_iou_edge_distance",
  1159. is_add_index=False,
  1160. )
  1161. parsing_result.sort(
  1162. key=lambda x: (
  1163. x["sub_index"],
  1164. x["layout_bbox"][1],
  1165. x["layout_bbox"][0],
  1166. ),
  1167. )
  1168. for idx, block in enumerate(parsing_result):
  1169. block["sub_index"] = idx + 1
  1170. # image,figure,chart,seal label
  1171. nearest_match_(
  1172. vision_blocks,
  1173. distance_type="nearest_iou_edge_distance",
  1174. is_add_index=False,
  1175. )
  1176. parsing_result.sort(
  1177. key=lambda x: (
  1178. x["sub_index"],
  1179. x["layout_bbox"][1],
  1180. x["layout_bbox"][0],
  1181. ),
  1182. )
  1183. for idx, block in enumerate(parsing_result):
  1184. block["sub_index"] = idx + 1
  1185. # vision footnote label
  1186. nearest_match_(
  1187. vision_footnote_blocks,
  1188. distance_type="vision_footnote",
  1189. is_add_index=False,
  1190. )
  1191. text_label_priority = {"vision_footnote": 9999}
  1192. parsing_result.sort(
  1193. key=lambda x: (
  1194. x["sub_index"],
  1195. text_label_priority.get(x["sub_label"], 0),
  1196. x["layout_bbox"][1],
  1197. x["layout_bbox"][0],
  1198. ),
  1199. )
  1200. for idx, block in enumerate(parsing_result):
  1201. block["sub_index"] = idx + 1
  1202. # header、footnote、header_image... label
  1203. nearest_match_(other_blocks, distance_type="manhattan", is_add_index=False)
  1204. return data
  1205. def _generate_input_data(parsing_result):
  1206. """
  1207. The evaluation input data is generated based on the parsing results.
  1208. :param parsing_result: A list containing the results of the layout parsing
  1209. :return: A formatted list of input data
  1210. """
  1211. input_data = [
  1212. {
  1213. "block_bbox": block["block_bbox"],
  1214. "sub_indices": [],
  1215. "sub_bboxes": [],
  1216. }
  1217. for block in parsing_result
  1218. ]
  1219. for block_index, block in enumerate(parsing_result):
  1220. sub_blocks = block["sub_blocks"]
  1221. get_layout_ordering(
  1222. block_index=block_index,
  1223. no_mask_labels=[
  1224. "text",
  1225. "formula",
  1226. "algorithm",
  1227. "reference",
  1228. "content",
  1229. "abstract",
  1230. ],
  1231. )
  1232. for sub_block in sub_blocks:
  1233. input_data[block_index]["sub_bboxes"].append(
  1234. list(map(int, sub_block["layout_bbox"])),
  1235. )
  1236. input_data[block_index]["sub_indices"].append(
  1237. int(sub_block["index"]),
  1238. )
  1239. return input_data
  1240. def _manhattan_distance(point1, point2, weight_x=1, weight_y=1):
  1241. return weight_x * abs(point1[0] - point2[0]) + weight_y * abs(point1[1] - point2[1])
  1242. def _calculate_horizontal_distance(
  1243. input_bbox,
  1244. match_bbox,
  1245. height,
  1246. disperse,
  1247. title_text,
  1248. ):
  1249. """
  1250. Calculate the horizontal distance between two bounding boxes, considering title text adjustments.
  1251. Args:
  1252. input_bbox (list): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  1253. match_bbox (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  1254. height (int): The height of the input bounding box used for normalization.
  1255. disperse (int): The dispersion factor used to normalize the horizontal distance.
  1256. title_text (list): A list of tuples containing title text information and their bounding box coordinates.
  1257. Format: [(position_indicator, [x1, y1, x2, y2]), ...].
  1258. Returns:
  1259. float: The calculated horizontal distance taking into account the title text adjustments.
  1260. """
  1261. x1, y1, x2, y2 = input_bbox
  1262. x1_prime, y1_prime, x2_prime, y2_prime = match_bbox
  1263. if y2 < y1_prime:
  1264. if title_text and title_text[-1][0] == 2:
  1265. y2 += title_text[-1][1][3] - title_text[-1][1][1]
  1266. distance1 = (y1_prime - y2) * 0.5
  1267. else:
  1268. if title_text and title_text[0][0] == 1:
  1269. y1 -= title_text[0][1][3] - title_text[0][1][1]
  1270. distance1 = y1 - y2_prime
  1271. return (
  1272. abs(x2_prime - x1) // disperse + distance1 // height + distance1 / 5000
  1273. ) # if page max size == 5000
  1274. def _calculate_vertical_distance(input_bbox, match_bbox, width, disperse, title_text):
  1275. """
  1276. Calculate the vertical distance between two bounding boxes, considering title text adjustments.
  1277. Args:
  1278. input_bbox (list): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  1279. match_bbox (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  1280. width (int): The width of the input bounding box used for normalization.
  1281. disperse (int): The dispersion factor used to normalize the vertical distance.
  1282. title_text (list): A list of tuples containing title text information and their bounding box coordinates.
  1283. Format: [(position_indicator, [x1, y1, x2, y2]), ...].
  1284. Returns:
  1285. float: The calculated vertical distance taking into account the title text adjustments.
  1286. """
  1287. x1, y1, x2, y2 = input_bbox
  1288. x1_prime, y1_prime, x2_prime, y2_prime = match_bbox
  1289. if x1 > x2_prime:
  1290. if title_text and title_text[0][0] == 3:
  1291. x1 -= title_text[0][1][2] - title_text[0][1][0]
  1292. distance2 = (x1 - x2_prime) * 0.5
  1293. else:
  1294. if title_text and title_text[-1][0] == 4:
  1295. x2 += title_text[-1][1][2] - title_text[-1][1][0]
  1296. distance2 = x1_prime - x2
  1297. return abs(y2_prime - y1) // disperse + distance2 // width + distance2 / 5000
  1298. def _nearest_edge_distance(
  1299. input_bbox,
  1300. match_bbox,
  1301. weight=[1, 1, 1, 1],
  1302. label="text",
  1303. no_mask_labels=[],
  1304. min_edge_distances_config=[],
  1305. tolerance_len=10,
  1306. ):
  1307. """
  1308. Calculate the nearest edge distance between two bounding boxes, considering directional weights.
  1309. Args:
  1310. input_bbox (list): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  1311. match_bbox (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  1312. weight (list, optional): Directional weights for the edge distances [left, right, up, down]. Defaults to [1, 1, 1, 1].
  1313. label (str, optional): The label/type of the object in the bounding box (e.g., 'text'). Defaults to 'text'.
  1314. no_mask_labels (list, optional): Labels for which no masking is applied when calculating edge distances. Defaults to an empty list.
  1315. min_edge_distances_config (list, optional): Configuration for minimum edge distances [min_edge_distance_x, min_edge_distance_y].
  1316. Defaults to [float('inf'), float('inf')].
  1317. Returns:
  1318. tuple: A tuple containing:
  1319. - The calculated minimum edge distance between the bounding boxes.
  1320. - A list with the minimum edge distances in the x and y directions.
  1321. """
  1322. match_bbox_iou = _calculate_overlap_area_2_minbox_area_ratio(
  1323. input_bbox,
  1324. match_bbox,
  1325. )
  1326. if match_bbox_iou > 0 and label not in no_mask_labels:
  1327. return 0, [0, 0]
  1328. if not min_edge_distances_config:
  1329. min_edge_distances_config = [float("inf"), float("inf")]
  1330. min_edge_distance_x, min_edge_distance_y = min_edge_distances_config
  1331. x1, y1, x2, y2 = input_bbox
  1332. x1_prime, y1_prime, x2_prime, y2_prime = match_bbox
  1333. direction_num = 0
  1334. distance_x = float("inf")
  1335. distance_y = float("inf")
  1336. distance = [float("inf")] * 4
  1337. # input_bbox is to the left of match_bbox
  1338. if x2 < x1_prime:
  1339. direction_num += 1
  1340. distance[0] = x1_prime - x2
  1341. if abs(distance[0] - min_edge_distance_x) <= tolerance_len:
  1342. distance_x = min_edge_distance_x * weight[0]
  1343. else:
  1344. distance_x = distance[0] * weight[0]
  1345. # input_bbox is to the right of match_bbox
  1346. elif x1 > x2_prime:
  1347. direction_num += 1
  1348. distance[1] = x1 - x2_prime
  1349. if abs(distance[1] - min_edge_distance_x) <= tolerance_len:
  1350. distance_x = min_edge_distance_x * weight[1]
  1351. else:
  1352. distance_x = distance[1] * weight[1]
  1353. elif match_bbox_iou > 0:
  1354. distance[0] = 0
  1355. distance_x = 0
  1356. # input_bbox is above match_bbox
  1357. if y2 < y1_prime:
  1358. direction_num += 1
  1359. distance[2] = y1_prime - y2
  1360. if abs(distance[2] - min_edge_distance_y) <= tolerance_len:
  1361. distance_y = min_edge_distance_y * weight[2]
  1362. else:
  1363. distance_y = distance[2] * weight[2]
  1364. if label in no_mask_labels:
  1365. distance_y = max(0.1, distance_y) * 100
  1366. # input_bbox is below match_bbox
  1367. elif y1 > y2_prime:
  1368. direction_num += 1
  1369. distance[3] = y1 - y2_prime
  1370. if abs(distance[3] - min_edge_distance_y) <= tolerance_len:
  1371. distance_y = min_edge_distance_y * weight[3]
  1372. else:
  1373. distance_y = distance[3] * weight[3]
  1374. elif match_bbox_iou > 0:
  1375. distance[2] = 0
  1376. distance_y = 0
  1377. if direction_num == 2:
  1378. return (distance_x + distance_y), [
  1379. min(distance[0], distance[1]),
  1380. min(distance[2], distance[3]),
  1381. ]
  1382. else:
  1383. return min(distance_x, distance_y), [
  1384. min(distance[0], distance[1]),
  1385. min(distance[2], distance[3]),
  1386. ]
  1387. def _get_weights(label, horizontal):
  1388. """Define weights based on the label and orientation."""
  1389. if label == "doc_title":
  1390. return (
  1391. [1, 0.1, 0.1, 1] if horizontal else [0.2, 0.1, 1, 1]
  1392. ) # left-down , right-left
  1393. elif label in [
  1394. "paragraph_title",
  1395. "abstract",
  1396. "figure_title",
  1397. "chart_title",
  1398. "image",
  1399. "seal",
  1400. "chart",
  1401. "figure",
  1402. ]:
  1403. return [1, 1, 0.1, 1] # down
  1404. else:
  1405. return [1, 1, 1, 0.1] # up
  1406. def _nearest_iou_edge_distance(
  1407. input_bbox,
  1408. match_bbox,
  1409. label,
  1410. vision_labels,
  1411. no_mask_labels,
  1412. median_width=-1,
  1413. title_labels=[],
  1414. title_text=[],
  1415. sub_title=[],
  1416. min_distance_config=[],
  1417. tolerance_len=10,
  1418. ):
  1419. """
  1420. Calculate the nearest IOU edge distance between two bounding boxes.
  1421. Args:
  1422. input_bbox (list): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  1423. match_bbox (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  1424. label (str): The label/type of the object in the bounding box (e.g., 'image', 'text', etc.).
  1425. no_mask_labels (list): Labels for which no masking is applied when calculating edge distances.
  1426. median_width (int, optional): The median width for title dispersion calculation. Defaults to -1.
  1427. title_labels (list, optional): Labels that indicate the object is a title. Defaults to an empty list.
  1428. title_text (list, optional): Text content associated with title labels. Defaults to an empty list.
  1429. sub_title (list, optional): List of subtitle bounding boxes to adjust the input_bbox. Defaults to an empty list.
  1430. min_distance_config (list, optional): Configuration for minimum distances [min_edge_distances_config, up_edge_distances_config, total_distance].
  1431. Returns:
  1432. tuple: A tuple containing the calculated distance and updated minimum distance configuration.
  1433. """
  1434. x1, y1, x2, y2 = input_bbox
  1435. x1_prime, y1_prime, x2_prime, y2_prime = match_bbox
  1436. min_edge_distances_config, up_edge_distances_config, total_distance = (
  1437. min_distance_config
  1438. )
  1439. iou_distance = 0
  1440. if label in vision_labels:
  1441. horizontal1 = horizontal2 = True
  1442. else:
  1443. horizontal1 = _get_bbox_direction(input_bbox)
  1444. horizontal2 = _get_bbox_direction(match_bbox, 3)
  1445. if (
  1446. horizontal1 != horizontal2
  1447. or _get_projection_iou(input_bbox, match_bbox, horizontal1) < 0.01
  1448. ):
  1449. iou_distance = 1
  1450. elif label == "doc_title" or (label in title_labels and title_text):
  1451. # Calculate distance for titles
  1452. disperse = max(1, median_width)
  1453. width = x2 - x1
  1454. height = y2 - y1
  1455. if horizontal1:
  1456. return (
  1457. _calculate_horizontal_distance(
  1458. input_bbox,
  1459. match_bbox,
  1460. height,
  1461. disperse,
  1462. title_text,
  1463. ),
  1464. min_distance_config,
  1465. )
  1466. else:
  1467. return (
  1468. _calculate_vertical_distance(
  1469. input_bbox,
  1470. match_bbox,
  1471. width,
  1472. disperse,
  1473. title_text,
  1474. ),
  1475. min_distance_config,
  1476. )
  1477. # Adjust input_bbox based on sub_title
  1478. if sub_title:
  1479. for sub in sub_title:
  1480. x1_, y1_, x2_, y2_ = sub
  1481. x1, y1, x2, y2 = (
  1482. min(x1, x1_),
  1483. min(
  1484. y1,
  1485. y1_,
  1486. ),
  1487. max(x2, x2_),
  1488. max(y2, y2_),
  1489. )
  1490. input_bbox = [x1, y1, x2, y2]
  1491. # Calculate edge distance
  1492. weight = _get_weights(label, horizontal1)
  1493. if label == "abstract":
  1494. tolerance_len *= 3
  1495. edge_distance, edge_distance_config = _nearest_edge_distance(
  1496. input_bbox,
  1497. match_bbox,
  1498. weight,
  1499. label=label,
  1500. no_mask_labels=no_mask_labels,
  1501. min_edge_distances_config=min_edge_distances_config,
  1502. tolerance_len=tolerance_len,
  1503. )
  1504. # Weights for combining distances
  1505. iou_edge_weight = [10**6, 10**3, 1, 0.001]
  1506. # Calculate up and left edge distances
  1507. up_edge_distance = y1_prime
  1508. left_edge_distance = x1_prime
  1509. if (
  1510. label in no_mask_labels or label == "paragraph_title" or label in vision_labels
  1511. ) and y1 > y2_prime:
  1512. up_edge_distance = -y2_prime
  1513. left_edge_distance = -x2_prime
  1514. min_up_edge_distance = up_edge_distances_config
  1515. if abs(min_up_edge_distance - up_edge_distance) <= tolerance_len:
  1516. up_edge_distance = min_up_edge_distance
  1517. # Calculate total distance
  1518. distance = (
  1519. iou_distance * iou_edge_weight[0]
  1520. + edge_distance * iou_edge_weight[1]
  1521. + up_edge_distance * iou_edge_weight[2]
  1522. + left_edge_distance * iou_edge_weight[3]
  1523. )
  1524. # Update minimum distance configuration if a smaller distance is found
  1525. if total_distance > distance:
  1526. edge_distance_config = [
  1527. min(min_edge_distances_config[0], edge_distance_config[0]),
  1528. min(min_edge_distances_config[1], edge_distance_config[1]),
  1529. ]
  1530. min_distance_config = [
  1531. edge_distance_config,
  1532. min(up_edge_distance, up_edge_distances_config),
  1533. distance,
  1534. ]
  1535. return distance, min_distance_config
  1536. def get_show_color(label):
  1537. label_colors = {
  1538. # Medium Blue (from 'titles_list')
  1539. "paragraph_title": (102, 102, 255, 100),
  1540. "doc_title": (255, 248, 220, 100), # Cornsilk
  1541. # Light Yellow (from 'tables_caption_list')
  1542. "table_title": (255, 255, 102, 100),
  1543. # Sky Blue (from 'imgs_caption_list')
  1544. "figure_title": (102, 178, 255, 100),
  1545. "chart_title": (221, 160, 221, 100), # Plum
  1546. "vision_footnote": (144, 238, 144, 100), # Light Green
  1547. # Deep Purple (from 'texts_list')
  1548. "text": (153, 0, 76, 100),
  1549. # Bright Green (from 'interequations_list')
  1550. "formula": (0, 255, 0, 100),
  1551. "abstract": (255, 239, 213, 100), # Papaya Whip
  1552. # Medium Green (from 'lists_list' and 'indexs_list')
  1553. "content": (40, 169, 92, 100),
  1554. # Neutral Gray (from 'dropped_bbox_list')
  1555. "seal": (158, 158, 158, 100),
  1556. # Olive Yellow (from 'tables_body_list')
  1557. "table": (204, 204, 0, 100),
  1558. # Bright Green (from 'imgs_body_list')
  1559. "image": (153, 255, 51, 100),
  1560. # Bright Green (from 'imgs_body_list')
  1561. "figure": (153, 255, 51, 100),
  1562. "chart": (216, 191, 216, 100), # Thistle
  1563. # Pale Yellow-Green (from 'tables_footnote_list')
  1564. "reference": (229, 255, 204, 100),
  1565. "algorithm": (255, 250, 240, 100), # Floral White
  1566. }
  1567. default_color = (158, 158, 158, 100)
  1568. return label_colors.get(label, default_color)