utils.py 65 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. __all__ = [
  15. "get_sub_regions_ocr_res",
  16. "get_layout_ordering",
  17. "recursive_img_array2path",
  18. "get_show_color",
  19. ]
  20. import numpy as np
  21. import copy
  22. import cv2
  23. import uuid
  24. from pathlib import Path
  25. from ..ocr.result import OCRResult
  26. from ...models_new.object_detection.result import DetResult
  27. def get_overlap_boxes_idx(src_boxes: np.ndarray, ref_boxes: np.ndarray) -> list:
  28. """
  29. Get the indices of source boxes that overlap with reference boxes based on a specified threshold.
  30. Args:
  31. src_boxes (np.ndarray): A 2D numpy array of source bounding boxes.
  32. ref_boxes (np.ndarray): A 2D numpy array of reference bounding boxes.
  33. Returns:
  34. list: A list of indices of source boxes that overlap with any reference box.
  35. """
  36. match_idx_list = []
  37. src_boxes_num = len(src_boxes)
  38. if src_boxes_num > 0 and len(ref_boxes) > 0:
  39. for rno in range(len(ref_boxes)):
  40. ref_box = ref_boxes[rno]
  41. x1 = np.maximum(ref_box[0], src_boxes[:, 0])
  42. y1 = np.maximum(ref_box[1], src_boxes[:, 1])
  43. x2 = np.minimum(ref_box[2], src_boxes[:, 2])
  44. y2 = np.minimum(ref_box[3], src_boxes[:, 3])
  45. pub_w = x2 - x1
  46. pub_h = y2 - y1
  47. match_idx = np.where((pub_w > 3) & (pub_h > 3))[0]
  48. match_idx_list.extend(match_idx)
  49. return match_idx_list
  50. def get_sub_regions_ocr_res(
  51. overall_ocr_res: OCRResult, object_boxes: list, flag_within: bool = True
  52. ) -> OCRResult:
  53. """
  54. Filters OCR results to only include text boxes within specified object boxes based on a flag.
  55. Args:
  56. overall_ocr_res (OCRResult): The original OCR result containing all text boxes.
  57. object_boxes (list): A list of bounding boxes for the objects of interest.
  58. flag_within (bool): If True, only include text boxes within the object boxes. If False, exclude text boxes within the object boxes.
  59. Returns:
  60. OCRResult: A filtered OCR result containing only the relevant text boxes.
  61. """
  62. sub_regions_ocr_res = {}
  63. sub_regions_ocr_res["rec_polys"] = []
  64. sub_regions_ocr_res["rec_texts"] = []
  65. sub_regions_ocr_res["rec_scores"] = []
  66. sub_regions_ocr_res["rec_boxes"] = []
  67. overall_text_boxes = overall_ocr_res["rec_boxes"]
  68. match_idx_list = get_overlap_boxes_idx(overall_text_boxes, object_boxes)
  69. match_idx_list = list(set(match_idx_list))
  70. for box_no in range(len(overall_text_boxes)):
  71. if flag_within:
  72. if box_no in match_idx_list:
  73. flag_match = True
  74. else:
  75. flag_match = False
  76. else:
  77. if box_no not in match_idx_list:
  78. flag_match = True
  79. else:
  80. flag_match = False
  81. if flag_match:
  82. sub_regions_ocr_res["rec_polys"].append(
  83. overall_ocr_res["rec_polys"][box_no]
  84. )
  85. sub_regions_ocr_res["rec_texts"].append(
  86. overall_ocr_res["rec_texts"][box_no]
  87. )
  88. sub_regions_ocr_res["rec_scores"].append(
  89. overall_ocr_res["rec_scores"][box_no]
  90. )
  91. sub_regions_ocr_res["rec_boxes"].append(
  92. overall_ocr_res["rec_boxes"][box_no]
  93. )
  94. return sub_regions_ocr_res
  95. def calculate_iou(box1, box2):
  96. """
  97. Calculate Intersection over Union (IoU) between two bounding boxes.
  98. Args:
  99. box1, box2: Lists or tuples representing bounding boxes [x_min, y_min, x_max, y_max].
  100. Returns:
  101. float: The IoU value.
  102. """
  103. box1 = list(map(int, box1))
  104. box2 = list(map(int, box2))
  105. x1_min, y1_min, x1_max, y1_max = box1
  106. x2_min, y2_min, x2_max, y2_max = box2
  107. inter_x_min = max(x1_min, x2_min)
  108. inter_y_min = max(y1_min, y2_min)
  109. inter_x_max = min(x1_max, x2_max)
  110. inter_y_max = min(y1_max, y2_max)
  111. if inter_x_max <= inter_x_min or inter_y_max <= inter_y_min:
  112. return 0.0
  113. inter_area = (inter_x_max - inter_x_min) * (inter_y_max - inter_y_min)
  114. box1_area = (x1_max - x1_min) * (y1_max - y1_min)
  115. box2_area = (x2_max - x2_min) * (y2_max - y2_min)
  116. min_area = min(box1_area, box2_area)
  117. if min_area <= 0:
  118. return 0.0
  119. iou = inter_area / min_area
  120. return iou
  121. def _whether_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.6):
  122. _, y0_1, _, y1_1 = bbox1
  123. _, y0_2, _, y1_2 = bbox2
  124. overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2))
  125. min_height = min(y1_1 - y0_1, y1_2 - y0_2)
  126. return (overlap / min_height) > overlap_ratio_threshold
  127. def _sort_box_by_y_projection(layout_bbox, ocr_res, line_height_threshold=0.7):
  128. assert ocr_res["boxes"] and ocr_res["rec_texts"]
  129. # span->line->block
  130. boxes = ocr_res["boxes"]
  131. rec_text = ocr_res["rec_texts"]
  132. x_min, x_max = layout_bbox[0], layout_bbox[2]
  133. spans = list(zip(boxes, rec_text))
  134. spans.sort(key=lambda span: span[0][1])
  135. spans = [list(span) for span in spans]
  136. lines = []
  137. first_span = spans[0]
  138. current_line = [first_span]
  139. current_y0, current_y1 = first_span[0][1], first_span[0][3]
  140. for span in spans[1:]:
  141. y0, y1 = span[0][1], span[0][3]
  142. if _whether_overlaps_y_exceeds_threshold(
  143. (0, current_y0, 0, current_y1),
  144. (0, y0, 0, y1),
  145. line_height_threshold,
  146. ):
  147. current_line.append(span)
  148. current_y0 = min(current_y0, y0)
  149. current_y1 = max(current_y1, y1)
  150. else:
  151. lines.append(current_line)
  152. current_line = [span]
  153. current_y0, current_y1 = y0, y1
  154. if current_line:
  155. lines.append(current_line)
  156. for line in lines:
  157. line.sort(key=lambda span: span[0][0])
  158. first_span = line[0]
  159. end_span = line[-1]
  160. if first_span[0][0] - x_min > 20:
  161. first_span[1] = "\n " + first_span[1]
  162. if x_max - end_span[0][2] > 20:
  163. end_span[1] = end_span[1] + "\n"
  164. ocr_res["boxes"] = [span[0] for line in lines for span in line]
  165. ocr_res["rec_texts"] = [span[1] + " " for line in lines for span in line]
  166. return ocr_res
  167. def get_structure_res(
  168. overall_ocr_res: OCRResult,
  169. layout_det_res: DetResult,
  170. table_res_list,
  171. ) -> OCRResult:
  172. """
  173. Extract structured information from OCR and layout detection results.
  174. Args:
  175. overall_ocr_res (OCRResult): An object containing the overall OCR results, including detected text boxes and recognized text. The structure is expected to have:
  176. - "input_img": The image on which OCR was performed.
  177. - "dt_boxes": A list of detected text box coordinates.
  178. - "rec_texts": A list of recognized text corresponding to the detected boxes.
  179. layout_det_res (DetResult): An object containing the layout detection results, including detected layout boxes and their labels. The structure is expected to have:
  180. - "boxes": A list of dictionaries with keys "coordinate" for box coordinates and "label" for the type of content.
  181. table_res_list (list): A list of table detection results, where each item is a dictionary containing:
  182. - "layout_bbox": The bounding box of the table layout.
  183. - "pred_html": The predicted HTML representation of the table.
  184. Returns:
  185. list: A list of structured boxes where each item is a dictionary containing:
  186. - "label": The label of the content (e.g., 'table', 'chart', 'image').
  187. - The label as a key with either table HTML or image data and text.
  188. - "layout_bbox": The coordinates of the layout box.
  189. """
  190. structure_boxes = []
  191. input_img = overall_ocr_res["doc_preprocessor_res"]["output_img"]
  192. for box_info in layout_det_res["boxes"]:
  193. layout_bbox = box_info["coordinate"]
  194. label = box_info["label"]
  195. rec_res = {"boxes": [], "rec_texts": [], "flag": False}
  196. drop_index = []
  197. seg_start_flag = True
  198. seg_end_flag = True
  199. if label == "table":
  200. for i, table_res in enumerate(table_res_list):
  201. if calculate_iou(layout_bbox, table_res["layout_bbox"]) > 0.5:
  202. structure_boxes.append(
  203. {
  204. "label": label,
  205. f"{label}": table_res["pred_html"],
  206. "layout_bbox": layout_bbox,
  207. "seg_start_flag": seg_start_flag,
  208. "seg_end_flag": seg_end_flag,
  209. },
  210. )
  211. del table_res_list[i]
  212. break
  213. else:
  214. overall_text_boxes = overall_ocr_res["rec_boxes"]
  215. for box_no in range(len(overall_text_boxes)):
  216. if calculate_iou(layout_bbox, overall_text_boxes[box_no]) > 0.5:
  217. rec_res["boxes"].append(overall_text_boxes[box_no])
  218. rec_res["rec_texts"].append(
  219. overall_ocr_res["rec_texts"][box_no],
  220. )
  221. rec_res["flag"] = True
  222. drop_index.append(box_no)
  223. if rec_res["flag"]:
  224. rec_res = _sort_box_by_y_projection(layout_bbox, rec_res, 0.7)
  225. rec_res_first_bbox = rec_res["boxes"][0]
  226. rec_res_end_bbox = rec_res["boxes"][-1]
  227. if rec_res_first_bbox[0] - layout_bbox[0] < 20:
  228. seg_start_flag = False
  229. if layout_bbox[2] - rec_res_end_bbox[2] < 20:
  230. seg_end_flag = False
  231. if label in ["chart", "image"]:
  232. structure_boxes.append(
  233. {
  234. "label": label,
  235. f"{label}": {
  236. "img": input_img[
  237. int(layout_bbox[1]) : int(layout_bbox[3]),
  238. int(layout_bbox[0]) : int(layout_bbox[2]),
  239. ],
  240. },
  241. "layout_bbox": layout_bbox,
  242. "seg_start_flag": seg_start_flag,
  243. "seg_end_flag": seg_end_flag,
  244. },
  245. )
  246. else:
  247. structure_boxes.append(
  248. {
  249. "label": label,
  250. f"{label}": "".join(rec_res["rec_texts"]),
  251. "layout_bbox": layout_bbox,
  252. "seg_start_flag": seg_start_flag,
  253. "seg_end_flag": seg_end_flag,
  254. },
  255. )
  256. return structure_boxes
  257. def projection_by_bboxes(boxes: np.ndarray, axis: int) -> np.ndarray:
  258. """
  259. Generate a 1D projection histogram from bounding boxes along a specified axis.
  260. Args:
  261. boxes: A (N, 4) array of bounding boxes defined by [x_min, y_min, x_max, y_max].
  262. axis: Axis for projection; 0 for horizontal (x-axis), 1 for vertical (y-axis).
  263. Returns:
  264. A 1D numpy array representing the projection histogram based on bounding box intervals.
  265. """
  266. assert axis in [0, 1]
  267. max_length = np.max(boxes[:, axis::2])
  268. projection = np.zeros(max_length, dtype=int)
  269. # Increment projection histogram over the interval defined by each bounding box
  270. for start, end in boxes[:, axis::2]:
  271. projection[start:end] += 1
  272. return projection
  273. def split_projection_profile(arr_values: np.ndarray, min_value: float, min_gap: float):
  274. """
  275. Split the projection profile into segments based on specified thresholds.
  276. Args:
  277. arr_values: 1D array representing the projection profile.
  278. min_value: Minimum value threshold to consider a profile segment significant.
  279. min_gap: Minimum gap width to consider a separation between segments.
  280. Returns:
  281. A tuple of start and end indices for each segment that meets the criteria.
  282. """
  283. # Identify indices where the projection exceeds the minimum value
  284. significant_indices = np.where(arr_values > min_value)[0]
  285. if not len(significant_indices):
  286. return
  287. # Calculate gaps between significant indices
  288. index_diffs = significant_indices[1:] - significant_indices[:-1]
  289. gap_indices = np.where(index_diffs > min_gap)[0]
  290. # Determine start and end indices of segments
  291. segment_starts = np.insert(
  292. significant_indices[gap_indices + 1],
  293. 0,
  294. significant_indices[0],
  295. )
  296. segment_ends = np.append(
  297. significant_indices[gap_indices],
  298. significant_indices[-1] + 1,
  299. )
  300. return segment_starts, segment_ends
  301. def recursive_yx_cut(boxes: np.ndarray, indices: list[int], res: list[int], min_gap=1):
  302. """
  303. Recursively project and segment bounding boxes, starting with Y-axis and followed by X-axis.
  304. Args:
  305. boxes: A (N, 4) array representing bounding boxes.
  306. indices: List of indices indicating the original position of boxes.
  307. res: List to store indices of the final segmented bounding boxes.
  308. """
  309. assert len(boxes) == len(indices)
  310. # Sort by y_min for Y-axis projection
  311. y_sorted_indices = boxes[:, 1].argsort()
  312. y_sorted_boxes = boxes[y_sorted_indices]
  313. y_sorted_indices = np.array(indices)[y_sorted_indices]
  314. # Perform Y-axis projection
  315. y_projection = projection_by_bboxes(boxes=y_sorted_boxes, axis=1)
  316. y_intervals = split_projection_profile(y_projection, 0, 1)
  317. if not y_intervals:
  318. return
  319. # Process each segment defined by Y-axis projection
  320. for y_start, y_end in zip(*y_intervals):
  321. # Select boxes within the current y interval
  322. y_interval_indices = (y_start <= y_sorted_boxes[:, 1]) & (
  323. y_sorted_boxes[:, 1] < y_end
  324. )
  325. y_boxes_chunk = y_sorted_boxes[y_interval_indices]
  326. y_indices_chunk = y_sorted_indices[y_interval_indices]
  327. # Sort by x_min for X-axis projection
  328. x_sorted_indices = y_boxes_chunk[:, 0].argsort()
  329. x_sorted_boxes_chunk = y_boxes_chunk[x_sorted_indices]
  330. x_sorted_indices_chunk = y_indices_chunk[x_sorted_indices]
  331. # Perform X-axis projection
  332. x_projection = projection_by_bboxes(boxes=x_sorted_boxes_chunk, axis=0)
  333. x_intervals = split_projection_profile(x_projection, 0, min_gap)
  334. if not x_intervals:
  335. continue
  336. # If X-axis cannot be further segmented, add current indices to results
  337. if len(x_intervals[0]) == 1:
  338. res.extend(x_sorted_indices_chunk)
  339. continue
  340. # Recursively process each segment defined by X-axis projection
  341. for x_start, x_end in zip(*x_intervals):
  342. x_interval_indices = (x_start <= x_sorted_boxes_chunk[:, 0]) & (
  343. x_sorted_boxes_chunk[:, 0] < x_end
  344. )
  345. recursive_yx_cut(
  346. x_sorted_boxes_chunk[x_interval_indices],
  347. x_sorted_indices_chunk[x_interval_indices],
  348. res,
  349. )
  350. def recursive_xy_cut(boxes: np.ndarray, indices: list[int], res: list[int], min_gap=1):
  351. """
  352. Recursively performs X-axis projection followed by Y-axis projection to segment bounding boxes.
  353. Args:
  354. boxes: A (N, 4) array representing bounding boxes with [x_min, y_min, x_max, y_max].
  355. indices: A list of indices representing the position of boxes in the original data.
  356. res: A list to store indices of bounding boxes that meet the criteria.
  357. """
  358. # Ensure boxes and indices have the same length
  359. assert len(boxes) == len(indices)
  360. # Sort by x_min to prepare for X-axis projection
  361. x_sorted_indices = boxes[:, 0].argsort()
  362. x_sorted_boxes = boxes[x_sorted_indices]
  363. x_sorted_indices = np.array(indices)[x_sorted_indices]
  364. # Perform X-axis projection
  365. x_projection = projection_by_bboxes(boxes=x_sorted_boxes, axis=0)
  366. x_intervals = split_projection_profile(x_projection, 0, 1)
  367. if not x_intervals:
  368. return
  369. # Process each segment defined by X-axis projection
  370. for x_start, x_end in zip(*x_intervals):
  371. # Select boxes within the current x interval
  372. x_interval_indices = (x_start <= x_sorted_boxes[:, 0]) & (
  373. x_sorted_boxes[:, 0] < x_end
  374. )
  375. x_boxes_chunk = x_sorted_boxes[x_interval_indices]
  376. x_indices_chunk = x_sorted_indices[x_interval_indices]
  377. # Sort selected boxes by y_min to prepare for Y-axis projection
  378. y_sorted_indices = x_boxes_chunk[:, 1].argsort()
  379. y_sorted_boxes_chunk = x_boxes_chunk[y_sorted_indices]
  380. y_sorted_indices_chunk = x_indices_chunk[y_sorted_indices]
  381. # Perform Y-axis projection
  382. y_projection = projection_by_bboxes(boxes=y_sorted_boxes_chunk, axis=1)
  383. y_intervals = split_projection_profile(y_projection, 0, min_gap)
  384. if not y_intervals:
  385. continue
  386. # If Y-axis cannot be further segmented, add current indices to results
  387. if len(y_intervals[0]) == 1:
  388. res.extend(y_sorted_indices_chunk)
  389. continue
  390. # Recursively process each segment defined by Y-axis projection
  391. for y_start, y_end in zip(*y_intervals):
  392. y_interval_indices = (y_start <= y_sorted_boxes_chunk[:, 1]) & (
  393. y_sorted_boxes_chunk[:, 1] < y_end
  394. )
  395. recursive_xy_cut(
  396. y_sorted_boxes_chunk[y_interval_indices],
  397. y_sorted_indices_chunk[y_interval_indices],
  398. res,
  399. )
  400. def sort_by_xycut(block_bboxes, direction=0, min_gap=1):
  401. block_bboxes = np.asarray(block_bboxes).astype(int)
  402. res = []
  403. if direction == 1:
  404. recursive_yx_cut(
  405. block_bboxes,
  406. np.arange(
  407. len(block_bboxes),
  408. ),
  409. res,
  410. min_gap,
  411. )
  412. else:
  413. recursive_xy_cut(
  414. block_bboxes,
  415. np.arange(
  416. len(block_bboxes),
  417. ),
  418. res,
  419. min_gap,
  420. )
  421. return res
  422. def _img_array2path(data, save_path):
  423. """
  424. Save an image array to disk and return the file path.
  425. Args:
  426. data (np.ndarray): An image represented as a numpy array.
  427. save_path (str or Path): The base path where images should be saved.
  428. Returns:
  429. str: The relative path of the saved image file.
  430. """
  431. if isinstance(data, np.ndarray) and data.ndim == 3:
  432. # Generate a unique filename using UUID
  433. img_name = f"image_{uuid.uuid4().hex}.png"
  434. img_path = Path(save_path) / "imgs" / img_name
  435. img_path.parent.mkdir(
  436. parents=True,
  437. exist_ok=True,
  438. ) # Ensure the directory exists
  439. cv2.imwrite(str(img_path), data)
  440. return f"imgs/{img_name}"
  441. else:
  442. return ValueError
  443. def recursive_img_array2path(data, save_path, labels=[]):
  444. """
  445. Process a dictionary or list to save image arrays to disk and replace them with file paths.
  446. Args:
  447. data (dict or list): The data structure that may contain image arrays.
  448. save_path (str or Path): The base path where images should be saved.
  449. """
  450. if isinstance(data, dict):
  451. for k, v in data.items():
  452. if k in labels and isinstance(v, np.ndarray) and v.ndim == 3:
  453. data[k] = _img_array2path(v, save_path)
  454. else:
  455. recursive_img_array2path(v, save_path, labels)
  456. elif isinstance(data, list):
  457. for item in data:
  458. recursive_img_array2path(item, save_path, labels)
  459. def _calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2):
  460. """
  461. Calculate the ratio of the overlap area between bbox1 and bbox2
  462. to the area of the smaller bounding box.
  463. Args:
  464. bbox1 (list or tuple): Coordinates of the first bounding box [x_min, y_min, x_max, y_max].
  465. bbox2 (list or tuple): Coordinates of the second bounding box [x_min, y_min, x_max, y_max].
  466. Returns:
  467. float: The ratio of the overlap area to the area of the smaller bounding box.
  468. """
  469. x_left = max(bbox1[0], bbox2[0])
  470. y_top = max(bbox1[1], bbox2[1])
  471. x_right = min(bbox1[2], bbox2[2])
  472. y_bottom = min(bbox1[3], bbox2[3])
  473. if x_right <= x_left or y_bottom <= y_top:
  474. return 0.0
  475. # Calculate the area of the overlap
  476. intersection_area = (x_right - x_left) * (y_bottom - y_top)
  477. # Calculate the areas of both bounding boxes
  478. area_bbox1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
  479. area_bbox2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
  480. # Determine the minimum non-zero box area
  481. min_box_area = min(area_bbox1, area_bbox2)
  482. # Avoid division by zero in case of zero-area boxes
  483. if min_box_area == 0:
  484. return 0.0
  485. return intersection_area / min_box_area
  486. def _get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio, smaller=True):
  487. """
  488. Determine if the overlap area between two bounding boxes exceeds a given ratio
  489. and return the smaller (or larger) bounding box based on the `smaller` flag.
  490. Args:
  491. bbox1 (list or tuple): Coordinates of the first bounding box [x_min, y_min, x_max, y_max].
  492. bbox2 (list or tuple): Coordinates of the second bounding box [x_min, y_min, x_max, y_max].
  493. ratio (float): The overlap ratio threshold.
  494. smaller (bool): If True, return the smaller bounding box; otherwise, return the larger one.
  495. Returns:
  496. list or tuple: The selected bounding box or None if the overlap ratio is not exceeded.
  497. """
  498. # Calculate the areas of both bounding boxes
  499. area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
  500. area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
  501. # Calculate the overlap ratio using a helper function
  502. overlap_ratio = _calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2)
  503. # Check if the overlap ratio exceeds the threshold
  504. if overlap_ratio > ratio:
  505. if (area1 <= area2 and smaller) or (area1 >= area2 and not smaller):
  506. return 1
  507. else:
  508. return 2
  509. return None
  510. def _remove_overlap_blocks(blocks, threshold=0.65, smaller=True):
  511. """
  512. Remove overlapping blocks based on a specified overlap ratio threshold.
  513. Args:
  514. blocks (list): List of block dictionaries, each containing a 'layout_bbox' key.
  515. threshold (float): Ratio threshold to determine significant overlap.
  516. smaller (bool): If True, the smaller block in overlap is removed.
  517. Returns:
  518. tuple: A tuple containing the updated list of blocks and a list of dropped blocks.
  519. """
  520. dropped_blocks = []
  521. dropped_indexes = []
  522. # Iterate over each pair of blocks to find overlaps
  523. for i in range(len(blocks)):
  524. block1 = blocks[i]
  525. for j in range(i + 1, len(blocks)):
  526. block2 = blocks[j]
  527. # Skip blocks that are already marked for removal
  528. if i in dropped_indexes or j in dropped_indexes:
  529. continue
  530. # Check for overlap and determine which block to remove
  531. overlap_box_index = _get_minbox_if_overlap_by_ratio(
  532. block1["layout_bbox"],
  533. block2["layout_bbox"],
  534. threshold,
  535. smaller=smaller,
  536. )
  537. if overlap_box_index is not None:
  538. if overlap_box_index == 1:
  539. block_to_remove = block1
  540. drop_index = i
  541. else:
  542. block_to_remove = block2
  543. drop_index = j
  544. if drop_index not in dropped_indexes:
  545. dropped_indexes.append(drop_index)
  546. dropped_blocks.append(block_to_remove)
  547. dropped_indexes.sort()
  548. for i in reversed(dropped_indexes):
  549. del blocks[i]
  550. return blocks, dropped_blocks
  551. def _text_median_width(blocks):
  552. widths = [
  553. block["layout_bbox"][2] - block["layout_bbox"][0]
  554. for block in blocks
  555. if block["label"] in ["text"]
  556. ]
  557. return np.median(widths) if widths else float("inf")
  558. def _get_layout_property(blocks, median_width, no_mask_labels, threshold=0.8):
  559. """
  560. Determine the layout (single or double column) of text blocks.
  561. Args:
  562. blocks (list): List of block dictionaries containing 'label' and 'layout_bbox'.
  563. median_width (float): Median width of text blocks.
  564. threshold (float): Threshold for determining layout overlap.
  565. Returns:
  566. list: Updated list of blocks with layout information.
  567. """
  568. blocks.sort(
  569. key=lambda x: (
  570. x["layout_bbox"][0],
  571. (x["layout_bbox"][2] - x["layout_bbox"][0]),
  572. ),
  573. )
  574. check_single_layout = {}
  575. page_min_x, page_max_x = float("inf"), 0
  576. double_label_height = 0
  577. double_label_area = 0
  578. single_label_area = 0
  579. for i, block in enumerate(blocks):
  580. page_min_x = min(page_min_x, block["layout_bbox"][0])
  581. page_max_x = max(page_max_x, block["layout_bbox"][2])
  582. page_width = page_max_x - page_min_x
  583. for i, block in enumerate(blocks):
  584. if block["label"] not in no_mask_labels:
  585. continue
  586. x_min_i, _, x_max_i, _ = block["layout_bbox"]
  587. layout_length = x_max_i - x_min_i
  588. cover_count, cover_with_threshold_count = 0, 0
  589. match_block_with_threshold_indexes = []
  590. for j, other_block in enumerate(blocks):
  591. if i == j or other_block["label"] not in no_mask_labels:
  592. continue
  593. x_min_j, _, x_max_j, _ = other_block["layout_bbox"]
  594. x_match_min, x_match_max = max(
  595. x_min_i,
  596. x_min_j,
  597. ), min(x_max_i, x_max_j)
  598. match_block_iou = (x_match_max - x_match_min) / (x_max_j - x_min_j)
  599. if match_block_iou > 0:
  600. cover_count += 1
  601. if match_block_iou > threshold:
  602. cover_with_threshold_count += 1
  603. match_block_with_threshold_indexes.append(
  604. (j, match_block_iou),
  605. )
  606. x_min_i = x_match_max
  607. if x_min_i >= x_max_i:
  608. break
  609. if (
  610. layout_length > median_width * 1.3
  611. and (cover_with_threshold_count >= 2 or cover_count >= 2)
  612. ) or layout_length > 0.6 * page_width:
  613. # if layout_length > median_width * 1.3 and (cover_with_threshold_count >= 2):
  614. block["layout"] = "double"
  615. double_label_height += block["layout_bbox"][3] - block["layout_bbox"][1]
  616. double_label_area += (block["layout_bbox"][2] - block["layout_bbox"][0]) * (
  617. block["layout_bbox"][3] - block["layout_bbox"][1]
  618. )
  619. else:
  620. block["layout"] = "single"
  621. check_single_layout[i] = match_block_with_threshold_indexes
  622. # Check single-layout block
  623. for i, single_layout in check_single_layout.items():
  624. if single_layout:
  625. index, match_iou = single_layout[-1]
  626. if match_iou > 0.9 and blocks[index]["layout"] == "double":
  627. blocks[i]["layout"] = "double"
  628. double_label_height += (
  629. blocks[i]["layout_bbox"][3] - blocks[i]["layout_bbox"][1]
  630. )
  631. double_label_area += (
  632. blocks[i]["layout_bbox"][2] - blocks[i]["layout_bbox"][0]
  633. ) * (blocks[i]["layout_bbox"][3] - blocks[i]["layout_bbox"][1])
  634. else:
  635. single_label_area += (
  636. blocks[i]["layout_bbox"][2] - blocks[i]["layout_bbox"][0]
  637. ) * (blocks[i]["layout_bbox"][3] - blocks[i]["layout_bbox"][1])
  638. return blocks, (double_label_area > single_label_area)
  639. def _get_bbox_direction(input_bbox, ratio=1):
  640. """
  641. Determine if a bounding box is horizontal or vertical.
  642. Args:
  643. input_bbox (list): Bounding box [x_min, y_min, x_max, y_max].
  644. ratio (float): Ratio for determining orientation.
  645. Returns:
  646. bool: True if horizontal, False if vertical.
  647. """
  648. return (input_bbox[2] - input_bbox[0]) * ratio >= (input_bbox[3] - input_bbox[1])
  649. def _get_projection_iou(input_bbox, match_bbox, is_horizontal=True):
  650. """
  651. Calculate the IoU of lines between two bounding boxes.
  652. Args:
  653. input_bbox (list): First bounding box [x_min, y_min, x_max, y_max].
  654. match_bbox (list): Second bounding box [x_min, y_min, x_max, y_max].
  655. is_horizontal (bool): Whether to compare horizontally or vertically.
  656. Returns:
  657. float: Line IoU.
  658. """
  659. if is_horizontal:
  660. x_match_min = max(input_bbox[0], match_bbox[0])
  661. x_match_max = min(input_bbox[2], match_bbox[2])
  662. return (x_match_max - x_match_min) / (input_bbox[2] - input_bbox[0])
  663. else:
  664. y_match_min = max(input_bbox[1], match_bbox[1])
  665. y_match_max = min(input_bbox[3], match_bbox[3])
  666. return (y_match_max - y_match_min) / (input_bbox[3] - input_bbox[1])
  667. def _get_sub_category(blocks, title_labels):
  668. """
  669. Determine the layout of title and text blocks.
  670. Args:
  671. blocks (list): List of block dictionaries.
  672. title_labels (list): List of labels considered as titles.
  673. Returns:
  674. list: Updated list of blocks with title-text layout information.
  675. """
  676. sub_title_labels = ["paragraph_title"]
  677. vision_labels = ["image", "table", "chart", "figure"]
  678. for i, block1 in enumerate(blocks):
  679. if block1.get("title_text") is None:
  680. block1["title_text"] = []
  681. if block1.get("sub_title") is None:
  682. block1["sub_title"] = []
  683. if block1.get("vision_footnote") is None:
  684. block1["vision_footnote"] = []
  685. if block1.get("sub_label") is None:
  686. block1["sub_label"] = block1["label"]
  687. if (
  688. block1["label"] not in title_labels
  689. and block1["label"] not in sub_title_labels
  690. and block1["label"] not in vision_labels
  691. ):
  692. continue
  693. bbox1 = block1["layout_bbox"]
  694. x1, y1, x2, y2 = bbox1
  695. is_horizontal_1 = _get_bbox_direction(block1["layout_bbox"])
  696. left_up_title_text_distance = float("inf")
  697. left_up_title_text_index = -1
  698. left_up_title_text_direction = None
  699. right_down_title_text_distance = float("inf")
  700. right_down_title_text_index = -1
  701. right_down_title_text_direction = None
  702. for j, block2 in enumerate(blocks):
  703. if i == j:
  704. continue
  705. bbox2 = block2["layout_bbox"]
  706. x1_prime, y1_prime, x2_prime, y2_prime = bbox2
  707. is_horizontal_2 = _get_bbox_direction(bbox2)
  708. match_block_iou = _get_projection_iou(
  709. bbox2,
  710. bbox1,
  711. is_horizontal_1,
  712. )
  713. def distance_(is_horizontal, is_left_up):
  714. if is_horizontal:
  715. if is_left_up:
  716. return (y1 - y2_prime + 2) // 5 + x1_prime / 5000
  717. else:
  718. return (y1_prime - y2 + 2) // 5 + x1_prime / 5000
  719. else:
  720. if is_left_up:
  721. return (x1 - x2_prime + 2) // 5 + y1_prime / 5000
  722. else:
  723. return (x1_prime - x2 + 2) // 5 + y1_prime / 5000
  724. block_iou_threshold = 0.1
  725. if block1["label"] in sub_title_labels:
  726. match_block_iou = _calculate_overlap_area_2_minbox_area_ratio(
  727. bbox2,
  728. bbox1,
  729. )
  730. block_iou_threshold = 0.7
  731. if is_horizontal_1:
  732. if match_block_iou >= block_iou_threshold:
  733. left_up_distance = distance_(True, True)
  734. right_down_distance = distance_(True, False)
  735. if (
  736. y2_prime <= y1
  737. and left_up_distance <= left_up_title_text_distance
  738. ):
  739. left_up_title_text_distance = left_up_distance
  740. left_up_title_text_index = j
  741. left_up_title_text_direction = is_horizontal_2
  742. elif (
  743. y1_prime > y2
  744. and right_down_distance < right_down_title_text_distance
  745. ):
  746. right_down_title_text_distance = right_down_distance
  747. right_down_title_text_index = j
  748. right_down_title_text_direction = is_horizontal_2
  749. else:
  750. if match_block_iou >= block_iou_threshold:
  751. left_up_distance = distance_(False, True)
  752. right_down_distance = distance_(False, False)
  753. if (
  754. x2_prime <= x1
  755. and left_up_distance <= left_up_title_text_distance
  756. ):
  757. left_up_title_text_distance = left_up_distance
  758. left_up_title_text_index = j
  759. left_up_title_text_direction = is_horizontal_2
  760. elif (
  761. x1_prime > x2
  762. and right_down_distance < right_down_title_text_distance
  763. ):
  764. right_down_title_text_distance = right_down_distance
  765. right_down_title_text_index = j
  766. right_down_title_text_direction = is_horizontal_2
  767. height = bbox1[3] - bbox1[1]
  768. width = bbox1[2] - bbox1[0]
  769. title_text_weight = [0.8, 0.8]
  770. # title_text_weight = [2, 2]
  771. title_text = []
  772. sub_title = []
  773. vision_footnote = []
  774. def get_sub_category_(
  775. title_text_direction,
  776. title_text_index,
  777. label,
  778. is_left_up=True,
  779. ):
  780. direction_ = [1, 3] if is_left_up else [2, 4]
  781. if (
  782. title_text_direction == is_horizontal_1
  783. and title_text_index != -1
  784. and (label == "text" or label == "paragraph_title")
  785. ):
  786. bbox2 = blocks[title_text_index]["layout_bbox"]
  787. if is_horizontal_1:
  788. height1 = bbox2[3] - bbox2[1]
  789. width1 = bbox2[2] - bbox2[0]
  790. if label == "text":
  791. if (
  792. _nearest_edge_distance(bbox1, bbox2)[0] <= 15
  793. and block1["label"] in vision_labels
  794. and width1 < width
  795. and height1 < 0.5 * height
  796. ):
  797. blocks[title_text_index]["sub_label"] = "vision_footnote"
  798. vision_footnote.append(bbox2)
  799. elif (
  800. height1 < height * title_text_weight[0]
  801. and (width1 < width or width1 > 1.5 * width)
  802. and block1["label"] in title_labels
  803. ):
  804. blocks[title_text_index]["sub_label"] = "title_text"
  805. title_text.append((direction_[0], bbox2))
  806. elif (
  807. label == "paragraph_title"
  808. and block1["label"] in sub_title_labels
  809. ):
  810. sub_title.append(bbox2)
  811. else:
  812. height1 = bbox2[3] - bbox2[1]
  813. width1 = bbox2[2] - bbox2[0]
  814. if label == "text":
  815. if (
  816. _nearest_edge_distance(bbox1, bbox2)[0] <= 15
  817. and block1["label"] in vision_labels
  818. and height1 < height
  819. and width1 < 0.5 * width
  820. ):
  821. blocks[title_text_index]["sub_label"] = "vision_footnote"
  822. vision_footnote.append(bbox2)
  823. elif (
  824. width1 < width * title_text_weight[1]
  825. and block1["label"] in title_labels
  826. ):
  827. blocks[title_text_index]["sub_label"] = "title_text"
  828. title_text.append((direction_[1], bbox2))
  829. elif (
  830. label == "paragraph_title"
  831. and block1["label"] in sub_title_labels
  832. ):
  833. sub_title.append(bbox2)
  834. if (
  835. is_horizontal_1
  836. and abs(left_up_title_text_distance - right_down_title_text_distance) * 5
  837. > height
  838. ) or (
  839. not is_horizontal_1
  840. and abs(left_up_title_text_distance - right_down_title_text_distance) * 5
  841. > width
  842. ):
  843. if left_up_title_text_distance < right_down_title_text_distance:
  844. get_sub_category_(
  845. left_up_title_text_direction,
  846. left_up_title_text_index,
  847. blocks[left_up_title_text_index]["label"],
  848. True,
  849. )
  850. else:
  851. get_sub_category_(
  852. right_down_title_text_direction,
  853. right_down_title_text_index,
  854. blocks[right_down_title_text_index]["label"],
  855. False,
  856. )
  857. else:
  858. get_sub_category_(
  859. left_up_title_text_direction,
  860. left_up_title_text_index,
  861. blocks[left_up_title_text_index]["label"],
  862. True,
  863. )
  864. get_sub_category_(
  865. right_down_title_text_direction,
  866. right_down_title_text_index,
  867. blocks[right_down_title_text_index]["label"],
  868. False,
  869. )
  870. if block1["label"] in title_labels:
  871. if blocks[i].get("title_text") == []:
  872. blocks[i]["title_text"] = title_text
  873. if block1["label"] in sub_title_labels:
  874. if blocks[i].get("sub_title") == []:
  875. blocks[i]["sub_title"] = sub_title
  876. if block1["label"] in vision_labels:
  877. if blocks[i].get("vision_footnote") == []:
  878. blocks[i]["vision_footnote"] = vision_footnote
  879. return blocks
  880. def get_layout_ordering(data, no_mask_labels=[], already_sorted=False):
  881. """
  882. Process layout parsing results to remove overlapping bounding boxes
  883. and assign an ordering index based on their positions.
  884. Modifies:
  885. The 'parsing_result' list in 'layout_parsing_result' by adding an 'index' to each block.
  886. """
  887. if already_sorted:
  888. return data
  889. title_text_labels = ["doc_title"]
  890. title_labels = ["doc_title", "paragraph_title"]
  891. vision_labels = ["image", "table", "seal", "chart", "figure"]
  892. vision_title_labels = ["table_title", "chart_title", "figure_title"]
  893. parsing_result = data["sub_blocks"]
  894. parsing_result, _ = _remove_overlap_blocks(
  895. parsing_result,
  896. threshold=0.5,
  897. smaller=True,
  898. )
  899. parsing_result = _get_sub_category(parsing_result, title_text_labels)
  900. doc_flag = False
  901. median_width = _text_median_width(parsing_result)
  902. parsing_result, projection_direction = _get_layout_property(
  903. parsing_result,
  904. median_width,
  905. no_mask_labels=no_mask_labels,
  906. threshold=0.3,
  907. )
  908. # Convert bounding boxes to float and remove overlaps
  909. (
  910. double_text_blocks,
  911. title_text_blocks,
  912. title_blocks,
  913. vision_blocks,
  914. vision_title_blocks,
  915. vision_footnote_blocks,
  916. other_blocks,
  917. ) = ([], [], [], [], [], [], [])
  918. drop_indexes = []
  919. for index, block in enumerate(parsing_result):
  920. label = block["sub_label"]
  921. block["layout_bbox"] = list(map(int, block["layout_bbox"]))
  922. if label == "doc_title":
  923. doc_flag = True
  924. if label in no_mask_labels:
  925. if block["layout"] == "double":
  926. double_text_blocks.append(block)
  927. drop_indexes.append(index)
  928. elif label == "title_text":
  929. title_text_blocks.append(block)
  930. drop_indexes.append(index)
  931. elif label == "vision_footnote":
  932. vision_footnote_blocks.append(block)
  933. drop_indexes.append(index)
  934. elif label in vision_title_labels:
  935. vision_title_blocks.append(block)
  936. drop_indexes.append(index)
  937. elif label in title_labels:
  938. title_blocks.append(block)
  939. drop_indexes.append(index)
  940. elif label in vision_labels:
  941. vision_blocks.append(block)
  942. drop_indexes.append(index)
  943. else:
  944. other_blocks.append(block)
  945. drop_indexes.append(index)
  946. for index in sorted(drop_indexes, reverse=True):
  947. del parsing_result[index]
  948. if len(parsing_result) > 0:
  949. # single text label
  950. if len(double_text_blocks) > len(parsing_result) or projection_direction:
  951. parsing_result.extend(title_blocks + double_text_blocks)
  952. title_blocks = []
  953. double_text_blocks = []
  954. block_bboxes = [block["layout_bbox"] for block in parsing_result]
  955. block_bboxes.sort(
  956. key=lambda x: (
  957. x[0] // max(20, median_width),
  958. x[1],
  959. ),
  960. )
  961. block_bboxes = np.array(block_bboxes)
  962. print("sort by yxcut...")
  963. sorted_indices = sort_by_xycut(
  964. block_bboxes,
  965. direction=1,
  966. min_gap=1,
  967. )
  968. else:
  969. block_bboxes = [block["layout_bbox"] for block in parsing_result]
  970. block_bboxes.sort(key=lambda x: (x[0] // 20, x[1]))
  971. block_bboxes = np.array(block_bboxes)
  972. sorted_indices = sort_by_xycut(
  973. block_bboxes,
  974. direction=0,
  975. min_gap=20,
  976. )
  977. sorted_boxes = block_bboxes[sorted_indices].tolist()
  978. for block in parsing_result:
  979. block["index"] = sorted_boxes.index(block["layout_bbox"]) + 1
  980. block["sub_index"] = sorted_boxes.index(block["layout_bbox"]) + 1
  981. def nearest_match_(input_blocks, distance_type="manhattan", is_add_index=True):
  982. for block in input_blocks:
  983. bbox = block["layout_bbox"]
  984. min_distance = float("inf")
  985. min_distance_config = [
  986. [float("inf"), float("inf")],
  987. float("inf"),
  988. float("inf"),
  989. ] # for double text
  990. nearest_gt_index = 0
  991. for match_block in parsing_result:
  992. match_bbox = match_block["layout_bbox"]
  993. if distance_type == "nearest_iou_edge_distance":
  994. distance, min_distance_config = _nearest_iou_edge_distance(
  995. bbox,
  996. match_bbox,
  997. block["sub_label"],
  998. vision_labels=vision_labels,
  999. no_mask_labels=no_mask_labels,
  1000. median_width=median_width,
  1001. title_labels=title_labels,
  1002. title_text=block["title_text"],
  1003. sub_title=block["sub_title"],
  1004. min_distance_config=min_distance_config,
  1005. tolerance_len=10,
  1006. )
  1007. elif distance_type == "title_text":
  1008. if (
  1009. match_block["label"] in title_labels + ["abstract"]
  1010. and match_block["title_text"] != []
  1011. ):
  1012. iou_left_up = _calculate_overlap_area_2_minbox_area_ratio(
  1013. bbox,
  1014. match_block["title_text"][0][1],
  1015. )
  1016. iou_right_down = _calculate_overlap_area_2_minbox_area_ratio(
  1017. bbox,
  1018. match_block["title_text"][-1][1],
  1019. )
  1020. iou = 1 - max(iou_left_up, iou_right_down)
  1021. distance = _manhattan_distance(bbox, match_bbox) * iou
  1022. else:
  1023. distance = float("inf")
  1024. elif distance_type == "manhattan":
  1025. distance = _manhattan_distance(bbox, match_bbox)
  1026. elif distance_type == "vision_footnote":
  1027. if (
  1028. match_block["label"] in vision_labels
  1029. and match_block["vision_footnote"] != []
  1030. ):
  1031. iou_left_up = _calculate_overlap_area_2_minbox_area_ratio(
  1032. bbox,
  1033. match_block["vision_footnote"][0],
  1034. )
  1035. iou_right_down = _calculate_overlap_area_2_minbox_area_ratio(
  1036. bbox,
  1037. match_block["vision_footnote"][-1],
  1038. )
  1039. iou = 1 - max(iou_left_up, iou_right_down)
  1040. distance = _manhattan_distance(bbox, match_bbox) * iou
  1041. else:
  1042. distance = float("inf")
  1043. elif distance_type == "vision_body":
  1044. if (
  1045. match_block["label"] in vision_title_labels
  1046. and block["vision_footnote"] != []
  1047. ):
  1048. iou_left_up = _calculate_overlap_area_2_minbox_area_ratio(
  1049. match_bbox,
  1050. block["vision_footnote"][0],
  1051. )
  1052. iou_right_down = _calculate_overlap_area_2_minbox_area_ratio(
  1053. match_bbox,
  1054. block["vision_footnote"][-1],
  1055. )
  1056. iou = 1 - max(iou_left_up, iou_right_down)
  1057. distance = _manhattan_distance(bbox, match_bbox) * iou
  1058. else:
  1059. distance = float("inf")
  1060. else:
  1061. raise NotImplementedError
  1062. if distance < min_distance:
  1063. min_distance = distance
  1064. if is_add_index:
  1065. nearest_gt_index = match_block.get("index", 999)
  1066. else:
  1067. nearest_gt_index = match_block.get("sub_index", 999)
  1068. if is_add_index:
  1069. block["index"] = nearest_gt_index
  1070. else:
  1071. block["sub_index"] = nearest_gt_index
  1072. parsing_result.append(block)
  1073. # double text label
  1074. double_text_blocks.sort(
  1075. key=lambda x: (
  1076. x["layout_bbox"][1] // 10,
  1077. x["layout_bbox"][0] // median_width,
  1078. x["layout_bbox"][1] ** 2 + x["layout_bbox"][0] ** 2,
  1079. ),
  1080. )
  1081. nearest_match_(
  1082. double_text_blocks,
  1083. distance_type="nearest_iou_edge_distance",
  1084. )
  1085. parsing_result.sort(
  1086. key=lambda x: (x["index"], x["layout_bbox"][1], x["layout_bbox"][0]),
  1087. )
  1088. for idx, block in enumerate(parsing_result):
  1089. block["index"] = idx + 1
  1090. block["sub_index"] = idx + 1
  1091. # title label
  1092. title_blocks.sort(
  1093. key=lambda x: (
  1094. x["layout_bbox"][1] // 10,
  1095. x["layout_bbox"][0] // median_width,
  1096. x["layout_bbox"][1] ** 2 + x["layout_bbox"][0] ** 2,
  1097. ),
  1098. )
  1099. nearest_match_(title_blocks, distance_type="nearest_iou_edge_distance")
  1100. if doc_flag:
  1101. # text_sort_labels = ["doc_title","paragraph_title","abstract"]
  1102. text_sort_labels = ["doc_title"]
  1103. text_label_priority = {
  1104. label: priority for priority, label in enumerate(text_sort_labels)
  1105. }
  1106. doc_titles = []
  1107. for i, block in enumerate(parsing_result):
  1108. if block["label"] == "doc_title":
  1109. doc_titles.append(
  1110. (i, block["layout_bbox"][1], block["layout_bbox"][0]),
  1111. )
  1112. doc_titles.sort(key=lambda x: (x[1], x[2]))
  1113. first_doc_title_index = doc_titles[0][0]
  1114. parsing_result[first_doc_title_index]["index"] = 1
  1115. parsing_result.sort(
  1116. key=lambda x: (
  1117. x["index"],
  1118. text_label_priority.get(x["label"], 9999),
  1119. x["layout_bbox"][1],
  1120. x["layout_bbox"][0],
  1121. ),
  1122. )
  1123. else:
  1124. parsing_result.sort(
  1125. key=lambda x: (
  1126. x["index"],
  1127. x["layout_bbox"][1],
  1128. x["layout_bbox"][0],
  1129. ),
  1130. )
  1131. for idx, block in enumerate(parsing_result):
  1132. block["index"] = idx + 1
  1133. block["sub_index"] = idx + 1
  1134. # title-text label
  1135. nearest_match_(title_text_blocks, distance_type="title_text")
  1136. text_sort_labels = ["doc_title", "paragraph_title", "title_text"]
  1137. text_label_priority = {
  1138. label: priority for priority, label in enumerate(text_sort_labels)
  1139. }
  1140. parsing_result.sort(
  1141. key=lambda x: (
  1142. x["index"],
  1143. text_label_priority.get(x["sub_label"], 9999),
  1144. x["layout_bbox"][1],
  1145. x["layout_bbox"][0],
  1146. ),
  1147. )
  1148. for idx, block in enumerate(parsing_result):
  1149. block["index"] = idx + 1
  1150. block["sub_index"] = idx + 1
  1151. # image,figure,chart,seal label
  1152. nearest_match_(
  1153. vision_title_blocks,
  1154. distance_type="nearest_iou_edge_distance",
  1155. is_add_index=False,
  1156. )
  1157. parsing_result.sort(
  1158. key=lambda x: (
  1159. x["sub_index"],
  1160. x["layout_bbox"][1],
  1161. x["layout_bbox"][0],
  1162. ),
  1163. )
  1164. for idx, block in enumerate(parsing_result):
  1165. block["sub_index"] = idx + 1
  1166. # image,figure,chart,seal label
  1167. nearest_match_(
  1168. vision_blocks,
  1169. distance_type="nearest_iou_edge_distance",
  1170. is_add_index=False,
  1171. )
  1172. parsing_result.sort(
  1173. key=lambda x: (
  1174. x["sub_index"],
  1175. x["layout_bbox"][1],
  1176. x["layout_bbox"][0],
  1177. ),
  1178. )
  1179. for idx, block in enumerate(parsing_result):
  1180. block["sub_index"] = idx + 1
  1181. # vision footnote label
  1182. nearest_match_(
  1183. vision_footnote_blocks,
  1184. distance_type="vision_footnote",
  1185. is_add_index=False,
  1186. )
  1187. text_label_priority = {"vision_footnote": 9999}
  1188. parsing_result.sort(
  1189. key=lambda x: (
  1190. x["sub_index"],
  1191. text_label_priority.get(x["sub_label"], 0),
  1192. x["layout_bbox"][1],
  1193. x["layout_bbox"][0],
  1194. ),
  1195. )
  1196. for idx, block in enumerate(parsing_result):
  1197. block["sub_index"] = idx + 1
  1198. # header、footnote、header_image... label
  1199. nearest_match_(other_blocks, distance_type="manhattan", is_add_index=False)
  1200. return data
  1201. def _generate_input_data(parsing_result):
  1202. """
  1203. The evaluation input data is generated based on the parsing results.
  1204. :param parsing_result: A list containing the results of the layout parsing
  1205. :return: A formatted list of input data
  1206. """
  1207. input_data = [
  1208. {
  1209. "block_bbox": block["block_bbox"],
  1210. "sub_indices": [],
  1211. "sub_bboxes": [],
  1212. }
  1213. for block in parsing_result
  1214. ]
  1215. for block_index, block in enumerate(parsing_result):
  1216. sub_blocks = block["sub_blocks"]
  1217. get_layout_ordering(
  1218. block_index=block_index,
  1219. no_mask_labels=[
  1220. "text",
  1221. "formula",
  1222. "algorithm",
  1223. "reference",
  1224. "content",
  1225. "abstract",
  1226. ],
  1227. )
  1228. for sub_block in sub_blocks:
  1229. input_data[block_index]["sub_bboxes"].append(
  1230. list(map(int, sub_block["layout_bbox"])),
  1231. )
  1232. input_data[block_index]["sub_indices"].append(
  1233. int(sub_block["index"]),
  1234. )
  1235. return input_data
  1236. def _manhattan_distance(point1, point2, weight_x=1, weight_y=1):
  1237. return weight_x * abs(point1[0] - point2[0]) + weight_y * abs(point1[1] - point2[1])
  1238. def _calculate_horizontal_distance(
  1239. input_bbox,
  1240. match_bbox,
  1241. height,
  1242. disperse,
  1243. title_text,
  1244. ):
  1245. """
  1246. Calculate the horizontal distance between two bounding boxes, considering title text adjustments.
  1247. Args:
  1248. input_bbox (list): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  1249. match_bbox (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  1250. height (int): The height of the input bounding box used for normalization.
  1251. disperse (int): The dispersion factor used to normalize the horizontal distance.
  1252. title_text (list): A list of tuples containing title text information and their bounding box coordinates.
  1253. Format: [(position_indicator, [x1, y1, x2, y2]), ...].
  1254. Returns:
  1255. float: The calculated horizontal distance taking into account the title text adjustments.
  1256. """
  1257. x1, y1, x2, y2 = input_bbox
  1258. x1_prime, y1_prime, x2_prime, y2_prime = match_bbox
  1259. if y2 < y1_prime:
  1260. if title_text and title_text[-1][0] == 2:
  1261. y2 += title_text[-1][1][3] - title_text[-1][1][1]
  1262. distance1 = (y1_prime - y2) * 0.5
  1263. else:
  1264. if title_text and title_text[0][0] == 1:
  1265. y1 -= title_text[0][1][3] - title_text[0][1][1]
  1266. distance1 = y1 - y2_prime
  1267. return (
  1268. abs(x2_prime - x1) // disperse + distance1 // height + distance1 / 5000
  1269. ) # if page max size == 5000
  1270. def _calculate_vertical_distance(input_bbox, match_bbox, width, disperse, title_text):
  1271. """
  1272. Calculate the vertical distance between two bounding boxes, considering title text adjustments.
  1273. Args:
  1274. input_bbox (list): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  1275. match_bbox (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  1276. width (int): The width of the input bounding box used for normalization.
  1277. disperse (int): The dispersion factor used to normalize the vertical distance.
  1278. title_text (list): A list of tuples containing title text information and their bounding box coordinates.
  1279. Format: [(position_indicator, [x1, y1, x2, y2]), ...].
  1280. Returns:
  1281. float: The calculated vertical distance taking into account the title text adjustments.
  1282. """
  1283. x1, y1, x2, y2 = input_bbox
  1284. x1_prime, y1_prime, x2_prime, y2_prime = match_bbox
  1285. if x1 > x2_prime:
  1286. if title_text and title_text[0][0] == 3:
  1287. x1 -= title_text[0][1][2] - title_text[0][1][0]
  1288. distance2 = (x1 - x2_prime) * 0.5
  1289. else:
  1290. if title_text and title_text[-1][0] == 4:
  1291. x2 += title_text[-1][1][2] - title_text[-1][1][0]
  1292. distance2 = x1_prime - x2
  1293. return abs(y2_prime - y1) // disperse + distance2 // width + distance2 / 5000
  1294. def _nearest_edge_distance(
  1295. input_bbox,
  1296. match_bbox,
  1297. weight=[1, 1, 1, 1],
  1298. label="text",
  1299. no_mask_labels=[],
  1300. min_edge_distances_config=[],
  1301. tolerance_len=10,
  1302. ):
  1303. """
  1304. Calculate the nearest edge distance between two bounding boxes, considering directional weights.
  1305. Args:
  1306. input_bbox (list): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  1307. match_bbox (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  1308. weight (list, optional): Directional weights for the edge distances [left, right, up, down]. Defaults to [1, 1, 1, 1].
  1309. label (str, optional): The label/type of the object in the bounding box (e.g., 'text'). Defaults to 'text'.
  1310. no_mask_labels (list, optional): Labels for which no masking is applied when calculating edge distances. Defaults to an empty list.
  1311. min_edge_distances_config (list, optional): Configuration for minimum edge distances [min_edge_distance_x, min_edge_distance_y].
  1312. Defaults to [float('inf'), float('inf')].
  1313. Returns:
  1314. tuple: A tuple containing:
  1315. - The calculated minimum edge distance between the bounding boxes.
  1316. - A list with the minimum edge distances in the x and y directions.
  1317. """
  1318. match_bbox_iou = _calculate_overlap_area_2_minbox_area_ratio(
  1319. input_bbox,
  1320. match_bbox,
  1321. )
  1322. if match_bbox_iou > 0 and label not in no_mask_labels:
  1323. return 0, [0, 0]
  1324. if not min_edge_distances_config:
  1325. min_edge_distances_config = [float("inf"), float("inf")]
  1326. min_edge_distance_x, min_edge_distance_y = min_edge_distances_config
  1327. x1, y1, x2, y2 = input_bbox
  1328. x1_prime, y1_prime, x2_prime, y2_prime = match_bbox
  1329. direction_num = 0
  1330. distance_x = float("inf")
  1331. distance_y = float("inf")
  1332. distance = [float("inf")] * 4
  1333. # input_bbox is to the left of match_bbox
  1334. if x2 < x1_prime:
  1335. direction_num += 1
  1336. distance[0] = x1_prime - x2
  1337. if abs(distance[0] - min_edge_distance_x) <= tolerance_len:
  1338. distance_x = min_edge_distance_x * weight[0]
  1339. else:
  1340. distance_x = distance[0] * weight[0]
  1341. # input_bbox is to the right of match_bbox
  1342. elif x1 > x2_prime:
  1343. direction_num += 1
  1344. distance[1] = x1 - x2_prime
  1345. if abs(distance[1] - min_edge_distance_x) <= tolerance_len:
  1346. distance_x = min_edge_distance_x * weight[1]
  1347. else:
  1348. distance_x = distance[1] * weight[1]
  1349. elif match_bbox_iou > 0:
  1350. distance[0] = 0
  1351. distance_x = 0
  1352. # input_bbox is above match_bbox
  1353. if y2 < y1_prime:
  1354. direction_num += 1
  1355. distance[2] = y1_prime - y2
  1356. if abs(distance[2] - min_edge_distance_y) <= tolerance_len:
  1357. distance_y = min_edge_distance_y * weight[2]
  1358. else:
  1359. distance_y = distance[2] * weight[2]
  1360. if label in no_mask_labels:
  1361. distance_y = max(0.1, distance_y) * 100
  1362. # input_bbox is below match_bbox
  1363. elif y1 > y2_prime:
  1364. direction_num += 1
  1365. distance[3] = y1 - y2_prime
  1366. if abs(distance[3] - min_edge_distance_y) <= tolerance_len:
  1367. distance_y = min_edge_distance_y * weight[3]
  1368. else:
  1369. distance_y = distance[3] * weight[3]
  1370. elif match_bbox_iou > 0:
  1371. distance[2] = 0
  1372. distance_y = 0
  1373. if direction_num == 2:
  1374. return (distance_x + distance_y), [
  1375. min(distance[0], distance[1]),
  1376. min(distance[2], distance[3]),
  1377. ]
  1378. else:
  1379. return min(distance_x, distance_y), [
  1380. min(distance[0], distance[1]),
  1381. min(distance[2], distance[3]),
  1382. ]
  1383. def _get_weights(label, horizontal):
  1384. """Define weights based on the label and orientation."""
  1385. if label == "doc_title":
  1386. return (
  1387. [1, 0.1, 0.1, 1] if horizontal else [0.2, 0.1, 1, 1]
  1388. ) # left-down , right-left
  1389. elif label in [
  1390. "paragraph_title",
  1391. "abstract",
  1392. "figure_title",
  1393. "chart_title",
  1394. "image",
  1395. "seal",
  1396. "chart",
  1397. "figure",
  1398. ]:
  1399. return [1, 1, 0.1, 1] # down
  1400. else:
  1401. return [1, 1, 1, 0.1] # up
  1402. def _nearest_iou_edge_distance(
  1403. input_bbox,
  1404. match_bbox,
  1405. label,
  1406. vision_labels,
  1407. no_mask_labels,
  1408. median_width=-1,
  1409. title_labels=[],
  1410. title_text=[],
  1411. sub_title=[],
  1412. min_distance_config=[],
  1413. tolerance_len=10,
  1414. ):
  1415. """
  1416. Calculate the nearest IOU edge distance between two bounding boxes.
  1417. Args:
  1418. input_bbox (list): The bounding box coordinates [x1, y1, x2, y2] of the input object.
  1419. match_bbox (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
  1420. label (str): The label/type of the object in the bounding box (e.g., 'image', 'text', etc.).
  1421. no_mask_labels (list): Labels for which no masking is applied when calculating edge distances.
  1422. median_width (int, optional): The median width for title dispersion calculation. Defaults to -1.
  1423. title_labels (list, optional): Labels that indicate the object is a title. Defaults to an empty list.
  1424. title_text (list, optional): Text content associated with title labels. Defaults to an empty list.
  1425. sub_title (list, optional): List of subtitle bounding boxes to adjust the input_bbox. Defaults to an empty list.
  1426. min_distance_config (list, optional): Configuration for minimum distances [min_edge_distances_config, up_edge_distances_config, total_distance].
  1427. Returns:
  1428. tuple: A tuple containing the calculated distance and updated minimum distance configuration.
  1429. """
  1430. x1, y1, x2, y2 = input_bbox
  1431. x1_prime, y1_prime, x2_prime, y2_prime = match_bbox
  1432. min_edge_distances_config, up_edge_distances_config, total_distance = (
  1433. min_distance_config
  1434. )
  1435. iou_distance = 0
  1436. if label in vision_labels:
  1437. horizontal1 = horizontal2 = True
  1438. else:
  1439. horizontal1 = _get_bbox_direction(input_bbox)
  1440. horizontal2 = _get_bbox_direction(match_bbox, 3)
  1441. if (
  1442. horizontal1 != horizontal2
  1443. or _get_projection_iou(input_bbox, match_bbox, horizontal1) < 0.01
  1444. ):
  1445. iou_distance = 1
  1446. elif label == "doc_title" or (label in title_labels and title_text):
  1447. # Calculate distance for titles
  1448. disperse = max(1, median_width)
  1449. width = x2 - x1
  1450. height = y2 - y1
  1451. if horizontal1:
  1452. return (
  1453. _calculate_horizontal_distance(
  1454. input_bbox,
  1455. match_bbox,
  1456. height,
  1457. disperse,
  1458. title_text,
  1459. ),
  1460. min_distance_config,
  1461. )
  1462. else:
  1463. return (
  1464. _calculate_vertical_distance(
  1465. input_bbox,
  1466. match_bbox,
  1467. width,
  1468. disperse,
  1469. title_text,
  1470. ),
  1471. min_distance_config,
  1472. )
  1473. # Adjust input_bbox based on sub_title
  1474. if sub_title:
  1475. for sub in sub_title:
  1476. x1_, y1_, x2_, y2_ = sub
  1477. x1, y1, x2, y2 = (
  1478. min(x1, x1_),
  1479. min(
  1480. y1,
  1481. y1_,
  1482. ),
  1483. max(x2, x2_),
  1484. max(y2, y2_),
  1485. )
  1486. input_bbox = [x1, y1, x2, y2]
  1487. # Calculate edge distance
  1488. weight = _get_weights(label, horizontal1)
  1489. if label == "abstract":
  1490. tolerance_len *= 3
  1491. edge_distance, edge_distance_config = _nearest_edge_distance(
  1492. input_bbox,
  1493. match_bbox,
  1494. weight,
  1495. label=label,
  1496. no_mask_labels=no_mask_labels,
  1497. min_edge_distances_config=min_edge_distances_config,
  1498. tolerance_len=tolerance_len,
  1499. )
  1500. # Weights for combining distances
  1501. iou_edge_weight = [10**6, 10**3, 1, 0.001]
  1502. # Calculate up and left edge distances
  1503. up_edge_distance = y1_prime
  1504. left_edge_distance = x1_prime
  1505. if (
  1506. label in no_mask_labels or label == "paragraph_title" or label in vision_labels
  1507. ) and y1 > y2_prime:
  1508. up_edge_distance = -y2_prime
  1509. left_edge_distance = -x2_prime
  1510. min_up_edge_distance = up_edge_distances_config
  1511. if abs(min_up_edge_distance - up_edge_distance) <= tolerance_len:
  1512. up_edge_distance = min_up_edge_distance
  1513. # Calculate total distance
  1514. distance = (
  1515. iou_distance * iou_edge_weight[0]
  1516. + edge_distance * iou_edge_weight[1]
  1517. + up_edge_distance * iou_edge_weight[2]
  1518. + left_edge_distance * iou_edge_weight[3]
  1519. )
  1520. # Update minimum distance configuration if a smaller distance is found
  1521. if total_distance > distance:
  1522. edge_distance_config = [
  1523. min(min_edge_distances_config[0], edge_distance_config[0]),
  1524. min(min_edge_distances_config[1], edge_distance_config[1]),
  1525. ]
  1526. min_distance_config = [
  1527. edge_distance_config,
  1528. min(up_edge_distance, up_edge_distances_config),
  1529. distance,
  1530. ]
  1531. return distance, min_distance_config
  1532. def get_show_color(label):
  1533. label_colors = {
  1534. # Medium Blue (from 'titles_list')
  1535. "paragraph_title": (102, 102, 255, 100),
  1536. "doc_title": (255, 248, 220, 100), # Cornsilk
  1537. # Light Yellow (from 'tables_caption_list')
  1538. "table_title": (255, 255, 102, 100),
  1539. # Sky Blue (from 'imgs_caption_list')
  1540. "figure_title": (102, 178, 255, 100),
  1541. "chart_title": (221, 160, 221, 100), # Plum
  1542. "vision_footnote": (144, 238, 144, 100), # Light Green
  1543. # Deep Purple (from 'texts_list')
  1544. "text": (153, 0, 76, 100),
  1545. # Bright Green (from 'interequations_list')
  1546. "formula": (0, 255, 0, 100),
  1547. "abstract": (255, 239, 213, 100), # Papaya Whip
  1548. # Medium Green (from 'lists_list' and 'indexs_list')
  1549. "content": (40, 169, 92, 100),
  1550. # Neutral Gray (from 'dropped_bbox_list')
  1551. "seal": (158, 158, 158, 100),
  1552. # Olive Yellow (from 'tables_body_list')
  1553. "table": (204, 204, 0, 100),
  1554. # Bright Green (from 'imgs_body_list')
  1555. "image": (153, 255, 51, 100),
  1556. # Bright Green (from 'imgs_body_list')
  1557. "figure": (153, 255, 51, 100),
  1558. "chart": (216, 191, 216, 100), # Thistle
  1559. # Pale Yellow-Green (from 'tables_footnote_list')
  1560. "reference": (229, 255, 204, 100),
  1561. "algorithm": (255, 250, 240, 100), # Floral White
  1562. }
  1563. default_color = (158, 158, 158, 100)
  1564. return label_colors.get(label, default_color)