utils.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. __all__ = [
  15. "get_sub_regions_ocr_res",
  16. "get_show_color",
  17. "sorted_layout_boxes",
  18. ]
  19. import re
  20. from copy import deepcopy
  21. from typing import Dict, List, Optional, Tuple, Union
  22. import numpy as np
  23. from PIL import Image
  24. from ..components import convert_points_to_boxes
  25. from ..ocr.result import OCRResult
  26. def get_overlap_boxes_idx(src_boxes: np.ndarray, ref_boxes: np.ndarray) -> List:
  27. """
  28. Get the indices of source boxes that overlap with reference boxes based on a specified threshold.
  29. Args:
  30. src_boxes (np.ndarray): A 2D numpy array of source bounding boxes.
  31. ref_boxes (np.ndarray): A 2D numpy array of reference bounding boxes.
  32. Returns:
  33. match_idx_list (list): A list of indices of source boxes that overlap with reference boxes.
  34. """
  35. match_idx_list = []
  36. src_boxes_num = len(src_boxes)
  37. if src_boxes_num > 0 and len(ref_boxes) > 0:
  38. for rno in range(len(ref_boxes)):
  39. ref_box = ref_boxes[rno]
  40. x1 = np.maximum(ref_box[0], src_boxes[:, 0])
  41. y1 = np.maximum(ref_box[1], src_boxes[:, 1])
  42. x2 = np.minimum(ref_box[2], src_boxes[:, 2])
  43. y2 = np.minimum(ref_box[3], src_boxes[:, 3])
  44. pub_w = x2 - x1
  45. pub_h = y2 - y1
  46. match_idx = np.where((pub_w > 3) & (pub_h > 3))[0]
  47. match_idx_list.extend(match_idx)
  48. return match_idx_list
  49. def get_sub_regions_ocr_res(
  50. overall_ocr_res: OCRResult,
  51. object_boxes: List,
  52. flag_within: bool = True,
  53. return_match_idx: bool = False,
  54. ) -> OCRResult:
  55. """
  56. Filters OCR results to only include text boxes within specified object boxes based on a flag.
  57. Args:
  58. overall_ocr_res (OCRResult): The original OCR result containing all text boxes.
  59. object_boxes (list): A list of bounding boxes for the objects of interest.
  60. flag_within (bool): If True, only include text boxes within the object boxes. If False, exclude text boxes within the object boxes.
  61. return_match_idx (bool): If True, return the list of matching indices.
  62. Returns:
  63. OCRResult: A filtered OCR result containing only the relevant text boxes.
  64. """
  65. sub_regions_ocr_res = {}
  66. sub_regions_ocr_res["rec_polys"] = []
  67. sub_regions_ocr_res["rec_texts"] = []
  68. sub_regions_ocr_res["rec_scores"] = []
  69. sub_regions_ocr_res["rec_boxes"] = []
  70. overall_text_boxes = overall_ocr_res["rec_boxes"]
  71. match_idx_list = get_overlap_boxes_idx(overall_text_boxes, object_boxes)
  72. match_idx_list = list(set(match_idx_list))
  73. for box_no in range(len(overall_text_boxes)):
  74. if flag_within:
  75. if box_no in match_idx_list:
  76. flag_match = True
  77. else:
  78. flag_match = False
  79. else:
  80. if box_no not in match_idx_list:
  81. flag_match = True
  82. else:
  83. flag_match = False
  84. if flag_match:
  85. sub_regions_ocr_res["rec_polys"].append(
  86. overall_ocr_res["rec_polys"][box_no]
  87. )
  88. sub_regions_ocr_res["rec_texts"].append(
  89. overall_ocr_res["rec_texts"][box_no]
  90. )
  91. sub_regions_ocr_res["rec_scores"].append(
  92. overall_ocr_res["rec_scores"][box_no]
  93. )
  94. sub_regions_ocr_res["rec_boxes"].append(
  95. overall_ocr_res["rec_boxes"][box_no]
  96. )
  97. for key in ["rec_polys", "rec_scores", "rec_boxes"]:
  98. sub_regions_ocr_res[key] = np.array(sub_regions_ocr_res[key])
  99. return (
  100. (sub_regions_ocr_res, match_idx_list)
  101. if return_match_idx
  102. else sub_regions_ocr_res
  103. )
  104. def sorted_layout_boxes(res, w):
  105. """
  106. Sort text boxes in order from top to bottom, left to right
  107. Args:
  108. res: List of dictionaries containing layout information.
  109. w: Width of image.
  110. Returns:
  111. List of dictionaries containing sorted layout information.
  112. """
  113. num_boxes = len(res)
  114. if num_boxes == 1:
  115. return res
  116. # Sort on the y axis first or sort it on the x axis
  117. sorted_boxes = sorted(res, key=lambda x: (x["block_bbox"][1], x["block_bbox"][0]))
  118. _boxes = list(sorted_boxes)
  119. new_res = []
  120. res_left = []
  121. res_right = []
  122. i = 0
  123. while True:
  124. if i >= num_boxes:
  125. break
  126. # Check that the bbox is on the left
  127. elif (
  128. _boxes[i]["block_bbox"][0] < w / 4
  129. and _boxes[i]["block_bbox"][2] < 3 * w / 5
  130. ):
  131. res_left.append(_boxes[i])
  132. i += 1
  133. elif _boxes[i]["block_bbox"][0] > 2 * w / 5:
  134. res_right.append(_boxes[i])
  135. i += 1
  136. else:
  137. new_res += res_left
  138. new_res += res_right
  139. new_res.append(_boxes[i])
  140. res_left = []
  141. res_right = []
  142. i += 1
  143. res_left = sorted(res_left, key=lambda x: (x["block_bbox"][1]))
  144. res_right = sorted(res_right, key=lambda x: (x["block_bbox"][1]))
  145. if res_left:
  146. new_res += res_left
  147. if res_right:
  148. new_res += res_right
  149. return new_res
  150. def calculate_projection_overlap_ratio(
  151. bbox1: List[float],
  152. bbox2: List[float],
  153. orientation: str = "horizontal",
  154. mode="union",
  155. ) -> float:
  156. """
  157. Calculate the IoU of lines between two bounding boxes.
  158. Args:
  159. bbox1 (List[float]): First bounding box [x_min, y_min, x_max, y_max].
  160. bbox2 (List[float]): Second bounding box [x_min, y_min, x_max, y_max].
  161. orientation (str): orientation of the projection, "horizontal" or "vertical".
  162. Returns:
  163. float: Line overlap ratio. Returns 0 if there is no overlap.
  164. """
  165. start_index, end_index = 1, 3
  166. if orientation == "horizontal":
  167. start_index, end_index = 0, 2
  168. intersection_start = max(bbox1[start_index], bbox2[start_index])
  169. intersection_end = min(bbox1[end_index], bbox2[end_index])
  170. overlap = intersection_end - intersection_start
  171. if overlap <= 0:
  172. return 0
  173. if mode == "union":
  174. ref_width = max(bbox1[end_index], bbox2[end_index]) - min(
  175. bbox1[start_index], bbox2[start_index]
  176. )
  177. elif mode == "small":
  178. ref_width = min(
  179. bbox1[end_index] - bbox1[start_index], bbox2[end_index] - bbox2[start_index]
  180. )
  181. elif mode == "large":
  182. ref_width = max(
  183. bbox1[end_index] - bbox1[start_index], bbox2[end_index] - bbox2[start_index]
  184. )
  185. else:
  186. raise ValueError(
  187. f"Invalid mode {mode}, must be one of ['union', 'small', 'large']."
  188. )
  189. return overlap / ref_width if ref_width > 0 else 0.0
  190. def calculate_overlap_ratio(
  191. bbox1: Union[list, tuple], bbox2: Union[list, tuple], mode="union"
  192. ) -> float:
  193. """
  194. Calculate the overlap ratio between two bounding boxes.
  195. Args:
  196. bbox1 (list or tuple): The first bounding box, format [x_min, y_min, x_max, y_max]
  197. bbox2 (list or tuple): The second bounding box, format [x_min, y_min, x_max, y_max]
  198. mode (str): The mode of calculation, either 'union', 'small', or 'large'.
  199. Returns:
  200. float: The overlap ratio value between the two bounding boxes
  201. """
  202. x_min_inter = max(bbox1[0], bbox2[0])
  203. y_min_inter = max(bbox1[1], bbox2[1])
  204. x_max_inter = min(bbox1[2], bbox2[2])
  205. y_max_inter = min(bbox1[3], bbox2[3])
  206. inter_width = max(0, x_max_inter - x_min_inter)
  207. inter_height = max(0, y_max_inter - y_min_inter)
  208. inter_area = inter_width * inter_height
  209. bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
  210. bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
  211. if mode == "union":
  212. ref_area = bbox1_area + bbox2_area - inter_area
  213. elif mode == "small":
  214. ref_area = min(bbox1_area, bbox2_area)
  215. elif mode == "large":
  216. ref_area = max(bbox1_area, bbox2_area)
  217. else:
  218. raise ValueError(
  219. f"Invalid mode {mode}, must be one of ['union', 'small', 'large']."
  220. )
  221. if ref_area == 0:
  222. return 0.0
  223. return inter_area / ref_area
  224. def group_boxes_into_lines(ocr_rec_res, line_height_iou_threshold):
  225. rec_boxes = ocr_rec_res["boxes"]
  226. rec_texts = ocr_rec_res["rec_texts"]
  227. rec_labels = ocr_rec_res["rec_labels"]
  228. text_boxes = [
  229. rec_boxes[i] for i in range(len(rec_boxes)) if rec_labels[i] == "text"
  230. ]
  231. text_orientation = calculate_text_orientation(text_boxes)
  232. match_orientation = "vertical" if text_orientation == "horizontal" else "horizontal"
  233. spans = list(zip(rec_boxes, rec_texts, rec_labels))
  234. sort_index = 1
  235. reverse = False
  236. if text_orientation == "vertical":
  237. sort_index = 0
  238. reverse = True
  239. spans.sort(key=lambda span: span[0][sort_index], reverse=reverse)
  240. spans = [list(span) for span in spans]
  241. lines = []
  242. line = [spans[0]]
  243. line_region_box = spans[0][0][:]
  244. # merge line
  245. for span in spans[1:]:
  246. rec_bbox = span[0]
  247. if (
  248. calculate_projection_overlap_ratio(
  249. line_region_box, rec_bbox, match_orientation, mode="small"
  250. )
  251. >= line_height_iou_threshold
  252. ):
  253. line.append(span)
  254. line_region_box[1] = min(line_region_box[1], rec_bbox[1])
  255. line_region_box[3] = max(line_region_box[3], rec_bbox[3])
  256. else:
  257. lines.append(line)
  258. line = [span]
  259. line_region_box = rec_bbox[:]
  260. lines.append(line)
  261. return lines, text_orientation
  262. def calculate_minimum_enclosing_bbox(bboxes):
  263. """
  264. Calculate the minimum enclosing bounding box for a list of bounding boxes.
  265. Args:
  266. bboxes (list): A list of bounding boxes represented as lists of four integers [x1, y1, x2, y2].
  267. Returns:
  268. list: The minimum enclosing bounding box represented as a list of four integers [x1, y1, x2, y2].
  269. """
  270. if not bboxes:
  271. raise ValueError("The list of bounding boxes is empty.")
  272. # Convert the list of bounding boxes to a NumPy array
  273. bboxes_array = np.array(bboxes)
  274. # Compute the minimum and maximum values along the respective axes
  275. min_x = np.min(bboxes_array[:, 0])
  276. min_y = np.min(bboxes_array[:, 1])
  277. max_x = np.max(bboxes_array[:, 2])
  278. max_y = np.max(bboxes_array[:, 3])
  279. # Return the minimum enclosing bounding box
  280. return [min_x, min_y, max_x, max_y]
  281. def calculate_text_orientation(
  282. bboxes: List[List[int]], orientation_ratio: float = 1.5
  283. ) -> bool:
  284. """
  285. Calculate the orientation of the text based on the bounding boxes.
  286. Args:
  287. bboxes (list): A list of bounding boxes.
  288. orientation_ratio (float): Ratio for determining orientation. Default is 1.5.
  289. Returns:
  290. str: "horizontal" or "vertical".
  291. """
  292. horizontal_box_num = 0
  293. for bbox in bboxes:
  294. if len(bbox) != 4:
  295. raise ValueError(
  296. "Invalid bounding box format. Expected a list of length 4."
  297. )
  298. x1, y1, x2, y2 = bbox
  299. width = x2 - x1
  300. height = y2 - y1
  301. horizontal_box_num += 1 if width * orientation_ratio >= height else 0
  302. return "horizontal" if horizontal_box_num >= len(bboxes) * 0.5 else "vertical"
  303. def is_english_letter(char):
  304. return bool(re.match(r"^[A-Za-z]$", char))
  305. def format_line(
  306. line: List[List[Union[List[int], str]]],
  307. block_right_coordinate: int,
  308. last_line_span_limit: int = 10,
  309. block_label: str = "text",
  310. # delimiter_map: Dict = {},
  311. ) -> None:
  312. """
  313. Format a line of text spans based on layout constraints.
  314. Args:
  315. line (list): A list of spans, where each span is a list containing a bounding box and text.
  316. block_left_coordinate (int): The minimum x-coordinate of the layout bounding box.
  317. block_right_coordinate (int): The maximum x-coordinate of the layout bounding box.
  318. first_line_span_limit (int): The limit for the number of pixels before the first span that should be considered part of the first line. Default is 10.
  319. last_line_span_limit (int): The limit for the number of pixels after the last span that should be considered part of the last line. Default is 10.
  320. block_label (str): The label associated with the entire block. Default is 'text'.
  321. Returns:
  322. None: The function modifies the line in place.
  323. """
  324. last_span_box = line[-1][0]
  325. for span in line:
  326. if span[2] == "formula" and block_label != "formula":
  327. if len(line) > 1:
  328. span[1] = f"${span[1]}$"
  329. else:
  330. span[1] = f"\n${span[1]}$"
  331. line_text = " ".join([span[1] for span in line])
  332. need_new_line = False
  333. if (
  334. block_right_coordinate - last_span_box[2] > last_line_span_limit
  335. and not line_text.endswith("-")
  336. and len(line_text) > 0
  337. and not is_english_letter(line_text[-1])
  338. ):
  339. need_new_line = True
  340. if line_text.endswith("-"):
  341. line_text = line_text[:-1]
  342. elif (
  343. len(line_text) > 0 and is_english_letter(line_text[-1])
  344. ) or line_text.endswith("$"):
  345. line_text += " "
  346. return line_text, need_new_line
  347. def split_boxes_by_projection(spans: List[List[int]], orientation, offset=1e-5):
  348. """
  349. Check if there is any complete containment in the x-orientation
  350. between the bounding boxes and split the containing box accordingly.
  351. Args:
  352. spans (list of lists): Each element is a list containing an ndarray of length 4, a text string, and a label.
  353. orientation: 'horizontal' or 'vertical', indicating whether the spans are arranged horizontally or vertically.
  354. offset (float): A small offset value to ensure that the split boxes are not too close to the original boxes.
  355. Returns:
  356. A new list of boxes, including split boxes, with the same `rec_text` and `label` attributes.
  357. """
  358. def is_projection_contained(box_a, box_b, start_idx, end_idx):
  359. """Check if box_a completely contains box_b in the x-orientation."""
  360. return box_a[start_idx] <= box_b[start_idx] and box_a[end_idx] >= box_b[end_idx]
  361. new_boxes = []
  362. if orientation == "horizontal":
  363. projection_start_index, projection_end_index = 0, 2
  364. else:
  365. projection_start_index, projection_end_index = 1, 3
  366. for i in range(len(spans)):
  367. span = spans[i]
  368. box_a, text, label = span
  369. is_split = False
  370. for j in range(len(spans)):
  371. if i == j:
  372. continue
  373. box_b = spans[j][0]
  374. if is_projection_contained(
  375. box_a, box_b, projection_start_index, projection_end_index
  376. ):
  377. is_split = True
  378. # Split box_a based on the x-coordinates of box_b
  379. if box_a[projection_start_index] < box_b[projection_start_index]:
  380. w = (
  381. box_b[projection_start_index]
  382. - offset
  383. - box_a[projection_start_index]
  384. )
  385. if w > 1:
  386. box_a[projection_end_index] = (
  387. box_b[projection_start_index] - offset
  388. )
  389. new_boxes.append(
  390. [
  391. np.array(box_a),
  392. text,
  393. label,
  394. ]
  395. )
  396. if box_a[projection_end_index] > box_b[projection_end_index]:
  397. w = (
  398. box_a[projection_end_index]
  399. - box_b[projection_end_index]
  400. + offset
  401. )
  402. if w > 1:
  403. box_a[projection_start_index] = (
  404. box_b[projection_end_index] + offset
  405. )
  406. span = [
  407. np.array(box_a),
  408. text,
  409. label,
  410. ]
  411. if j == len(spans) - 1 and is_split:
  412. new_boxes.append(span)
  413. if not is_split:
  414. new_boxes.append(span)
  415. return new_boxes
  416. def remove_extra_space(input_text: str) -> str:
  417. """
  418. Process the input text to handle spaces.
  419. The function removes multiple consecutive spaces between Chinese characters and ensures that
  420. only a single space is retained between Chinese and non-Chinese characters.
  421. Args:
  422. input_text (str): The text to be processed.
  423. Returns:
  424. str: The processed text with properly formatted spaces.
  425. """
  426. # Remove spaces between Chinese characters
  427. text_without_spaces = re.sub(
  428. r"(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])", "", input_text
  429. )
  430. # Ensure single space between Chinese and non-Chinese characters
  431. text_with_single_spaces = re.sub(
  432. r"(?<=[\u4e00-\u9fff])\s+(?=[^\u4e00-\u9fff])|(?<=[^\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])",
  433. " ",
  434. text_without_spaces,
  435. )
  436. # Reduce any remaining consecutive spaces to a single space
  437. final_text = re.sub(r"\s+", " ", text_with_single_spaces).strip()
  438. return final_text
  439. def gather_imgs(original_img, layout_det_objs):
  440. imgs_in_doc = []
  441. for det_obj in layout_det_objs:
  442. if det_obj["label"] in ("image", "chart"):
  443. x_min, y_min, x_max, y_max = list(map(int, det_obj["coordinate"]))
  444. img_path = f"imgs/img_in_table_box_{x_min}_{y_min}_{x_max}_{y_max}.jpg"
  445. img = Image.fromarray(original_img[y_min:y_max, x_min:x_max, ::-1])
  446. imgs_in_doc.append(
  447. {
  448. "path": img_path,
  449. "img": img,
  450. "coordinate": (x_min, y_min, x_max, y_max),
  451. "score": det_obj["score"],
  452. }
  453. )
  454. return imgs_in_doc
  455. def _get_minbox_if_overlap_by_ratio(
  456. bbox1: Union[List[int], Tuple[int, int, int, int]],
  457. bbox2: Union[List[int], Tuple[int, int, int, int]],
  458. ratio: float,
  459. smaller: bool = True,
  460. ) -> Optional[Union[List[int], Tuple[int, int, int, int]]]:
  461. """
  462. Determine if the overlap area between two bounding boxes exceeds a given ratio
  463. and return the smaller (or larger) bounding box based on the `smaller` flag.
  464. Args:
  465. bbox1 (Union[List[int], Tuple[int, int, int, int]]): Coordinates of the first bounding box [x_min, y_min, x_max, y_max].
  466. bbox2 (Union[List[int], Tuple[int, int, int, int]]): Coordinates of the second bounding box [x_min, y_min, x_max, y_max].
  467. ratio (float): The overlap ratio threshold.
  468. smaller (bool): If True, return the smaller bounding box; otherwise, return the larger one.
  469. Returns:
  470. Optional[Union[List[int], Tuple[int, int, int, int]]]:
  471. The selected bounding box or None if the overlap ratio is not exceeded.
  472. """
  473. # Calculate the areas of both bounding boxes
  474. area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
  475. area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
  476. # Calculate the overlap ratio using a helper function
  477. overlap_ratio = calculate_overlap_ratio(bbox1, bbox2, mode="small")
  478. # Check if the overlap ratio exceeds the threshold
  479. if overlap_ratio > ratio:
  480. if (area1 <= area2 and smaller) or (area1 >= area2 and not smaller):
  481. return 1
  482. else:
  483. return 2
  484. return None
  485. def remove_overlap_blocks(
  486. blocks: List[Dict[str, List[int]]], threshold: float = 0.65, smaller: bool = True
  487. ) -> Tuple[List[Dict[str, List[int]]], List[Dict[str, List[int]]]]:
  488. """
  489. Remove overlapping blocks based on a specified overlap ratio threshold.
  490. Args:
  491. blocks (List[Dict[str, List[int]]]): List of block dictionaries, each containing a 'block_bbox' key.
  492. threshold (float): Ratio threshold to determine significant overlap.
  493. smaller (bool): If True, the smaller block in overlap is removed.
  494. Returns:
  495. Tuple[List[Dict[str, List[int]]], List[Dict[str, List[int]]]]:
  496. A tuple containing the updated list of blocks and a list of dropped blocks.
  497. """
  498. dropped_indexes = set()
  499. blocks = deepcopy(blocks)
  500. # Iterate over each pair of blocks to find overlaps
  501. for i, block1 in enumerate(blocks["boxes"]):
  502. for j in range(i + 1, len(blocks["boxes"])):
  503. block2 = blocks["boxes"][j]
  504. # Skip blocks that are already marked for removal
  505. if i in dropped_indexes or j in dropped_indexes:
  506. continue
  507. # Check for overlap and determine which block to remove
  508. overlap_box_index = _get_minbox_if_overlap_by_ratio(
  509. block1["coordinate"],
  510. block2["coordinate"],
  511. threshold,
  512. smaller=smaller,
  513. )
  514. if overlap_box_index is not None:
  515. if block1["label"] == "image" and block2["label"] == "image":
  516. # Determine which block to remove based on overlap_box_index
  517. if overlap_box_index == 1:
  518. drop_index = i
  519. else:
  520. drop_index = j
  521. elif block1["label"] == "image" and block2["label"] != "image":
  522. drop_index = i
  523. elif block1["label"] != "image" and block2["label"] == "image":
  524. drop_index = j
  525. elif overlap_box_index == 1:
  526. drop_index = i
  527. else:
  528. drop_index = j
  529. dropped_indexes.add(drop_index)
  530. # Remove marked blocks from the original list
  531. for index in sorted(dropped_indexes, reverse=True):
  532. del blocks["boxes"][index]
  533. return blocks
  534. def get_bbox_intersection(bbox1, bbox2, return_format="bbox"):
  535. """
  536. Compute the intersection of two bounding boxes, supporting both 4-coordinate and 8-coordinate formats.
  537. Args:
  538. bbox1 (tuple): The first bounding box, either in 4-coordinate format (x_min, y_min, x_max, y_max)
  539. or 8-coordinate format (x1, y1, x2, y2, x3, y3, x4, y4).
  540. bbox2 (tuple): The second bounding box in the same format as bbox1.
  541. return_format (str): The format of the output intersection, either 'bbox' or 'poly'.
  542. Returns:
  543. tuple or None: The intersection bounding box in the specified format, or None if there is no intersection.
  544. """
  545. bbox1 = np.array(bbox1)
  546. bbox2 = np.array(bbox2)
  547. # Convert both bounding boxes to rectangles
  548. rect1 = bbox1 if len(bbox1.shape) == 1 else convert_points_to_boxes([bbox1])[0]
  549. rect2 = bbox2 if len(bbox2.shape) == 1 else convert_points_to_boxes([bbox2])[0]
  550. # Calculate the intersection rectangle
  551. x_min_inter = max(rect1[0], rect2[0])
  552. y_min_inter = max(rect1[1], rect2[1])
  553. x_max_inter = min(rect1[2], rect2[2])
  554. y_max_inter = min(rect1[3], rect2[3])
  555. # Check if there is an intersection
  556. if x_min_inter >= x_max_inter or y_min_inter >= y_max_inter:
  557. return None
  558. if return_format == "bbox":
  559. return np.array([x_min_inter, y_min_inter, x_max_inter, y_max_inter])
  560. elif return_format == "poly":
  561. return np.array(
  562. [
  563. [x_min_inter, y_min_inter],
  564. [x_max_inter, y_min_inter],
  565. [x_max_inter, y_max_inter],
  566. [x_min_inter, y_max_inter],
  567. ],
  568. dtype=np.int16,
  569. )
  570. else:
  571. raise ValueError("return_format must be either 'bbox' or 'poly'.")
  572. def shrink_supplement_region_bbox(
  573. supplement_region_bbox,
  574. ref_region_bbox,
  575. image_width,
  576. image_height,
  577. block_idxes_set,
  578. block_bboxes,
  579. parameters_config,
  580. ) -> List:
  581. """
  582. Shrink the supplement region bbox according to the reference region bbox and match the block bboxes.
  583. Args:
  584. supplement_region_bbox (list): The supplement region bbox.
  585. ref_region_bbox (list): The reference region bbox.
  586. image_width (int): The width of the image.
  587. image_height (int): The height of the image.
  588. block_idxes_set (set): The indexes of the blocks that intersect with the region bbox.
  589. block_bboxes (dict): The dictionary of block bboxes.
  590. parameters_config (dict): The configuration parameters.
  591. Returns:
  592. list: The new region bbox and the matched block idxes.
  593. """
  594. x1, y1, x2, y2 = supplement_region_bbox
  595. x1_prime, y1_prime, x2_prime, y2_prime = ref_region_bbox
  596. index_conversion_map = {0: 2, 1: 3, 2: 0, 3: 1}
  597. edge_distance_list = [
  598. (x1_prime - x1) / image_width,
  599. (y1_prime - y1) / image_height,
  600. (x2 - x2_prime) / image_width,
  601. (y2 - y2_prime) / image_height,
  602. ]
  603. edge_distance_list_tmp = edge_distance_list[:]
  604. min_distance = min(edge_distance_list)
  605. src_index = index_conversion_map[edge_distance_list.index(min_distance)]
  606. if len(block_idxes_set) == 0:
  607. return supplement_region_bbox, []
  608. for _ in range(3):
  609. dst_index = index_conversion_map[src_index]
  610. tmp_region_bbox = supplement_region_bbox[:]
  611. tmp_region_bbox[dst_index] = ref_region_bbox[src_index]
  612. iner_block_idxes, split_block_idxes = [], []
  613. for block_idx in block_idxes_set:
  614. overlap_ratio = calculate_overlap_ratio(
  615. tmp_region_bbox, block_bboxes[block_idx], mode="small"
  616. )
  617. if overlap_ratio > parameters_config["region"].get(
  618. "match_block_overlap_ratio_threshold", 0.8
  619. ):
  620. iner_block_idxes.append(block_idx)
  621. elif overlap_ratio > parameters_config["region"].get(
  622. "split_block_overlap_ratio_threshold", 0.4
  623. ):
  624. split_block_idxes.append(block_idx)
  625. if len(iner_block_idxes) > 0:
  626. if len(split_block_idxes) > 0:
  627. for split_block_idx in split_block_idxes:
  628. split_block_bbox = block_bboxes[split_block_idx]
  629. x1, y1, x2, y2 = tmp_region_bbox
  630. x1_prime, y1_prime, x2_prime, y2_prime = split_block_bbox
  631. edge_distance_list = [
  632. (x1_prime - x1) / image_width,
  633. (y1_prime - y1) / image_height,
  634. (x2 - x2_prime) / image_width,
  635. (y2 - y2_prime) / image_height,
  636. ]
  637. max_distance = max(edge_distance_list)
  638. src_index = edge_distance_list.index(max_distance)
  639. dst_index = index_conversion_map[src_index]
  640. tmp_region_bbox[dst_index] = split_block_bbox[src_index]
  641. tmp_region_bbox, iner_idxes = shrink_supplement_region_bbox(
  642. tmp_region_bbox,
  643. ref_region_bbox,
  644. image_width,
  645. image_height,
  646. iner_block_idxes,
  647. block_bboxes,
  648. parameters_config,
  649. )
  650. if len(iner_idxes) == 0:
  651. continue
  652. matched_bboxes = [block_bboxes[idx] for idx in iner_block_idxes]
  653. supplement_region_bbox = calculate_minimum_enclosing_bbox(matched_bboxes)
  654. break
  655. else:
  656. edge_distance_list_tmp = [
  657. x for x in edge_distance_list_tmp if x != min_distance
  658. ]
  659. min_distance = min(edge_distance_list_tmp)
  660. src_index = index_conversion_map[edge_distance_list.index(min_distance)]
  661. return supplement_region_bbox, iner_block_idxes
  662. def update_region_box(bbox, region_box):
  663. if region_box is None:
  664. return bbox
  665. x1, y1, x2, y2 = bbox
  666. x1_region, y1_region, x2_region, y2_region = region_box
  667. x1_region = int(min(x1, x1_region))
  668. y1_region = int(min(y1, y1_region))
  669. x2_region = int(max(x2, x2_region))
  670. y2_region = int(max(y2, y2_region))
  671. region_box = [x1_region, y1_region, x2_region, y2_region]
  672. return region_box
  673. def convert_formula_res_to_ocr_format(formula_res_list: List, ocr_res: dict):
  674. for formula_res in formula_res_list:
  675. x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
  676. poly_points = [
  677. (x_min, y_min),
  678. (x_max, y_min),
  679. (x_max, y_max),
  680. (x_min, y_max),
  681. ]
  682. ocr_res["dt_polys"].append(poly_points)
  683. ocr_res["rec_texts"].append(f"{formula_res['rec_formula']}")
  684. ocr_res["rec_boxes"] = np.vstack(
  685. (ocr_res["rec_boxes"], [formula_res["dt_polys"]])
  686. )
  687. ocr_res["rec_labels"].append("formula")
  688. ocr_res["rec_polys"].append(poly_points)
  689. ocr_res["rec_scores"].append(1)
  690. def caculate_bbox_area(bbox):
  691. x1, y1, x2, y2 = bbox
  692. area = abs((x2 - x1) * (y2 - y1))
  693. return area
  694. def get_show_color(label: str) -> Tuple:
  695. label_colors = {
  696. # Medium Blue (from 'titles_list')
  697. "paragraph_title": (102, 102, 255, 100),
  698. "doc_title": (255, 248, 220, 100), # Cornsilk
  699. # Light Yellow (from 'tables_caption_list')
  700. "table_title": (255, 255, 102, 100),
  701. # Sky Blue (from 'imgs_caption_list')
  702. "figure_title": (102, 178, 255, 100),
  703. "chart_title": (221, 160, 221, 100), # Plum
  704. "vision_footnote": (144, 238, 144, 100), # Light Green
  705. # Deep Purple (from 'texts_list')
  706. "text": (153, 0, 76, 100),
  707. # Bright Green (from 'interequations_list')
  708. "formula": (0, 255, 0, 100),
  709. "abstract": (255, 239, 213, 100), # Papaya Whip
  710. # Medium Green (from 'lists_list' and 'indexs_list')
  711. "content": (40, 169, 92, 100),
  712. # Neutral Gray (from 'dropped_bbox_list')
  713. "seal": (158, 158, 158, 100),
  714. # Olive Yellow (from 'tables_body_list')
  715. "table": (204, 204, 0, 100),
  716. # Bright Green (from 'imgs_body_list')
  717. "image": (153, 255, 51, 100),
  718. # Bright Green (from 'imgs_body_list')
  719. "figure": (153, 255, 51, 100),
  720. "chart": (216, 191, 216, 100), # Thistle
  721. # Pale Yellow-Green (from 'tables_footnote_list')
  722. "reference": (229, 255, 204, 100),
  723. "algorithm": (255, 250, 240, 100), # Floral White
  724. }
  725. default_color = (158, 158, 158, 100)
  726. return label_colors.get(label, default_color)