utils.py 65 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. __all__ = [
  15. "get_sub_regions_ocr_res",
  16. "get_layout_ordering",
  17. "recursive_img_array2path",
  18. "get_show_color",
  19. ]
  20. import numpy as np
  21. import copy
  22. import cv2
  23. import uuid
  24. from pathlib import Path
  25. from ..ocr.result import OCRResult
  26. from ...models_new.object_detection.result import DetResult
  27. def get_overlap_boxes_idx(src_boxes: np.ndarray, ref_boxes: np.ndarray) -> list:
  28. """
  29. Get the indices of source boxes that overlap with reference boxes based on a specified threshold.
  30. Args:
  31. src_boxes (np.ndarray): A 2D numpy array of source bounding boxes.
  32. ref_boxes (np.ndarray): A 2D numpy array of reference bounding boxes.
  33. Returns:
  34. list: A list of indices of source boxes that overlap with any reference box.
  35. """
  36. match_idx_list = []
  37. src_boxes_num = len(src_boxes)
  38. if src_boxes_num > 0 and len(ref_boxes) > 0:
  39. for rno in range(len(ref_boxes)):
  40. ref_box = ref_boxes[rno]
  41. x1 = np.maximum(ref_box[0], src_boxes[:, 0])
  42. y1 = np.maximum(ref_box[1], src_boxes[:, 1])
  43. x2 = np.minimum(ref_box[2], src_boxes[:, 2])
  44. y2 = np.minimum(ref_box[3], src_boxes[:, 3])
  45. pub_w = x2 - x1
  46. pub_h = y2 - y1
  47. match_idx = np.where((pub_w > 3) & (pub_h > 3))[0]
  48. match_idx_list.extend(match_idx)
  49. return match_idx_list
  50. def get_sub_regions_ocr_res(
  51. overall_ocr_res: OCRResult, object_boxes: list, flag_within: bool = True
  52. ) -> OCRResult:
  53. """
  54. Filters OCR results to only include text boxes within specified object boxes based on a flag.
  55. Args:
  56. overall_ocr_res (OCRResult): The original OCR result containing all text boxes.
  57. object_boxes (list): A list of bounding boxes for the objects of interest.
  58. flag_within (bool): If True, only include text boxes within the object boxes. If False, exclude text boxes within the object boxes.
  59. Returns:
  60. OCRResult: A filtered OCR result containing only the relevant text boxes.
  61. """
  62. sub_regions_ocr_res = {}
  63. sub_regions_ocr_res["rec_polys"] = []
  64. sub_regions_ocr_res["rec_texts"] = []
  65. sub_regions_ocr_res["rec_scores"] = []
  66. sub_regions_ocr_res["rec_boxes"] = []
  67. overall_text_boxes = overall_ocr_res["rec_boxes"]
  68. match_idx_list = get_overlap_boxes_idx(overall_text_boxes, object_boxes)
  69. match_idx_list = list(set(match_idx_list))
  70. for box_no in range(len(overall_text_boxes)):
  71. if flag_within:
  72. if box_no in match_idx_list:
  73. flag_match = True
  74. else:
  75. flag_match = False
  76. else:
  77. if box_no not in match_idx_list:
  78. flag_match = True
  79. else:
  80. flag_match = False
  81. if flag_match:
  82. sub_regions_ocr_res["rec_polys"].append(
  83. overall_ocr_res["rec_polys"][box_no]
  84. )
  85. sub_regions_ocr_res["rec_texts"].append(
  86. overall_ocr_res["rec_texts"][box_no]
  87. )
  88. sub_regions_ocr_res["rec_scores"].append(
  89. overall_ocr_res["rec_scores"][box_no]
  90. )
  91. sub_regions_ocr_res["rec_boxes"].append(
  92. overall_ocr_res["rec_boxes"][box_no]
  93. )
  94. return sub_regions_ocr_res
  95. def calculate_iou(box1, box2):
  96. """
  97. Calculate Intersection over Union (IoU) between two bounding boxes.
  98. Args:
  99. box1, box2: Lists or tuples representing bounding boxes [x_min, y_min, x_max, y_max].
  100. Returns:
  101. float: The IoU value.
  102. """
  103. box1 = list(map(int, box1))
  104. box2 = list(map(int, box2))
  105. x1_min, y1_min, x1_max, y1_max = box1
  106. x2_min, y2_min, x2_max, y2_max = box2
  107. inter_x_min = max(x1_min, x2_min)
  108. inter_y_min = max(y1_min, y2_min)
  109. inter_x_max = min(x1_max, x2_max)
  110. inter_y_max = min(y1_max, y2_max)
  111. if inter_x_max <= inter_x_min or inter_y_max <= inter_y_min:
  112. return 0.0
  113. inter_area = (inter_x_max - inter_x_min) * (inter_y_max - inter_y_min)
  114. box1_area = (x1_max - x1_min) * (y1_max - y1_min)
  115. box2_area = (x2_max - x2_min) * (y2_max - y2_min)
  116. min_area = min(box1_area, box2_area)
  117. if min_area <= 0:
  118. return 0.0
  119. iou = inter_area / min_area
  120. return iou
  121. def _whether_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.6):
  122. _, y0_1, _, y1_1 = bbox1
  123. _, y0_2, _, y1_2 = bbox2
  124. overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2))
  125. min_height = min(y1_1 - y0_1, y1_2 - y0_2)
  126. return (overlap / min_height) > overlap_ratio_threshold
  127. def _sort_box_by_y_projection(layout_bbox, ocr_res, line_height_threshold=0.7):
  128. assert ocr_res["boxes"] and ocr_res["rec_texts"]
  129. # span->line->block
  130. boxes = ocr_res["boxes"]
  131. rec_text = ocr_res["rec_texts"]
  132. x_min, x_max = layout_bbox[0], layout_bbox[2]
  133. spans = list(zip(boxes, rec_text))
  134. spans.sort(key=lambda span: span[0][1])
  135. spans = [list(span) for span in spans]
  136. lines = []
  137. first_span = spans[0]
  138. current_line = [first_span]
  139. current_y0, current_y1 = first_span[0][1], first_span[0][3]
  140. for span in spans[1:]:
  141. y0, y1 = span[0][1], span[0][3]
  142. if _whether_overlaps_y_exceeds_threshold(
  143. (0, current_y0, 0, current_y1),
  144. (0, y0, 0, y1),
  145. line_height_threshold,
  146. ):
  147. current_line.append(span)
  148. current_y0 = min(current_y0, y0)
  149. current_y1 = max(current_y1, y1)
  150. else:
  151. lines.append(current_line)
  152. current_line = [span]
  153. current_y0, current_y1 = y0, y1
  154. if current_line:
  155. lines.append(current_line)
  156. for line in lines:
  157. line.sort(key=lambda span: span[0][0])
  158. first_span = line[0]
  159. end_span = line[-1]
  160. if first_span[0][0] - x_min > 20:
  161. first_span[1] = "\n" + first_span[1]
  162. if x_max - end_span[0][2] > 20:
  163. end_span[1] = end_span[1] + "\n"
  164. ocr_res["boxes"] = [span[0] for line in lines for span in line]
  165. ocr_res["rec_texts"] = [span[1] + " " for line in lines for span in line]
  166. return ocr_res
  167. def get_structure_res(
  168. overall_ocr_res: OCRResult,
  169. layout_det_res: DetResult,
  170. table_res_list,
  171. ) -> OCRResult:
  172. """
  173. Extract structured information from OCR and layout detection results.
  174. Args:
  175. overall_ocr_res (OCRResult): An object containing the overall OCR results, including detected text boxes and recognized text. The structure is expected to have:
  176. - "input_img": The image on which OCR was performed.
  177. - "dt_boxes": A list of detected text box coordinates.
  178. - "rec_texts": A list of recognized text corresponding to the detected boxes.
  179. layout_det_res (DetResult): An object containing the layout detection results, including detected layout boxes and their labels. The structure is expected to have:
  180. - "boxes": A list of dictionaries with keys "coordinate" for box coordinates and "label" for the type of content.
  181. table_res_list (list): A list of table detection results, where each item is a dictionary containing:
  182. - "layout_bbox": The bounding box of the table layout.
  183. - "pred_html": The predicted HTML representation of the table.
  184. Returns:
  185. list: A list of structured boxes where each item is a dictionary containing:
  186. - "label": The label of the content (e.g., 'table', 'chart', 'image').
  187. - The label as a key with either table HTML or image data and text.
  188. - "layout_bbox": The coordinates of the layout box.
  189. """
  190. structure_boxes = []
  191. input_img = overall_ocr_res["doc_preprocessor_res"]["output_img"]
  192. for box_info in layout_det_res["boxes"]:
  193. layout_bbox = box_info["coordinate"]
  194. label = box_info["label"]
  195. rec_res = {"boxes": [], "rec_texts": [], "flag": False}
  196. seg_start_flag = True
  197. seg_end_flag = True
  198. if label == "table":
  199. for i, table_res in enumerate(table_res_list):
  200. if calculate_iou(layout_bbox, table_res["cell_box_list"][0]) > 0.5:
  201. structure_boxes.append(
  202. {
  203. "label": label,
  204. f"{label}": table_res["pred_html"],
  205. "layout_bbox": layout_bbox,
  206. "seg_start_flag": seg_start_flag,
  207. "seg_end_flag": seg_end_flag,
  208. },
  209. )
  210. del table_res_list[i]
  211. break
  212. else:
  213. overall_text_boxes = overall_ocr_res["rec_boxes"]
  214. for box_no in range(len(overall_text_boxes)):
  215. if calculate_iou(layout_bbox, overall_text_boxes[box_no]) > 0.5:
  216. rec_res["boxes"].append(overall_text_boxes[box_no])
  217. rec_res["rec_texts"].append(
  218. overall_ocr_res["rec_texts"][box_no],
  219. )
  220. rec_res["flag"] = True
  221. if rec_res["flag"]:
  222. rec_res = _sort_box_by_y_projection(layout_bbox, rec_res, 0.7)
  223. rec_res_first_bbox = rec_res["boxes"][0]
  224. rec_res_end_bbox = rec_res["boxes"][-1]
  225. if rec_res_first_bbox[0] - layout_bbox[0] < 20:
  226. seg_start_flag = False
  227. if layout_bbox[2] - rec_res_end_bbox[2] < 20:
  228. seg_end_flag = False
  229. if label == "formula":
  230. rec_res["rec_texts"] = [
  231. rec_res_text.replace("$", "")
  232. for rec_res_text in rec_res["rec_texts"]
  233. ]
  234. if label in ["chart", "image"]:
  235. structure_boxes.append(
  236. {
  237. "label": label,
  238. f"{label}": {
  239. "img": input_img[
  240. int(layout_bbox[1]) : int(layout_bbox[3]),
  241. int(layout_bbox[0]) : int(layout_bbox[2]),
  242. ],
  243. },
  244. "layout_bbox": layout_bbox,
  245. "seg_start_flag": seg_start_flag,
  246. "seg_end_flag": seg_end_flag,
  247. },
  248. )
  249. else:
  250. structure_boxes.append(
  251. {
  252. "label": label,
  253. f"{label}": "".join(rec_res["rec_texts"]),
  254. "layout_bbox": layout_bbox,
  255. "seg_start_flag": seg_start_flag,
  256. "seg_end_flag": seg_end_flag,
  257. },
  258. )
  259. return structure_boxes
  260. def projection_by_bboxes(boxes: np.ndarray, axis: int) -> np.ndarray:
  261. """
  262. Generate a 1D projection histogram from bounding boxes along a specified axis.
  263. Args:
  264. boxes: A (N, 4) array of bounding boxes defined by [x_min, y_min, x_max, y_max].
  265. axis: Axis for projection; 0 for horizontal (x-axis), 1 for vertical (y-axis).
  266. Returns:
  267. A 1D numpy array representing the projection histogram based on bounding box intervals.
  268. """
  269. assert axis in [0, 1]
  270. max_length = np.max(boxes[:, axis::2])
  271. projection = np.zeros(max_length, dtype=int)
  272. # Increment projection histogram over the interval defined by each bounding box
  273. for start, end in boxes[:, axis::2]:
  274. projection[start:end] += 1
  275. return projection
  276. def split_projection_profile(arr_values: np.ndarray, min_value: float, min_gap: float):
  277. """
  278. Split the projection profile into segments based on specified thresholds.
  279. Args:
  280. arr_values: 1D array representing the projection profile.
  281. min_value: Minimum value threshold to consider a profile segment significant.
  282. min_gap: Minimum gap width to consider a separation between segments.
  283. Returns:
  284. A tuple of start and end indices for each segment that meets the criteria.
  285. """
  286. # Identify indices where the projection exceeds the minimum value
  287. significant_indices = np.where(arr_values > min_value)[0]
  288. if not len(significant_indices):
  289. return
  290. # Calculate gaps between significant indices
  291. index_diffs = significant_indices[1:] - significant_indices[:-1]
  292. gap_indices = np.where(index_diffs > min_gap)[0]
  293. # Determine start and end indices of segments
  294. segment_starts = np.insert(
  295. significant_indices[gap_indices + 1],
  296. 0,
  297. significant_indices[0],
  298. )
  299. segment_ends = np.append(
  300. significant_indices[gap_indices],
  301. significant_indices[-1] + 1,
  302. )
  303. return segment_starts, segment_ends
  304. def recursive_yx_cut(boxes: np.ndarray, indices: list[int], res: list[int], min_gap=1):
  305. """
  306. Recursively project and segment bounding boxes, starting with Y-axis and followed by X-axis.
  307. Args:
  308. boxes: A (N, 4) array representing bounding boxes.
  309. indices: List of indices indicating the original position of boxes.
  310. res: List to store indices of the final segmented bounding boxes.
  311. """
  312. assert len(boxes) == len(indices)
  313. # Sort by y_min for Y-axis projection
  314. y_sorted_indices = boxes[:, 1].argsort()
  315. y_sorted_boxes = boxes[y_sorted_indices]
  316. y_sorted_indices = np.array(indices)[y_sorted_indices]
  317. # Perform Y-axis projection
  318. y_projection = projection_by_bboxes(boxes=y_sorted_boxes, axis=1)
  319. y_intervals = split_projection_profile(y_projection, 0, 1)
  320. if not y_intervals:
  321. return
  322. # Process each segment defined by Y-axis projection
  323. for y_start, y_end in zip(*y_intervals):
  324. # Select boxes within the current y interval
  325. y_interval_indices = (y_start <= y_sorted_boxes[:, 1]) & (
  326. y_sorted_boxes[:, 1] < y_end
  327. )
  328. y_boxes_chunk = y_sorted_boxes[y_interval_indices]
  329. y_indices_chunk = y_sorted_indices[y_interval_indices]
  330. # Sort by x_min for X-axis projection
  331. x_sorted_indices = y_boxes_chunk[:, 0].argsort()
  332. x_sorted_boxes_chunk = y_boxes_chunk[x_sorted_indices]
  333. x_sorted_indices_chunk = y_indices_chunk[x_sorted_indices]
  334. # Perform X-axis projection
  335. x_projection = projection_by_bboxes(boxes=x_sorted_boxes_chunk, axis=0)
  336. x_intervals = split_projection_profile(x_projection, 0, min_gap)
  337. if not x_intervals:
  338. continue
  339. # If X-axis cannot be further segmented, add current indices to results
  340. if len(x_intervals[0]) == 1:
  341. res.extend(x_sorted_indices_chunk)
  342. continue
  343. # Recursively process each segment defined by X-axis projection
  344. for x_start, x_end in zip(*x_intervals):
  345. x_interval_indices = (x_start <= x_sorted_boxes_chunk[:, 0]) & (
  346. x_sorted_boxes_chunk[:, 0] < x_end
  347. )
  348. recursive_yx_cut(
  349. x_sorted_boxes_chunk[x_interval_indices],
  350. x_sorted_indices_chunk[x_interval_indices],
  351. res,
  352. )
  353. def recursive_xy_cut(boxes: np.ndarray, indices: list[int], res: list[int], min_gap=1):
  354. """
  355. Recursively performs X-axis projection followed by Y-axis projection to segment bounding boxes.
  356. Args:
  357. boxes: A (N, 4) array representing bounding boxes with [x_min, y_min, x_max, y_max].
  358. indices: A list of indices representing the position of boxes in the original data.
  359. res: A list to store indices of bounding boxes that meet the criteria.
  360. """
  361. # Ensure boxes and indices have the same length
  362. assert len(boxes) == len(indices)
  363. # Sort by x_min to prepare for X-axis projection
  364. x_sorted_indices = boxes[:, 0].argsort()
  365. x_sorted_boxes = boxes[x_sorted_indices]
  366. x_sorted_indices = np.array(indices)[x_sorted_indices]
  367. # Perform X-axis projection
  368. x_projection = projection_by_bboxes(boxes=x_sorted_boxes, axis=0)
  369. x_intervals = split_projection_profile(x_projection, 0, 1)
  370. if not x_intervals:
  371. return
  372. # Process each segment defined by X-axis projection
  373. for x_start, x_end in zip(*x_intervals):
  374. # Select boxes within the current x interval
  375. x_interval_indices = (x_start <= x_sorted_boxes[:, 0]) & (
  376. x_sorted_boxes[:, 0] < x_end
  377. )
  378. x_boxes_chunk = x_sorted_boxes[x_interval_indices]
  379. x_indices_chunk = x_sorted_indices[x_interval_indices]
  380. # Sort selected boxes by y_min to prepare for Y-axis projection
  381. y_sorted_indices = x_boxes_chunk[:, 1].argsort()
  382. y_sorted_boxes_chunk = x_boxes_chunk[y_sorted_indices]
  383. y_sorted_indices_chunk = x_indices_chunk[y_sorted_indices]
  384. # Perform Y-axis projection
  385. y_projection = projection_by_bboxes(boxes=y_sorted_boxes_chunk, axis=1)
  386. y_intervals = split_projection_profile(y_projection, 0, min_gap)
  387. if not y_intervals:
  388. continue
  389. # If Y-axis cannot be further segmented, add current indices to results
  390. if len(y_intervals[0]) == 1:
  391. res.extend(y_sorted_indices_chunk)
  392. continue
  393. # Recursively process each segment defined by Y-axis projection
  394. for y_start, y_end in zip(*y_intervals):
  395. y_interval_indices = (y_start <= y_sorted_boxes_chunk[:, 1]) & (
  396. y_sorted_boxes_chunk[:, 1] < y_end
  397. )
  398. recursive_xy_cut(
  399. y_sorted_boxes_chunk[y_interval_indices],
  400. y_sorted_indices_chunk[y_interval_indices],
  401. res,
  402. )
  403. def sort_by_xycut(block_bboxes, direction=0, min_gap=1):
  404. block_bboxes = np.asarray(block_bboxes).astype(int)
  405. res = []
  406. if direction == 1:
  407. recursive_yx_cut(
  408. block_bboxes,
  409. np.arange(
  410. len(block_bboxes),
  411. ),
  412. res,
  413. min_gap,
  414. )
  415. else:
  416. recursive_xy_cut(
  417. block_bboxes,
  418. np.arange(
  419. len(block_bboxes),
  420. ),
  421. res,
  422. min_gap,
  423. )
  424. return res
  425. def _img_array2path(data, save_path):
  426. """
  427. Save an image array to disk and return the file path.
  428. Args:
  429. data (np.ndarray): An image represented as a numpy array.
  430. save_path (str or Path): The base path where images should be saved.
  431. Returns:
  432. str: The relative path of the saved image file.
  433. """
  434. if isinstance(data, np.ndarray) and data.ndim == 3:
  435. # Generate a unique filename using UUID
  436. img_name = f"image_{uuid.uuid4().hex}.png"
  437. img_path = Path(save_path) / "imgs" / img_name
  438. img_path.parent.mkdir(
  439. parents=True,
  440. exist_ok=True,
  441. ) # Ensure the directory exists
  442. cv2.imwrite(str(img_path), data)
  443. return f"imgs/{img_name}"
  444. else:
  445. return ValueError
  446. def recursive_img_array2path(data, save_path, labels=[]):
  447. """
  448. Process a dictionary or list to save image arrays to disk and replace them with file paths.
  449. Args:
  450. data (dict or list): The data structure that may contain image arrays.
  451. save_path (str or Path): The base path where images should be saved.
  452. """
  453. if isinstance(data, dict):
  454. for k, v in data.items():
  455. if k in labels and isinstance(v, np.ndarray) and v.ndim == 3:
  456. data[k] = _img_array2path(v, save_path)
  457. else:
  458. recursive_img_array2path(v, save_path, labels)
  459. elif isinstance(data, list):
  460. for item in data:
  461. recursive_img_array2path(item, save_path, labels)
  462. def _calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2):
  463. """
  464. Calculate the ratio of the overlap area between bbox1 and bbox2
  465. to the area of the smaller bounding box.
  466. Args:
  467. bbox1 (list or tuple): Coordinates of the first bounding box [x_min, y_min, x_max, y_max].
  468. bbox2 (list or tuple): Coordinates of the second bounding box [x_min, y_min, x_max, y_max].
  469. Returns:
  470. float: The ratio of the overlap area to the area of the smaller bounding box.
  471. """
  472. x_left = max(bbox1[0], bbox2[0])
  473. y_top = max(bbox1[1], bbox2[1])
  474. x_right = min(bbox1[2], bbox2[2])
  475. y_bottom = min(bbox1[3], bbox2[3])
  476. if x_right <= x_left or y_bottom <= y_top:
  477. return 0.0
  478. # Calculate the area of the overlap
  479. intersection_area = (x_right - x_left) * (y_bottom - y_top)
  480. # Calculate the areas of both bounding boxes
  481. area_bbox1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
  482. area_bbox2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
  483. # Determine the minimum non-zero box area
  484. min_box_area = min(area_bbox1, area_bbox2)
  485. # Avoid division by zero in case of zero-area boxes
  486. if min_box_area == 0:
  487. return 0.0
  488. return intersection_area / min_box_area
  489. def _get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio, smaller=True):
  490. """
  491. Determine if the overlap area between two bounding boxes exceeds a given ratio
  492. and return the smaller (or larger) bounding box based on the `smaller` flag.
  493. Args:
  494. bbox1 (list or tuple): Coordinates of the first bounding box [x_min, y_min, x_max, y_max].
  495. bbox2 (list or tuple): Coordinates of the second bounding box [x_min, y_min, x_max, y_max].
  496. ratio (float): The overlap ratio threshold.
  497. smaller (bool): If True, return the smaller bounding box; otherwise, return the larger one.
  498. Returns:
  499. list or tuple: The selected bounding box or None if the overlap ratio is not exceeded.
  500. """
  501. # Calculate the areas of both bounding boxes
  502. area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
  503. area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
  504. # Calculate the overlap ratio using a helper function
  505. overlap_ratio = _calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2)
  506. # Check if the overlap ratio exceeds the threshold
  507. if overlap_ratio > ratio:
  508. if (area1 <= area2 and smaller) or (area1 >= area2 and not smaller):
  509. return 1
  510. else:
  511. return 2
  512. return None
  513. def _remove_overlap_blocks(blocks, threshold=0.65, smaller=True):
  514. """
  515. Remove overlapping blocks based on a specified overlap ratio threshold.
  516. Args:
  517. blocks (list): List of block dictionaries, each containing a 'layout_bbox' key.
  518. threshold (float): Ratio threshold to determine significant overlap.
  519. smaller (bool): If True, the smaller block in overlap is removed.
  520. Returns:
  521. tuple: A tuple containing the updated list of blocks and a list of dropped blocks.
  522. """
  523. dropped_blocks = []
  524. dropped_indexes = []
  525. # Iterate over each pair of blocks to find overlaps
  526. for i in range(len(blocks)):
  527. block1 = blocks[i]
  528. for j in range(i + 1, len(blocks)):
  529. block2 = blocks[j]
  530. # Skip blocks that are already marked for removal
  531. if i in dropped_indexes or j in dropped_indexes:
  532. continue
  533. # Check for overlap and determine which block to remove
  534. overlap_box_index = _get_minbox_if_overlap_by_ratio(
  535. block1["layout_bbox"],
  536. block2["layout_bbox"],
  537. threshold,
  538. smaller=smaller,
  539. )
  540. if overlap_box_index is not None:
  541. if overlap_box_index == 1:
  542. block_to_remove = block1
  543. drop_index = i
  544. else:
  545. block_to_remove = block2
  546. drop_index = j
  547. if drop_index not in dropped_indexes:
  548. dropped_indexes.append(drop_index)
  549. dropped_blocks.append(block_to_remove)
  550. dropped_indexes.sort()
  551. for i in reversed(dropped_indexes):
  552. del blocks[i]
  553. return blocks, dropped_blocks
  554. def _text_median_width(blocks):
  555. widths = [
  556. block["layout_bbox"][2] - block["layout_bbox"][0]
  557. for block in blocks
  558. if block["label"] in ["text"]
  559. ]
  560. return np.median(widths) if widths else float("inf")
  561. def _get_layout_property(blocks, median_width, no_mask_labels, threshold=0.8):
  562. """
  563. Determine the layout (single or double column) of text blocks.
  564. Args:
  565. blocks (list): List of block dictionaries containing 'label' and 'layout_bbox'.
  566. median_width (float): Median width of text blocks.
  567. threshold (float): Threshold for determining layout overlap.
  568. Returns:
  569. list: Updated list of blocks with layout information.
  570. """
  571. blocks.sort(
  572. key=lambda x: (
  573. x["layout_bbox"][0],
  574. (x["layout_bbox"][2] - x["layout_bbox"][0]),
  575. ),
  576. )
  577. check_single_layout = {}
  578. page_min_x, page_max_x = float("inf"), 0
  579. double_label_height = 0
  580. double_label_area = 0
  581. single_label_area = 0
  582. for i, block in enumerate(blocks):
  583. page_min_x = min(page_min_x, block["layout_bbox"][0])
  584. page_max_x = max(page_max_x, block["layout_bbox"][2])
  585. page_width = page_max_x - page_min_x
  586. for i, block in enumerate(blocks):
  587. if block["label"] not in no_mask_labels:
  588. continue
  589. x_min_i, _, x_max_i, _ = block["layout_bbox"]
  590. layout_length = x_max_i - x_min_i
  591. cover_count, cover_with_threshold_count = 0, 0
  592. match_block_with_threshold_indexes = []
  593. for j, other_block in enumerate(blocks):
  594. if i == j or other_block["label"] not in no_mask_labels:
  595. continue
  596. x_min_j, _, x_max_j, _ = other_block["layout_bbox"]
  597. x_match_min, x_match_max = max(
  598. x_min_i,
  599. x_min_j,
  600. ), min(x_max_i, x_max_j)
  601. match_block_iou = (x_match_max - x_match_min) / (x_max_j - x_min_j)
  602. if match_block_iou > 0:
  603. cover_count += 1
  604. if match_block_iou > threshold:
  605. cover_with_threshold_count += 1
  606. match_block_with_threshold_indexes.append(
  607. (j, match_block_iou),
  608. )
  609. x_min_i = x_match_max
  610. if x_min_i >= x_max_i:
  611. break
  612. if (
  613. layout_length > median_width * 1.3
  614. and (cover_with_threshold_count >= 2 or cover_count >= 2)
  615. ) or layout_length > 0.6 * page_width:
  616. # if layout_length > median_width * 1.3 and (cover_with_threshold_count >= 2):
  617. block["layout"] = "double"
  618. double_label_height += block["layout_bbox"][3] - block["layout_bbox"][1]
  619. double_label_area += (block["layout_bbox"][2] - block["layout_bbox"][0]) * (
  620. block["layout_bbox"][3] - block["layout_bbox"][1]
  621. )
  622. else:
  623. block["layout"] = "single"
  624. check_single_layout[i] = match_block_with_threshold_indexes
  625. # Check single-layout block
  626. for i, single_layout in check_single_layout.items():
  627. if single_layout:
  628. index, match_iou = single_layout[-1]
  629. if match_iou > 0.9 and blocks[index]["layout"] == "double":
  630. blocks[i]["layout"] = "double"
  631. double_label_height += (
  632. blocks[i]["layout_bbox"][3] - blocks[i]["layout_bbox"][1]
  633. )
  634. double_label_area += (
  635. blocks[i]["layout_bbox"][2] - blocks[i]["layout_bbox"][0]
  636. ) * (blocks[i]["layout_bbox"][3] - blocks[i]["layout_bbox"][1])
  637. else:
  638. single_label_area += (
  639. blocks[i]["layout_bbox"][2] - blocks[i]["layout_bbox"][0]
  640. ) * (blocks[i]["layout_bbox"][3] - blocks[i]["layout_bbox"][1])
  641. return blocks, (double_label_area > single_label_area)
  642. def _get_bbox_direction(input_bbox, ratio=1):
  643. """
  644. Determine if a bounding box is horizontal or vertical.
  645. Args:
  646. input_bbox (list): Bounding box [x_min, y_min, x_max, y_max].
  647. ratio (float): Ratio for determining orientation.
  648. Returns:
  649. bool: True if horizontal, False if vertical.
  650. """
  651. return (input_bbox[2] - input_bbox[0]) * ratio >= (input_bbox[3] - input_bbox[1])
  652. def _get_projection_iou(input_bbox, match_bbox, is_horizontal=True):
  653. """
  654. Calculate the IoU of lines between two bounding boxes.
  655. Args:
  656. input_bbox (list): First bounding box [x_min, y_min, x_max, y_max].
  657. match_bbox (list): Second bounding box [x_min, y_min, x_max, y_max].
  658. is_horizontal (bool): Whether to compare horizontally or vertically.
  659. Returns:
  660. float: Line IoU.
  661. """
  662. if is_horizontal:
  663. x_match_min = max(input_bbox[0], match_bbox[0])
  664. x_match_max = min(input_bbox[2], match_bbox[2])
  665. return (x_match_max - x_match_min) / (input_bbox[2] - input_bbox[0])
  666. else:
  667. y_match_min = max(input_bbox[1], match_bbox[1])
  668. y_match_max = min(input_bbox[3], match_bbox[3])
  669. return (y_match_max - y_match_min) / (input_bbox[3] - input_bbox[1])
  670. def _get_sub_category(blocks, title_labels):
  671. """
  672. Determine the layout of title and text blocks.
  673. Args:
  674. blocks (list): List of block dictionaries.
  675. title_labels (list): List of labels considered as titles.
  676. Returns:
  677. list: Updated list of blocks with title-text layout information.
  678. """
  679. sub_title_labels = ["paragraph_title"]
  680. vision_labels = ["image", "table", "chart", "figure"]
  681. for i, block1 in enumerate(blocks):
  682. if block1.get("title_text") is None:
  683. block1["title_text"] = []
  684. if block1.get("sub_title") is None:
  685. block1["sub_title"] = []
  686. if block1.get("vision_footnote") is None:
  687. block1["vision_footnote"] = []
  688. if block1.get("sub_label") is None:
  689. block1["sub_label"] = block1["label"]
  690. if (
  691. block1["label"] not in title_labels
  692. and block1["label"] not in sub_title_labels
  693. and block1["label"] not in vision_labels
  694. ):
  695. continue
  696. bbox1 = block1["layout_bbox"]
  697. x1, y1, x2, y2 = bbox1
  698. is_horizontal_1 = _get_bbox_direction(block1["layout_bbox"])
  699. left_up_title_text_distance = float("inf")
  700. left_up_title_text_index = -1
  701. left_up_title_text_direction = None
  702. right_down_title_text_distance = float("inf")
  703. right_down_title_text_index = -1
  704. right_down_title_text_direction = None
  705. for j, block2 in enumerate(blocks):
  706. if i == j:
  707. continue
  708. bbox2 = block2["layout_bbox"]
  709. x1_prime, y1_prime, x2_prime, y2_prime = bbox2
  710. is_horizontal_2 = _get_bbox_direction(bbox2)
  711. match_block_iou = _get_projection_iou(
  712. bbox2,
  713. bbox1,
  714. is_horizontal_1,
  715. )
  716. def distance_(is_horizontal, is_left_up):
  717. if is_horizontal:
  718. if is_left_up:
  719. return (y1 - y2_prime + 2) // 5 + x1_prime / 5000
  720. else:
  721. return (y1_prime - y2 + 2) // 5 + x1_prime / 5000
  722. else:
  723. if is_left_up:
  724. return (x1 - x2_prime + 2) // 5 + y1_prime / 5000
  725. else:
  726. return (x1_prime - x2 + 2) // 5 + y1_prime / 5000
  727. block_iou_threshold = 0.1
  728. if block1["label"] in sub_title_labels:
  729. match_block_iou = _calculate_overlap_area_2_minbox_area_ratio(
  730. bbox2,
  731. bbox1,
  732. )
  733. block_iou_threshold = 0.7
  734. if is_horizontal_1:
  735. if match_block_iou >= block_iou_threshold:
  736. left_up_distance = distance_(True, True)
  737. right_down_distance = distance_(True, False)
  738. if (
  739. y2_prime <= y1
  740. and left_up_distance <= left_up_title_text_distance
  741. ):
  742. left_up_title_text_distance = left_up_distance
  743. left_up_title_text_index = j
  744. left_up_title_text_direction = is_horizontal_2
  745. elif (
  746. y1_prime > y2
  747. and right_down_distance < right_down_title_text_distance
  748. ):
  749. right_down_title_text_distance = right_down_distance
  750. right_down_title_text_index = j
  751. right_down_title_text_direction = is_horizontal_2
  752. else:
  753. if match_block_iou >= block_iou_threshold:
  754. left_up_distance = distance_(False, True)
  755. right_down_distance = distance_(False, False)
  756. if (
  757. x2_prime <= x1
  758. and left_up_distance <= left_up_title_text_distance
  759. ):
  760. left_up_title_text_distance = left_up_distance
  761. left_up_title_text_index = j
  762. left_up_title_text_direction = is_horizontal_2
  763. elif (
  764. x1_prime > x2
  765. and right_down_distance < right_down_title_text_distance
  766. ):
  767. right_down_title_text_distance = right_down_distance
  768. right_down_title_text_index = j
  769. right_down_title_text_direction = is_horizontal_2
  770. height = bbox1[3] - bbox1[1]
  771. width = bbox1[2] - bbox1[0]
  772. title_text_weight = [0.8, 0.8]
  773. # title_text_weight = [2, 2]
  774. title_text = []
  775. sub_title = []
  776. vision_footnote = []
  777. def get_sub_category_(
  778. title_text_direction,
  779. title_text_index,
  780. label,
  781. is_left_up=True,
  782. ):
  783. direction_ = [1, 3] if is_left_up else [2, 4]
  784. if (
  785. title_text_direction == is_horizontal_1
  786. and title_text_index != -1
  787. and (label == "text" or label == "paragraph_title")
  788. ):
  789. bbox2 = blocks[title_text_index]["layout_bbox"]
  790. if is_horizontal_1:
  791. height1 = bbox2[3] - bbox2[1]
  792. width1 = bbox2[2] - bbox2[0]
  793. if label == "text":
  794. if (
  795. _nearest_edge_distance(bbox1, bbox2)[0] <= 15
  796. and block1["label"] in vision_labels
  797. and width1 < width
  798. and height1 < 0.5 * height
  799. ):
  800. blocks[title_text_index]["sub_label"] = "vision_footnote"
  801. vision_footnote.append(bbox2)
  802. elif (
  803. height1 < height * title_text_weight[0]
  804. and (width1 < width or width1 > 1.5 * width)
  805. and block1["label"] in title_labels
  806. ):
  807. blocks[title_text_index]["sub_label"] = "title_text"
  808. title_text.append((direction_[0], bbox2))
  809. elif (
  810. label == "paragraph_title"
  811. and block1["label"] in sub_title_labels
  812. ):
  813. sub_title.append(bbox2)
  814. else:
  815. height1 = bbox2[3] - bbox2[1]
  816. width1 = bbox2[2] - bbox2[0]
  817. if label == "text":
  818. if (
  819. _nearest_edge_distance(bbox1, bbox2)[0] <= 15
  820. and block1["label"] in vision_labels
  821. and height1 < height
  822. and width1 < 0.5 * width
  823. ):
  824. blocks[title_text_index]["sub_label"] = "vision_footnote"
  825. vision_footnote.append(bbox2)
  826. elif (
  827. width1 < width * title_text_weight[1]
  828. and block1["label"] in title_labels
  829. ):
  830. blocks[title_text_index]["sub_label"] = "title_text"
  831. title_text.append((direction_[1], bbox2))
  832. elif (
  833. label == "paragraph_title"
  834. and block1["label"] in sub_title_labels
  835. ):
  836. sub_title.append(bbox2)
  837. if (
  838. is_horizontal_1
  839. and abs(left_up_title_text_distance - right_down_title_text_distance) * 5
  840. > height
  841. ) or (
  842. not is_horizontal_1
  843. and abs(left_up_title_text_distance - right_down_title_text_distance) * 5
  844. > width
  845. ):
  846. if left_up_title_text_distance < right_down_title_text_distance:
  847. get_sub_category_(
  848. left_up_title_text_direction,
  849. left_up_title_text_index,
  850. blocks[left_up_title_text_index]["label"],
  851. True,
  852. )
  853. else:
  854. get_sub_category_(
  855. right_down_title_text_direction,
  856. right_down_title_text_index,
  857. blocks[right_down_title_text_index]["label"],
  858. False,
  859. )
  860. else:
  861. get_sub_category_(
  862. left_up_title_text_direction,
  863. left_up_title_text_index,
  864. blocks[left_up_title_text_index]["label"],
  865. True,
  866. )
  867. get_sub_category_(
  868. right_down_title_text_direction,
  869. right_down_title_text_index,
  870. blocks[right_down_title_text_index]["label"],
  871. False,
  872. )
  873. if block1["label"] in title_labels:
  874. if blocks[i].get("title_text") == []:
  875. blocks[i]["title_text"] = title_text
  876. if block1["label"] in sub_title_labels:
  877. if blocks[i].get("sub_title") == []:
  878. blocks[i]["sub_title"] = sub_title
  879. if block1["label"] in vision_labels:
  880. if blocks[i].get("vision_footnote") == []:
  881. blocks[i]["vision_footnote"] = vision_footnote
  882. return blocks
  883. def get_layout_ordering(data, no_mask_labels=[], already_sorted=False):
  884. """
  885. Process layout parsing results to remove overlapping bounding boxes
  886. and assign an ordering index based on their positions.
  887. Modifies:
  888. The 'parsing_result' list in 'layout_parsing_result' by adding an 'index' to each block.
  889. """
  890. if already_sorted:
  891. return data
  892. title_text_labels = ["doc_title"]
  893. title_labels = ["doc_title", "paragraph_title"]
  894. vision_labels = ["image", "table", "seal", "chart", "figure"]
  895. vision_title_labels = ["table_title", "chart_title", "figure_title"]
  896. parsing_result = data["sub_blocks"]
  897. parsing_result, _ = _remove_overlap_blocks(
  898. parsing_result,
  899. threshold=0.5,
  900. smaller=True,
  901. )
  902. parsing_result = _get_sub_category(parsing_result, title_text_labels)
  903. doc_flag = False
  904. median_width = _text_median_width(parsing_result)
  905. parsing_result, projection_direction = _get_layout_property(
  906. parsing_result,
  907. median_width,
  908. no_mask_labels=no_mask_labels,
  909. threshold=0.3,
  910. )
  911. # Convert bounding boxes to float and remove overlaps
  912. (
  913. double_text_blocks,
  914. title_text_blocks,
  915. title_blocks,
  916. vision_blocks,
  917. vision_title_blocks,
  918. vision_footnote_blocks,
  919. other_blocks,
  920. ) = ([], [], [], [], [], [], [])
  921. drop_indexes = []
  922. for index, block in enumerate(parsing_result):
  923. label = block["sub_label"]
  924. block["layout_bbox"] = list(map(int, block["layout_bbox"]))
  925. if label == "doc_title":
  926. doc_flag = True
  927. if label in no_mask_labels:
  928. if block["layout"] == "double":
  929. double_text_blocks.append(block)
  930. drop_indexes.append(index)
  931. elif label == "title_text":
  932. title_text_blocks.append(block)
  933. drop_indexes.append(index)
  934. elif label == "vision_footnote":
  935. vision_footnote_blocks.append(block)
  936. drop_indexes.append(index)
  937. elif label in vision_title_labels:
  938. vision_title_blocks.append(block)
  939. drop_indexes.append(index)
  940. elif label in title_labels:
  941. title_blocks.append(block)
  942. drop_indexes.append(index)
  943. elif label in vision_labels:
  944. vision_blocks.append(block)
  945. drop_indexes.append(index)
  946. else:
  947. other_blocks.append(block)
  948. drop_indexes.append(index)
  949. for index in sorted(drop_indexes, reverse=True):
  950. del parsing_result[index]
  951. if len(parsing_result) > 0:
  952. # single text label
  953. if len(double_text_blocks) > len(parsing_result) or projection_direction:
  954. parsing_result.extend(title_blocks + double_text_blocks)
  955. title_blocks = []
  956. double_text_blocks = []
  957. block_bboxes = [block["layout_bbox"] for block in parsing_result]
  958. block_bboxes.sort(
  959. key=lambda x: (
  960. x[0] // max(20, median_width),
  961. x[1],
  962. ),
  963. )
  964. block_bboxes = np.array(block_bboxes)
  965. print("sort by yxcut...")
  966. sorted_indices = sort_by_xycut(
  967. block_bboxes,
  968. direction=1,
  969. min_gap=1,
  970. )
  971. else:
  972. block_bboxes = [block["layout_bbox"] for block in parsing_result]
  973. block_bboxes.sort(key=lambda x: (x[0] // 20, x[1]))
  974. block_bboxes = np.array(block_bboxes)
  975. sorted_indices = sort_by_xycut(
  976. block_bboxes,
  977. direction=0,
  978. min_gap=20,
  979. )
  980. sorted_boxes = block_bboxes[sorted_indices].tolist()
  981. for block in parsing_result:
  982. block["index"] = sorted_boxes.index(block["layout_bbox"]) + 1
  983. block["sub_index"] = sorted_boxes.index(block["layout_bbox"]) + 1
  984. def nearest_match_(input_blocks, distance_type="manhattan", is_add_index=True):
  985. for block in input_blocks:
  986. bbox = block["layout_bbox"]
  987. min_distance = float("inf")
  988. min_distance_config = [
  989. [float("inf"), float("inf")],
  990. float("inf"),
  991. float("inf"),
  992. ] # for double text
  993. nearest_gt_index = 0
  994. for match_block in parsing_result:
  995. match_bbox = match_block["layout_bbox"]
  996. if distance_type == "nearest_iou_edge_distance":
  997. distance, min_distance_config = _nearest_iou_edge_distance(
  998. bbox,
  999. match_bbox,
  1000. block["sub_label"],
  1001. vision_labels=vision_labels,
  1002. no_mask_labels=no_mask_labels,
  1003. median_width=median_width,
  1004. title_labels=title_labels,
  1005. title_text=block["title_text"],
  1006. sub_title=block["sub_title"],
  1007. min_distance_config=min_distance_config,
  1008. tolerance_len=10,
  1009. )
  1010. elif distance_type == "title_text":
  1011. if (
  1012. match_block["label"] in title_labels + ["abstract"]
  1013. and match_block["title_text"] != []
  1014. ):
  1015. iou_left_up = _calculate_overlap_area_2_minbox_area_ratio(
  1016. bbox,
  1017. match_block["title_text"][0][1],
  1018. )
  1019. iou_right_down = _calculate_overlap_area_2_minbox_area_ratio(
  1020. bbox,
  1021. match_block["title_text"][-1][1],
  1022. )
  1023. iou = 1 - max(iou_left_up, iou_right_down)
  1024. distance = _manhattan_distance(bbox, match_bbox) * iou
  1025. else:
  1026. distance = float("inf")
  1027. elif distance_type == "manhattan":
  1028. distance = _manhattan_distance(bbox, match_bbox)
  1029. elif distance_type == "vision_footnote":
  1030. if (
  1031. match_block["label"] in vision_labels
  1032. and match_block["vision_footnote"] != []
  1033. ):
  1034. iou_left_up = _calculate_overlap_area_2_minbox_area_ratio(
  1035. bbox,
  1036. match_block["vision_footnote"][0],
  1037. )
  1038. iou_right_down = _calculate_overlap_area_2_minbox_area_ratio(
  1039. bbox,
  1040. match_block["vision_footnote"][-1],
  1041. )
  1042. iou = 1 - max(iou_left_up, iou_right_down)
  1043. distance = _manhattan_distance(bbox, match_bbox) * iou
  1044. else:
  1045. distance = float("inf")
  1046. elif distance_type == "vision_body":
  1047. if (
  1048. match_block["label"] in vision_title_labels
  1049. and block["vision_footnote"] != []
  1050. ):
  1051. iou_left_up = _calculate_overlap_area_2_minbox_area_ratio(
  1052. match_bbox,
  1053. block["vision_footnote"][0],
  1054. )
  1055. iou_right_down = _calculate_overlap_area_2_minbox_area_ratio(
  1056. match_bbox,
  1057. block["vision_footnote"][-1],
  1058. )
  1059. iou = 1 - max(iou_left_up, iou_right_down)
  1060. distance = _manhattan_distance(bbox, match_bbox) * iou
  1061. else:
  1062. distance = float("inf")
  1063. else:
  1064. raise NotImplementedError
  1065. if distance < min_distance:
  1066. min_distance = distance
  1067. if is_add_index:
  1068. nearest_gt_index = match_block.get("index", 999)
  1069. else:
  1070. nearest_gt_index = match_block.get("sub_index", 999)
  1071. if is_add_index:
  1072. block["index"] = nearest_gt_index
  1073. else:
  1074. block["sub_index"] = nearest_gt_index
  1075. parsing_result.append(block)
  1076. # double text label
  1077. double_text_blocks.sort(
  1078. key=lambda x: (
  1079. x["layout_bbox"][1] // 10,
  1080. x["layout_bbox"][0] // median_width,
  1081. x["layout_bbox"][1] ** 2 + x["layout_bbox"][0] ** 2,
  1082. ),
  1083. )
  1084. nearest_match_(
  1085. double_text_blocks,
  1086. distance_type="nearest_iou_edge_distance",
  1087. )
  1088. parsing_result.sort(
  1089. key=lambda x: (x["index"], x["layout_bbox"][1], x["layout_bbox"][0]),
  1090. )
  1091. for idx, block in enumerate(parsing_result):
  1092. block["index"] = idx + 1
  1093. block["sub_index"] = idx + 1
  1094. # title label
  1095. title_blocks.sort(
  1096. key=lambda x: (
  1097. x["layout_bbox"][1] // 10,
  1098. x["layout_bbox"][0] // median_width,
  1099. x["layout_bbox"][1] ** 2 + x["layout_bbox"][0] ** 2,
  1100. ),
  1101. )
  1102. nearest_match_(title_blocks, distance_type="nearest_iou_edge_distance")
  1103. if doc_flag:
  1104. # text_sort_labels = ["doc_title","paragraph_title","abstract"]
  1105. text_sort_labels = ["doc_title"]
  1106. text_label_priority = {
  1107. label: priority for priority, label in enumerate(text_sort_labels)
  1108. }
  1109. doc_titles = []
  1110. for i, block in enumerate(parsing_result):
  1111. if block["label"] == "doc_title":
  1112. doc_titles.append(
  1113. (i, block["layout_bbox"][1], block["layout_bbox"][0]),
  1114. )
  1115. doc_titles.sort(key=lambda x: (x[1], x[2]))
  1116. first_doc_title_index = doc_titles[0][0]
  1117. parsing_result[first_doc_title_index]["index"] = 1
  1118. parsing_result.sort(
  1119. key=lambda x: (
  1120. x["index"],
  1121. text_label_priority.get(x["label"], 9999),
  1122. x["layout_bbox"][1],
  1123. x["layout_bbox"][0],
  1124. ),
  1125. )
  1126. else:
  1127. parsing_result.sort(
  1128. key=lambda x: (
  1129. x["index"],
  1130. x["layout_bbox"][1],
  1131. x["layout_bbox"][0],
  1132. ),
  1133. )
  1134. for idx, block in enumerate(parsing_result):
  1135. block["index"] = idx + 1
  1136. block["sub_index"] = idx + 1
  1137. # title-text label
  1138. nearest_match_(title_text_blocks, distance_type="title_text")
  1139. text_sort_labels = ["doc_title", "paragraph_title", "title_text"]
  1140. text_label_priority = {
  1141. label: priority for priority, label in enumerate(text_sort_labels)
  1142. }
  1143. parsing_result.sort(
  1144. key=lambda x: (
  1145. x["index"],
  1146. text_label_priority.get(x["sub_label"], 9999),
  1147. x["layout_bbox"][1],
  1148. x["layout_bbox"][0],
  1149. ),
  1150. )
  1151. for idx, block in enumerate(parsing_result):
  1152. block["index"] = idx + 1
  1153. block["sub_index"] = idx + 1
  1154. # image,figure,chart,seal label
  1155. nearest_match_(
  1156. vision_title_blocks,
  1157. distance_type="nearest_iou_edge_distance",
  1158. is_add_index=False,
  1159. )
  1160. parsing_result.sort(
  1161. key=lambda x: (
  1162. x["sub_index"],
  1163. x["layout_bbox"][1],
  1164. x["layout_bbox"][0],
  1165. ),
  1166. )
  1167. for idx, block in enumerate(parsing_result):
  1168. block["sub_index"] = idx + 1
  1169. # image,figure,chart,seal label
  1170. nearest_match_(
  1171. vision_blocks,
  1172. distance_type="nearest_iou_edge_distance",
  1173. is_add_index=False,
  1174. )
  1175. parsing_result.sort(
  1176. key=lambda x: (
  1177. x["sub_index"],
  1178. x["layout_bbox"][1],
  1179. x["layout_bbox"][0],
  1180. ),
  1181. )
  1182. for idx, block in enumerate(parsing_result):
  1183. block["sub_index"] = idx + 1
  1184. # vision footnote label
  1185. nearest_match_(
  1186. vision_footnote_blocks,
  1187. distance_type="vision_footnote",
  1188. is_add_index=False,
  1189. )
  1190. text_label_priority = {"vision_footnote": 9999}
  1191. parsing_result.sort(
  1192. key=lambda x: (
  1193. x["sub_index"],
  1194. text_label_priority.get(x["sub_label"], 0),
  1195. x["layout_bbox"][1],
  1196. x["layout_bbox"][0],
  1197. ),
  1198. )
  1199. for idx, block in enumerate(parsing_result):
  1200. block["sub_index"] = idx + 1
  1201. # header、footnote、header_image... label
  1202. nearest_match_(other_blocks, distance_type="manhattan", is_add_index=False)
  1203. return data
  1204. def _generate_input_data(parsing_result):
  1205. """
  1206. The evaluation input data is generated based on the parsing results.
  1207. :param parsing_result: A list containing the results of the layout parsing
  1208. :return: A formatted list of input data
  1209. """
  1210. input_data = [
  1211. {
  1212. "block_bbox": block["block_bbox"],
  1213. "sub_indices": [],
  1214. "sub_bboxes": [],
  1215. }
  1216. for block in parsing_result
  1217. ]
  1218. for block_index, block in enumerate(parsing_result):
  1219. sub_blocks = block["sub_blocks"]
  1220. get_layout_ordering(
  1221. block_index=block_index,
  1222. no_mask_labels=[
  1223. "text",
  1224. "formula",
  1225. "algorithm",
  1226. "reference",
  1227. "content",
  1228. "abstract",
  1229. ],
  1230. )
  1231. for sub_block in sub_blocks:
  1232. input_data[block_index]["sub_bboxes"].append(
  1233. list(map(int, sub_block["layout_bbox"])),
  1234. )
  1235. input_data[block_index]["sub_indices"].append(
  1236. int(sub_block["index"]),
  1237. )
  1238. return input_data
  1239. def _manhattan_distance(point1, point2, weight_x=1, weight_y=1):
  1240. return weight_x * abs(point1[0] - point2[0]) + weight_y * abs(point1[1] - point2[1])
  1241. def _calculate_horizontal_distance(
  1242. input_bbox,
  1243. match_bbox,
  1244. height,
  1245. disperse,
  1246. title_text,
  1247. ):
  1248. """
  1249. Calculate the horizontal distance between two bounding boxes, considering title text adjustments.
  1250. Args:
  1251. input_bbox (list): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  1252. match_bbox (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  1253. height (int): The height of the input bounding box used for normalization.
  1254. disperse (int): The dispersion factor used to normalize the horizontal distance.
  1255. title_text (list): A list of tuples containing title text information and their bounding box coordinates.
  1256. Format: [(position_indicator, [x1, y1, x2, y2]), ...].
  1257. Returns:
  1258. float: The calculated horizontal distance taking into account the title text adjustments.
  1259. """
  1260. x1, y1, x2, y2 = input_bbox
  1261. x1_prime, y1_prime, x2_prime, y2_prime = match_bbox
  1262. if y2 < y1_prime:
  1263. if title_text and title_text[-1][0] == 2:
  1264. y2 += title_text[-1][1][3] - title_text[-1][1][1]
  1265. distance1 = (y1_prime - y2) * 0.5
  1266. else:
  1267. if title_text and title_text[0][0] == 1:
  1268. y1 -= title_text[0][1][3] - title_text[0][1][1]
  1269. distance1 = y1 - y2_prime
  1270. return (
  1271. abs(x2_prime - x1) // disperse + distance1 // height + distance1 / 5000
  1272. ) # if page max size == 5000
  1273. def _calculate_vertical_distance(input_bbox, match_bbox, width, disperse, title_text):
  1274. """
  1275. Calculate the vertical distance between two bounding boxes, considering title text adjustments.
  1276. Args:
  1277. input_bbox (list): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  1278. match_bbox (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  1279. width (int): The width of the input bounding box used for normalization.
  1280. disperse (int): The dispersion factor used to normalize the vertical distance.
  1281. title_text (list): A list of tuples containing title text information and their bounding box coordinates.
  1282. Format: [(position_indicator, [x1, y1, x2, y2]), ...].
  1283. Returns:
  1284. float: The calculated vertical distance taking into account the title text adjustments.
  1285. """
  1286. x1, y1, x2, y2 = input_bbox
  1287. x1_prime, y1_prime, x2_prime, y2_prime = match_bbox
  1288. if x1 > x2_prime:
  1289. if title_text and title_text[0][0] == 3:
  1290. x1 -= title_text[0][1][2] - title_text[0][1][0]
  1291. distance2 = (x1 - x2_prime) * 0.5
  1292. else:
  1293. if title_text and title_text[-1][0] == 4:
  1294. x2 += title_text[-1][1][2] - title_text[-1][1][0]
  1295. distance2 = x1_prime - x2
  1296. return abs(y2_prime - y1) // disperse + distance2 // width + distance2 / 5000
  1297. def _nearest_edge_distance(
  1298. input_bbox,
  1299. match_bbox,
  1300. weight=[1, 1, 1, 1],
  1301. label="text",
  1302. no_mask_labels=[],
  1303. min_edge_distances_config=[],
  1304. tolerance_len=10,
  1305. ):
  1306. """
  1307. Calculate the nearest edge distance between two bounding boxes, considering directional weights.
  1308. Args:
  1309. input_bbox (list): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  1310. match_bbox (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  1311. weight (list, optional): Directional weights for the edge distances [left, right, up, down]. Defaults to [1, 1, 1, 1].
  1312. label (str, optional): The label/type of the object in the bounding box (e.g., 'text'). Defaults to 'text'.
  1313. no_mask_labels (list, optional): Labels for which no masking is applied when calculating edge distances. Defaults to an empty list.
  1314. min_edge_distances_config (list, optional): Configuration for minimum edge distances [min_edge_distance_x, min_edge_distance_y].
  1315. Defaults to [float('inf'), float('inf')].
  1316. Returns:
  1317. tuple: A tuple containing:
  1318. - The calculated minimum edge distance between the bounding boxes.
  1319. - A list with the minimum edge distances in the x and y directions.
  1320. """
  1321. match_bbox_iou = _calculate_overlap_area_2_minbox_area_ratio(
  1322. input_bbox,
  1323. match_bbox,
  1324. )
  1325. if match_bbox_iou > 0 and label not in no_mask_labels:
  1326. return 0, [0, 0]
  1327. if not min_edge_distances_config:
  1328. min_edge_distances_config = [float("inf"), float("inf")]
  1329. min_edge_distance_x, min_edge_distance_y = min_edge_distances_config
  1330. x1, y1, x2, y2 = input_bbox
  1331. x1_prime, y1_prime, x2_prime, y2_prime = match_bbox
  1332. direction_num = 0
  1333. distance_x = float("inf")
  1334. distance_y = float("inf")
  1335. distance = [float("inf")] * 4
  1336. # input_bbox is to the left of match_bbox
  1337. if x2 < x1_prime:
  1338. direction_num += 1
  1339. distance[0] = x1_prime - x2
  1340. if abs(distance[0] - min_edge_distance_x) <= tolerance_len:
  1341. distance_x = min_edge_distance_x * weight[0]
  1342. else:
  1343. distance_x = distance[0] * weight[0]
  1344. # input_bbox is to the right of match_bbox
  1345. elif x1 > x2_prime:
  1346. direction_num += 1
  1347. distance[1] = x1 - x2_prime
  1348. if abs(distance[1] - min_edge_distance_x) <= tolerance_len:
  1349. distance_x = min_edge_distance_x * weight[1]
  1350. else:
  1351. distance_x = distance[1] * weight[1]
  1352. elif match_bbox_iou > 0:
  1353. distance[0] = 0
  1354. distance_x = 0
  1355. # input_bbox is above match_bbox
  1356. if y2 < y1_prime:
  1357. direction_num += 1
  1358. distance[2] = y1_prime - y2
  1359. if abs(distance[2] - min_edge_distance_y) <= tolerance_len:
  1360. distance_y = min_edge_distance_y * weight[2]
  1361. else:
  1362. distance_y = distance[2] * weight[2]
  1363. if label in no_mask_labels:
  1364. distance_y = max(0.1, distance_y) * 100
  1365. # input_bbox is below match_bbox
  1366. elif y1 > y2_prime:
  1367. direction_num += 1
  1368. distance[3] = y1 - y2_prime
  1369. if abs(distance[3] - min_edge_distance_y) <= tolerance_len:
  1370. distance_y = min_edge_distance_y * weight[3]
  1371. else:
  1372. distance_y = distance[3] * weight[3]
  1373. elif match_bbox_iou > 0:
  1374. distance[2] = 0
  1375. distance_y = 0
  1376. if direction_num == 2:
  1377. return (distance_x + distance_y), [
  1378. min(distance[0], distance[1]),
  1379. min(distance[2], distance[3]),
  1380. ]
  1381. else:
  1382. return min(distance_x, distance_y), [
  1383. min(distance[0], distance[1]),
  1384. min(distance[2], distance[3]),
  1385. ]
  1386. def _get_weights(label, horizontal):
  1387. """Define weights based on the label and orientation."""
  1388. if label == "doc_title":
  1389. return (
  1390. [1, 0.1, 0.1, 1] if horizontal else [0.2, 0.1, 1, 1]
  1391. ) # left-down , right-left
  1392. elif label in [
  1393. "paragraph_title",
  1394. "abstract",
  1395. "figure_title",
  1396. "chart_title",
  1397. "image",
  1398. "seal",
  1399. "chart",
  1400. "figure",
  1401. ]:
  1402. return [1, 1, 0.1, 1] # down
  1403. else:
  1404. return [1, 1, 1, 0.1] # up
  1405. def _nearest_iou_edge_distance(
  1406. input_bbox,
  1407. match_bbox,
  1408. label,
  1409. vision_labels,
  1410. no_mask_labels,
  1411. median_width=-1,
  1412. title_labels=[],
  1413. title_text=[],
  1414. sub_title=[],
  1415. min_distance_config=[],
  1416. tolerance_len=10,
  1417. ):
  1418. """
  1419. Calculate the nearest IOU edge distance between two bounding boxes.
  1420. Args:
  1421. input_bbox (list): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  1422. match_bbox (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  1423. label (str): The label/type of the object in the bounding box (e.g., 'image', 'text', etc.).
  1424. no_mask_labels (list): Labels for which no masking is applied when calculating edge distances.
  1425. median_width (int, optional): The median width for title dispersion calculation. Defaults to -1.
  1426. title_labels (list, optional): Labels that indicate the object is a title. Defaults to an empty list.
  1427. title_text (list, optional): Text content associated with title labels. Defaults to an empty list.
  1428. sub_title (list, optional): List of subtitle bounding boxes to adjust the input_bbox. Defaults to an empty list.
  1429. min_distance_config (list, optional): Configuration for minimum distances [min_edge_distances_config, up_edge_distances_config, total_distance].
  1430. Returns:
  1431. tuple: A tuple containing the calculated distance and updated minimum distance configuration.
  1432. """
  1433. x1, y1, x2, y2 = input_bbox
  1434. x1_prime, y1_prime, x2_prime, y2_prime = match_bbox
  1435. min_edge_distances_config, up_edge_distances_config, total_distance = (
  1436. min_distance_config
  1437. )
  1438. iou_distance = 0
  1439. if label in vision_labels:
  1440. horizontal1 = horizontal2 = True
  1441. else:
  1442. horizontal1 = _get_bbox_direction(input_bbox)
  1443. horizontal2 = _get_bbox_direction(match_bbox, 3)
  1444. if (
  1445. horizontal1 != horizontal2
  1446. or _get_projection_iou(input_bbox, match_bbox, horizontal1) < 0.01
  1447. ):
  1448. iou_distance = 1
  1449. elif label == "doc_title" or (label in title_labels and title_text):
  1450. # Calculate distance for titles
  1451. disperse = max(1, median_width)
  1452. width = x2 - x1
  1453. height = y2 - y1
  1454. if horizontal1:
  1455. return (
  1456. _calculate_horizontal_distance(
  1457. input_bbox,
  1458. match_bbox,
  1459. height,
  1460. disperse,
  1461. title_text,
  1462. ),
  1463. min_distance_config,
  1464. )
  1465. else:
  1466. return (
  1467. _calculate_vertical_distance(
  1468. input_bbox,
  1469. match_bbox,
  1470. width,
  1471. disperse,
  1472. title_text,
  1473. ),
  1474. min_distance_config,
  1475. )
  1476. # Adjust input_bbox based on sub_title
  1477. if sub_title:
  1478. for sub in sub_title:
  1479. x1_, y1_, x2_, y2_ = sub
  1480. x1, y1, x2, y2 = (
  1481. min(x1, x1_),
  1482. min(
  1483. y1,
  1484. y1_,
  1485. ),
  1486. max(x2, x2_),
  1487. max(y2, y2_),
  1488. )
  1489. input_bbox = [x1, y1, x2, y2]
  1490. # Calculate edge distance
  1491. weight = _get_weights(label, horizontal1)
  1492. if label == "abstract":
  1493. tolerance_len *= 3
  1494. edge_distance, edge_distance_config = _nearest_edge_distance(
  1495. input_bbox,
  1496. match_bbox,
  1497. weight,
  1498. label=label,
  1499. no_mask_labels=no_mask_labels,
  1500. min_edge_distances_config=min_edge_distances_config,
  1501. tolerance_len=tolerance_len,
  1502. )
  1503. # Weights for combining distances
  1504. iou_edge_weight = [10**6, 10**3, 1, 0.001]
  1505. # Calculate up and left edge distances
  1506. up_edge_distance = y1_prime
  1507. left_edge_distance = x1_prime
  1508. if (
  1509. label in no_mask_labels or label == "paragraph_title" or label in vision_labels
  1510. ) and y1 > y2_prime:
  1511. up_edge_distance = -y2_prime
  1512. left_edge_distance = -x2_prime
  1513. min_up_edge_distance = up_edge_distances_config
  1514. if abs(min_up_edge_distance - up_edge_distance) <= tolerance_len:
  1515. up_edge_distance = min_up_edge_distance
  1516. # Calculate total distance
  1517. distance = (
  1518. iou_distance * iou_edge_weight[0]
  1519. + edge_distance * iou_edge_weight[1]
  1520. + up_edge_distance * iou_edge_weight[2]
  1521. + left_edge_distance * iou_edge_weight[3]
  1522. )
  1523. # Update minimum distance configuration if a smaller distance is found
  1524. if total_distance > distance:
  1525. edge_distance_config = [
  1526. min(min_edge_distances_config[0], edge_distance_config[0]),
  1527. min(min_edge_distances_config[1], edge_distance_config[1]),
  1528. ]
  1529. min_distance_config = [
  1530. edge_distance_config,
  1531. min(up_edge_distance, up_edge_distances_config),
  1532. distance,
  1533. ]
  1534. return distance, min_distance_config
  1535. def get_show_color(label):
  1536. label_colors = {
  1537. # Medium Blue (from 'titles_list')
  1538. "paragraph_title": (102, 102, 255, 100),
  1539. "doc_title": (255, 248, 220, 100), # Cornsilk
  1540. # Light Yellow (from 'tables_caption_list')
  1541. "table_title": (255, 255, 102, 100),
  1542. # Sky Blue (from 'imgs_caption_list')
  1543. "figure_title": (102, 178, 255, 100),
  1544. "chart_title": (221, 160, 221, 100), # Plum
  1545. "vision_footnote": (144, 238, 144, 100), # Light Green
  1546. # Deep Purple (from 'texts_list')
  1547. "text": (153, 0, 76, 100),
  1548. # Bright Green (from 'interequations_list')
  1549. "formula": (0, 255, 0, 100),
  1550. "abstract": (255, 239, 213, 100), # Papaya Whip
  1551. # Medium Green (from 'lists_list' and 'indexs_list')
  1552. "content": (40, 169, 92, 100),
  1553. # Neutral Gray (from 'dropped_bbox_list')
  1554. "seal": (158, 158, 158, 100),
  1555. # Olive Yellow (from 'tables_body_list')
  1556. "table": (204, 204, 0, 100),
  1557. # Bright Green (from 'imgs_body_list')
  1558. "image": (153, 255, 51, 100),
  1559. # Bright Green (from 'imgs_body_list')
  1560. "figure": (153, 255, 51, 100),
  1561. "chart": (216, 191, 216, 100), # Thistle
  1562. # Pale Yellow-Green (from 'tables_footnote_list')
  1563. "reference": (229, 255, 204, 100),
  1564. "algorithm": (255, 250, 240, 100), # Floral White
  1565. }
  1566. default_color = (158, 158, 158, 100)
  1567. return label_colors.get(label, default_color)