utils.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. __all__ = [
  15. "get_sub_regions_ocr_res",
  16. "get_show_color",
  17. "sorted_layout_boxes",
  18. ]
  19. import re
  20. from copy import deepcopy
  21. from typing import Dict, List, Optional, Tuple, Union
  22. import numpy as np
  23. from PIL import Image
  24. from ..components import convert_points_to_boxes
  25. from ..ocr.result import OCRResult
  26. from .setting import REGION_SETTINGS
  27. def get_overlap_boxes_idx(src_boxes: np.ndarray, ref_boxes: np.ndarray) -> List:
  28. """
  29. Get the indices of source boxes that overlap with reference boxes based on a specified threshold.
  30. Args:
  31. src_boxes (np.ndarray): A 2D numpy array of source bounding boxes.
  32. ref_boxes (np.ndarray): A 2D numpy array of reference bounding boxes.
  33. Returns:
  34. match_idx_list (list): A list of indices of source boxes that overlap with reference boxes.
  35. """
  36. match_idx_list = []
  37. src_boxes_num = len(src_boxes)
  38. if src_boxes_num > 0 and len(ref_boxes) > 0:
  39. for rno in range(len(ref_boxes)):
  40. ref_box = ref_boxes[rno]
  41. x1 = np.maximum(ref_box[0], src_boxes[:, 0])
  42. y1 = np.maximum(ref_box[1], src_boxes[:, 1])
  43. x2 = np.minimum(ref_box[2], src_boxes[:, 2])
  44. y2 = np.minimum(ref_box[3], src_boxes[:, 3])
  45. pub_w = x2 - x1
  46. pub_h = y2 - y1
  47. match_idx = np.where((pub_w > 3) & (pub_h > 3))[0]
  48. match_idx_list.extend(match_idx)
  49. return match_idx_list
  50. def get_sub_regions_ocr_res(
  51. overall_ocr_res: OCRResult,
  52. object_boxes: List,
  53. flag_within: bool = True,
  54. return_match_idx: bool = False,
  55. ) -> OCRResult:
  56. """
  57. Filters OCR results to only include text boxes within specified object boxes based on a flag.
  58. Args:
  59. overall_ocr_res (OCRResult): The original OCR result containing all text boxes.
  60. object_boxes (list): A list of bounding boxes for the objects of interest.
  61. flag_within (bool): If True, only include text boxes within the object boxes. If False, exclude text boxes within the object boxes.
  62. return_match_idx (bool): If True, return the list of matching indices.
  63. Returns:
  64. OCRResult: A filtered OCR result containing only the relevant text boxes.
  65. """
  66. sub_regions_ocr_res = {}
  67. sub_regions_ocr_res["rec_polys"] = []
  68. sub_regions_ocr_res["rec_texts"] = []
  69. sub_regions_ocr_res["rec_scores"] = []
  70. sub_regions_ocr_res["rec_boxes"] = []
  71. overall_text_boxes = overall_ocr_res["rec_boxes"]
  72. match_idx_list = get_overlap_boxes_idx(overall_text_boxes, object_boxes)
  73. match_idx_list = list(set(match_idx_list))
  74. for box_no in range(len(overall_text_boxes)):
  75. if flag_within:
  76. if box_no in match_idx_list:
  77. flag_match = True
  78. else:
  79. flag_match = False
  80. else:
  81. if box_no not in match_idx_list:
  82. flag_match = True
  83. else:
  84. flag_match = False
  85. if flag_match:
  86. sub_regions_ocr_res["rec_polys"].append(
  87. overall_ocr_res["rec_polys"][box_no]
  88. )
  89. sub_regions_ocr_res["rec_texts"].append(
  90. overall_ocr_res["rec_texts"][box_no]
  91. )
  92. sub_regions_ocr_res["rec_scores"].append(
  93. overall_ocr_res["rec_scores"][box_no]
  94. )
  95. sub_regions_ocr_res["rec_boxes"].append(
  96. overall_ocr_res["rec_boxes"][box_no]
  97. )
  98. for key in ["rec_polys", "rec_scores", "rec_boxes"]:
  99. sub_regions_ocr_res[key] = np.array(sub_regions_ocr_res[key])
  100. return (
  101. (sub_regions_ocr_res, match_idx_list)
  102. if return_match_idx
  103. else sub_regions_ocr_res
  104. )
  105. def sorted_layout_boxes(res, w):
  106. """
  107. Sort text boxes in order from top to bottom, left to right
  108. Args:
  109. res: List of dictionaries containing layout information.
  110. w: Width of image.
  111. Returns:
  112. List of dictionaries containing sorted layout information.
  113. """
  114. num_boxes = len(res)
  115. if num_boxes == 1:
  116. return res
  117. # Sort on the y axis first or sort it on the x axis
  118. sorted_boxes = sorted(res, key=lambda x: (x["block_bbox"][1], x["block_bbox"][0]))
  119. _boxes = list(sorted_boxes)
  120. new_res = []
  121. res_left = []
  122. res_right = []
  123. i = 0
  124. while True:
  125. if i >= num_boxes:
  126. break
  127. # Check that the bbox is on the left
  128. elif (
  129. _boxes[i]["block_bbox"][0] < w / 4
  130. and _boxes[i]["block_bbox"][2] < 3 * w / 5
  131. ):
  132. res_left.append(_boxes[i])
  133. i += 1
  134. elif _boxes[i]["block_bbox"][0] > 2 * w / 5:
  135. res_right.append(_boxes[i])
  136. i += 1
  137. else:
  138. new_res += res_left
  139. new_res += res_right
  140. new_res.append(_boxes[i])
  141. res_left = []
  142. res_right = []
  143. i += 1
  144. res_left = sorted(res_left, key=lambda x: (x["block_bbox"][1]))
  145. res_right = sorted(res_right, key=lambda x: (x["block_bbox"][1]))
  146. if res_left:
  147. new_res += res_left
  148. if res_right:
  149. new_res += res_right
  150. return new_res
  151. def calculate_projection_overlap_ratio(
  152. bbox1: List[float],
  153. bbox2: List[float],
  154. direction: str = "horizontal",
  155. mode="union",
  156. ) -> float:
  157. """
  158. Calculate the IoU of lines between two bounding boxes.
  159. Args:
  160. bbox1 (List[float]): First bounding box [x_min, y_min, x_max, y_max].
  161. bbox2 (List[float]): Second bounding box [x_min, y_min, x_max, y_max].
  162. direction (str): direction of the projection, "horizontal" or "vertical".
  163. Returns:
  164. float: Line overlap ratio. Returns 0 if there is no overlap.
  165. """
  166. start_index, end_index = 1, 3
  167. if direction == "horizontal":
  168. start_index, end_index = 0, 2
  169. intersection_start = max(bbox1[start_index], bbox2[start_index])
  170. intersection_end = min(bbox1[end_index], bbox2[end_index])
  171. overlap = intersection_end - intersection_start
  172. if overlap <= 0:
  173. return 0
  174. if mode == "union":
  175. ref_width = max(bbox1[end_index], bbox2[end_index]) - min(
  176. bbox1[start_index], bbox2[start_index]
  177. )
  178. elif mode == "small":
  179. ref_width = min(
  180. bbox1[end_index] - bbox1[start_index], bbox2[end_index] - bbox2[start_index]
  181. )
  182. elif mode == "large":
  183. ref_width = max(
  184. bbox1[end_index] - bbox1[start_index], bbox2[end_index] - bbox2[start_index]
  185. )
  186. else:
  187. raise ValueError(
  188. f"Invalid mode {mode}, must be one of ['union', 'small', 'large']."
  189. )
  190. return overlap / ref_width if ref_width > 0 else 0.0
  191. def calculate_overlap_ratio(
  192. bbox1: Union[list, tuple], bbox2: Union[list, tuple], mode="union"
  193. ) -> float:
  194. """
  195. Calculate the overlap ratio between two bounding boxes.
  196. Args:
  197. bbox1 (list or tuple): The first bounding box, format [x_min, y_min, x_max, y_max]
  198. bbox2 (list or tuple): The second bounding box, format [x_min, y_min, x_max, y_max]
  199. mode (str): The mode of calculation, either 'union', 'small', or 'large'.
  200. Returns:
  201. float: The overlap ratio value between the two bounding boxes
  202. """
  203. x_min_inter = max(bbox1[0], bbox2[0])
  204. y_min_inter = max(bbox1[1], bbox2[1])
  205. x_max_inter = min(bbox1[2], bbox2[2])
  206. y_max_inter = min(bbox1[3], bbox2[3])
  207. inter_width = max(0, x_max_inter - x_min_inter)
  208. inter_height = max(0, y_max_inter - y_min_inter)
  209. inter_area = inter_width * inter_height
  210. bbox1_area = caculate_bbox_area(bbox1)
  211. bbox2_area = caculate_bbox_area(bbox2)
  212. if mode == "union":
  213. ref_area = bbox1_area + bbox2_area - inter_area
  214. elif mode == "small":
  215. ref_area = min(bbox1_area, bbox2_area)
  216. elif mode == "large":
  217. ref_area = max(bbox1_area, bbox2_area)
  218. else:
  219. raise ValueError(
  220. f"Invalid mode {mode}, must be one of ['union', 'small', 'large']."
  221. )
  222. if ref_area == 0:
  223. return 0.0
  224. return inter_area / ref_area
  225. def group_boxes_into_lines(ocr_rec_res, line_height_iou_threshold):
  226. rec_boxes = ocr_rec_res["boxes"]
  227. rec_texts = ocr_rec_res["rec_texts"]
  228. rec_labels = ocr_rec_res["rec_labels"]
  229. text_boxes = [
  230. rec_boxes[i] for i in range(len(rec_boxes)) if rec_labels[i] == "text"
  231. ]
  232. text_orientation = calculate_text_orientation(text_boxes)
  233. match_direction = "vertical" if text_orientation == "horizontal" else "horizontal"
  234. spans = list(zip(rec_boxes, rec_texts, rec_labels))
  235. sort_index = 1
  236. reverse = False
  237. if text_orientation == "vertical":
  238. sort_index = 0
  239. reverse = True
  240. spans.sort(key=lambda span: span[0][sort_index], reverse=reverse)
  241. spans = [list(span) for span in spans]
  242. lines = []
  243. line = [spans[0]]
  244. line_region_box = spans[0][0].copy()
  245. # merge line
  246. for span in spans[1:]:
  247. rec_bbox = span[0]
  248. if (
  249. calculate_projection_overlap_ratio(
  250. line_region_box, rec_bbox, match_direction, mode="small"
  251. )
  252. >= line_height_iou_threshold
  253. ):
  254. line.append(span)
  255. line_region_box[1] = min(line_region_box[1], rec_bbox[1])
  256. line_region_box[3] = max(line_region_box[3], rec_bbox[3])
  257. else:
  258. lines.append(line)
  259. line = [span]
  260. line_region_box = rec_bbox.copy()
  261. lines.append(line)
  262. return lines, text_orientation
  263. def calculate_minimum_enclosing_bbox(bboxes):
  264. """
  265. Calculate the minimum enclosing bounding box for a list of bounding boxes.
  266. Args:
  267. bboxes (list): A list of bounding boxes represented as lists of four integers [x1, y1, x2, y2].
  268. Returns:
  269. list: The minimum enclosing bounding box represented as a list of four integers [x1, y1, x2, y2].
  270. """
  271. if not bboxes:
  272. raise ValueError("The list of bounding boxes is empty.")
  273. # Convert the list of bounding boxes to a NumPy array
  274. bboxes_array = np.array(bboxes)
  275. # Compute the minimum and maximum values along the respective axes
  276. min_x = np.min(bboxes_array[:, 0])
  277. min_y = np.min(bboxes_array[:, 1])
  278. max_x = np.max(bboxes_array[:, 2])
  279. max_y = np.max(bboxes_array[:, 3])
  280. # Return the minimum enclosing bounding box
  281. return [min_x, min_y, max_x, max_y]
  282. def calculate_text_orientation(
  283. bboxes: List[List[int]], orientation_ratio: float = 1.5
  284. ) -> bool:
  285. """
  286. Calculate the orientation of the text based on the bounding boxes.
  287. Args:
  288. bboxes (list): A list of bounding boxes.
  289. orientation_ratio (float): Ratio for determining orientation. Default is 1.5.
  290. Returns:
  291. str: "horizontal" or "vertical".
  292. """
  293. horizontal_box_num = 0
  294. for bbox in bboxes:
  295. if len(bbox) != 4:
  296. raise ValueError(
  297. "Invalid bounding box format. Expected a list of length 4."
  298. )
  299. x1, y1, x2, y2 = bbox
  300. width = x2 - x1
  301. height = y2 - y1
  302. horizontal_box_num += 1 if width * orientation_ratio >= height else 0
  303. return "horizontal" if horizontal_box_num >= len(bboxes) * 0.5 else "vertical"
  304. def is_english_letter(char):
  305. return bool(re.match(r"^[A-Za-z]$", char))
  306. def is_non_breaking_punctuation(char):
  307. """
  308. 判断一个字符是否是不需要换行的标点符号,包括全角和半角的符号。
  309. :param char: str, 单个字符
  310. :return: bool, 如果字符是不需要换行的标点符号,返回True,否则返回False
  311. """
  312. non_breaking_punctuations = {
  313. ",", # 半角逗号
  314. ",", # 全角逗号
  315. "、", # 顿号
  316. ";", # 半角分号
  317. ";", # 全角分号
  318. ":", # 半角冒号
  319. ":", # 全角冒号
  320. }
  321. return char in non_breaking_punctuations
  322. def format_line(
  323. line: List[List[Union[List[int], str]]],
  324. block_right_coordinate: int,
  325. last_line_span_limit: int = 10,
  326. block_label: str = "text",
  327. ) -> None:
  328. """
  329. Format a line of text spans based on layout constraints.
  330. Args:
  331. line (list): A list of spans, where each span is a list containing a bounding box and text.
  332. block_left_coordinate (int): The minimum x-coordinate of the layout bounding box.
  333. block_right_coordinate (int): The maximum x-coordinate of the layout bounding box.
  334. first_line_span_limit (int): The limit for the number of pixels before the first span that should be considered part of the first line. Default is 10.
  335. last_line_span_limit (int): The limit for the number of pixels after the last span that should be considered part of the last line. Default is 10.
  336. block_label (str): The label associated with the entire block. Default is 'text'.
  337. Returns:
  338. None: The function modifies the line in place.
  339. """
  340. last_span_box = line[-1][0]
  341. for span in line:
  342. if span[2] == "formula" and block_label != "formula":
  343. if len(line) > 1:
  344. span[1] = f"${span[1]}$"
  345. else:
  346. span[1] = f"\n${span[1]}$"
  347. line_text = " ".join([span[1] for span in line])
  348. need_new_line = False
  349. if (
  350. block_right_coordinate - last_span_box[2] > last_line_span_limit
  351. and not line_text.endswith("-")
  352. and len(line_text) > 0
  353. and not is_english_letter(line_text[-1])
  354. and not is_non_breaking_punctuation(line_text[-1])
  355. ):
  356. need_new_line = True
  357. if line_text.endswith("-"):
  358. line_text = line_text[:-1]
  359. elif (
  360. len(line_text) > 0 and is_english_letter(line_text[-1])
  361. ) or line_text.endswith("$"):
  362. line_text += " "
  363. return line_text, need_new_line
  364. def split_boxes_by_projection(spans: List[List[int]], direction, offset=1e-5):
  365. """
  366. Check if there is any complete containment in the x-direction
  367. between the bounding boxes and split the containing box accordingly.
  368. Args:
  369. spans (list of lists): Each element is a list containing an ndarray of length 4, a text string, and a label.
  370. direction: 'horizontal' or 'vertical', indicating whether the spans are arranged horizontally or vertically.
  371. offset (float): A small offset value to ensure that the split boxes are not too close to the original boxes.
  372. Returns:
  373. A new list of boxes, including split boxes, with the same `rec_text` and `label` attributes.
  374. """
  375. def is_projection_contained(box_a, box_b, start_idx, end_idx):
  376. """Check if box_a completely contains box_b in the x-direction."""
  377. return box_a[start_idx] <= box_b[start_idx] and box_a[end_idx] >= box_b[end_idx]
  378. new_boxes = []
  379. if direction == "horizontal":
  380. projection_start_index, projection_end_index = 0, 2
  381. else:
  382. projection_start_index, projection_end_index = 1, 3
  383. for i in range(len(spans)):
  384. span = spans[i]
  385. is_split = False
  386. for j in range(i, len(spans)):
  387. box_b = spans[j][0]
  388. box_a, text, label = span
  389. if is_projection_contained(
  390. box_a, box_b, projection_start_index, projection_end_index
  391. ):
  392. is_split = True
  393. # Split box_a based on the x-coordinates of box_b
  394. if box_a[projection_start_index] < box_b[projection_start_index]:
  395. w = (
  396. box_b[projection_start_index]
  397. - offset
  398. - box_a[projection_start_index]
  399. )
  400. if w > 1:
  401. new_bbox = box_a.copy()
  402. new_bbox[projection_end_index] = (
  403. box_b[projection_start_index] - offset
  404. )
  405. new_boxes.append(
  406. [
  407. np.array(new_bbox),
  408. text,
  409. label,
  410. ]
  411. )
  412. if box_a[projection_end_index] > box_b[projection_end_index]:
  413. w = (
  414. box_a[projection_end_index]
  415. - box_b[projection_end_index]
  416. + offset
  417. )
  418. if w > 1:
  419. box_a[projection_start_index] = (
  420. box_b[projection_end_index] + offset
  421. )
  422. span = [
  423. np.array(box_a),
  424. text,
  425. label,
  426. ]
  427. if j == len(spans) - 1 and is_split:
  428. new_boxes.append(span)
  429. if not is_split:
  430. new_boxes.append(span)
  431. return new_boxes
  432. def remove_extra_space(input_text: str) -> str:
  433. """
  434. Process the input text to handle spaces.
  435. The function removes multiple consecutive spaces between Chinese characters and ensures that
  436. only a single space is retained between Chinese and non-Chinese characters.
  437. Args:
  438. input_text (str): The text to be processed.
  439. Returns:
  440. str: The processed text with properly formatted spaces.
  441. """
  442. # Remove spaces between Chinese characters
  443. text_without_spaces = re.sub(
  444. r"(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])", "", input_text
  445. )
  446. # Ensure single space between Chinese and non-Chinese characters
  447. text_with_single_spaces = re.sub(
  448. r"(?<=[\u4e00-\u9fff])\s+(?=[^\u4e00-\u9fff])|(?<=[^\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])",
  449. " ",
  450. text_without_spaces,
  451. )
  452. # Reduce any remaining consecutive spaces to a single space
  453. final_text = re.sub(r"\s+", " ", text_with_single_spaces).strip()
  454. return final_text
  455. def gather_imgs(original_img, layout_det_objs):
  456. imgs_in_doc = []
  457. for det_obj in layout_det_objs:
  458. if det_obj["label"] in ("image", "chart"):
  459. x_min, y_min, x_max, y_max = list(map(int, det_obj["coordinate"]))
  460. img_path = f"imgs/img_in_table_box_{x_min}_{y_min}_{x_max}_{y_max}.jpg"
  461. img = Image.fromarray(original_img[y_min:y_max, x_min:x_max, ::-1])
  462. imgs_in_doc.append(
  463. {
  464. "path": img_path,
  465. "img": img,
  466. "coordinate": (x_min, y_min, x_max, y_max),
  467. "score": det_obj["score"],
  468. }
  469. )
  470. return imgs_in_doc
  471. def _get_minbox_if_overlap_by_ratio(
  472. bbox1: Union[List[int], Tuple[int, int, int, int]],
  473. bbox2: Union[List[int], Tuple[int, int, int, int]],
  474. ratio: float,
  475. smaller: bool = True,
  476. ) -> Optional[Union[List[int], Tuple[int, int, int, int]]]:
  477. """
  478. Determine if the overlap area between two bounding boxes exceeds a given ratio
  479. and return the smaller (or larger) bounding box based on the `smaller` flag.
  480. Args:
  481. bbox1 (Union[List[int], Tuple[int, int, int, int]]): Coordinates of the first bounding box [x_min, y_min, x_max, y_max].
  482. bbox2 (Union[List[int], Tuple[int, int, int, int]]): Coordinates of the second bounding box [x_min, y_min, x_max, y_max].
  483. ratio (float): The overlap ratio threshold.
  484. smaller (bool): If True, return the smaller bounding box; otherwise, return the larger one.
  485. Returns:
  486. Optional[Union[List[int], Tuple[int, int, int, int]]]:
  487. The selected bounding box or None if the overlap ratio is not exceeded.
  488. """
  489. # Calculate the areas of both bounding boxes
  490. area1 = caculate_bbox_area(bbox1)
  491. area2 = caculate_bbox_area(bbox2)
  492. # Calculate the overlap ratio using a helper function
  493. overlap_ratio = calculate_overlap_ratio(bbox1, bbox2, mode="small")
  494. # Check if the overlap ratio exceeds the threshold
  495. if overlap_ratio > ratio:
  496. if (area1 <= area2 and smaller) or (area1 >= area2 and not smaller):
  497. return 1
  498. else:
  499. return 2
  500. return None
  501. def remove_overlap_blocks(
  502. blocks: List[Dict[str, List[int]]], threshold: float = 0.65, smaller: bool = True
  503. ) -> Tuple[List[Dict[str, List[int]]], List[Dict[str, List[int]]]]:
  504. """
  505. Remove overlapping blocks based on a specified overlap ratio threshold.
  506. Args:
  507. blocks (List[Dict[str, List[int]]]): List of block dictionaries, each containing a 'block_bbox' key.
  508. threshold (float): Ratio threshold to determine significant overlap.
  509. smaller (bool): If True, the smaller block in overlap is removed.
  510. Returns:
  511. Tuple[List[Dict[str, List[int]]], List[Dict[str, List[int]]]]:
  512. A tuple containing the updated list of blocks and a list of dropped blocks.
  513. """
  514. dropped_indexes = set()
  515. blocks = deepcopy(blocks)
  516. # Iterate over each pair of blocks to find overlaps
  517. for i, block1 in enumerate(blocks["boxes"]):
  518. for j in range(i + 1, len(blocks["boxes"])):
  519. block2 = blocks["boxes"][j]
  520. # Skip blocks that are already marked for removal
  521. if i in dropped_indexes or j in dropped_indexes:
  522. continue
  523. # Check for overlap and determine which block to remove
  524. overlap_box_index = _get_minbox_if_overlap_by_ratio(
  525. block1["coordinate"],
  526. block2["coordinate"],
  527. threshold,
  528. smaller=smaller,
  529. )
  530. if overlap_box_index is not None:
  531. if block1["label"] == "image" and block2["label"] == "image":
  532. # Determine which block to remove based on overlap_box_index
  533. if overlap_box_index == 1:
  534. drop_index = i
  535. else:
  536. drop_index = j
  537. elif block1["label"] == "image" and block2["label"] != "image":
  538. drop_index = i
  539. elif block1["label"] != "image" and block2["label"] == "image":
  540. drop_index = j
  541. elif overlap_box_index == 1:
  542. drop_index = i
  543. else:
  544. drop_index = j
  545. dropped_indexes.add(drop_index)
  546. # Remove marked blocks from the original list
  547. for index in sorted(dropped_indexes, reverse=True):
  548. del blocks["boxes"][index]
  549. return blocks
  550. def get_bbox_intersection(bbox1, bbox2, return_format="bbox"):
  551. """
  552. Compute the intersection of two bounding boxes, supporting both 4-coordinate and 8-coordinate formats.
  553. Args:
  554. bbox1 (tuple): The first bounding box, either in 4-coordinate format (x_min, y_min, x_max, y_max)
  555. or 8-coordinate format (x1, y1, x2, y2, x3, y3, x4, y4).
  556. bbox2 (tuple): The second bounding box in the same format as bbox1.
  557. return_format (str): The format of the output intersection, either 'bbox' or 'poly'.
  558. Returns:
  559. tuple or None: The intersection bounding box in the specified format, or None if there is no intersection.
  560. """
  561. bbox1 = np.array(bbox1)
  562. bbox2 = np.array(bbox2)
  563. # Convert both bounding boxes to rectangles
  564. rect1 = bbox1 if len(bbox1.shape) == 1 else convert_points_to_boxes([bbox1])[0]
  565. rect2 = bbox2 if len(bbox2.shape) == 1 else convert_points_to_boxes([bbox2])[0]
  566. # Calculate the intersection rectangle
  567. x_min_inter = max(rect1[0], rect2[0])
  568. y_min_inter = max(rect1[1], rect2[1])
  569. x_max_inter = min(rect1[2], rect2[2])
  570. y_max_inter = min(rect1[3], rect2[3])
  571. # Check if there is an intersection
  572. if x_min_inter >= x_max_inter or y_min_inter >= y_max_inter:
  573. return None
  574. if return_format == "bbox":
  575. return np.array([x_min_inter, y_min_inter, x_max_inter, y_max_inter])
  576. elif return_format == "poly":
  577. return np.array(
  578. [
  579. [x_min_inter, y_min_inter],
  580. [x_max_inter, y_min_inter],
  581. [x_max_inter, y_max_inter],
  582. [x_min_inter, y_max_inter],
  583. ],
  584. dtype=np.int16,
  585. )
  586. else:
  587. raise ValueError("return_format must be either 'bbox' or 'poly'.")
  588. def shrink_supplement_region_bbox(
  589. supplement_region_bbox,
  590. ref_region_bbox,
  591. image_width,
  592. image_height,
  593. block_idxes_set,
  594. block_bboxes,
  595. ) -> List:
  596. """
  597. Shrink the supplement region bbox according to the reference region bbox and match the block bboxes.
  598. Args:
  599. supplement_region_bbox (list): The supplement region bbox.
  600. ref_region_bbox (list): The reference region bbox.
  601. image_width (int): The width of the image.
  602. image_height (int): The height of the image.
  603. block_idxes_set (set): The indexes of the blocks that intersect with the region bbox.
  604. block_bboxes (dict): The dictionary of block bboxes.
  605. Returns:
  606. list: The new region bbox and the matched block idxes.
  607. """
  608. x1, y1, x2, y2 = supplement_region_bbox
  609. x1_prime, y1_prime, x2_prime, y2_prime = ref_region_bbox
  610. index_conversion_map = {0: 2, 1: 3, 2: 0, 3: 1}
  611. edge_distance_list = [
  612. (x1_prime - x1) / image_width,
  613. (y1_prime - y1) / image_height,
  614. (x2 - x2_prime) / image_width,
  615. (y2 - y2_prime) / image_height,
  616. ]
  617. edge_distance_list_tmp = edge_distance_list[:]
  618. min_distance = min(edge_distance_list)
  619. src_index = index_conversion_map[edge_distance_list.index(min_distance)]
  620. if len(block_idxes_set) == 0:
  621. return supplement_region_bbox, []
  622. for _ in range(3):
  623. dst_index = index_conversion_map[src_index]
  624. tmp_region_bbox = supplement_region_bbox[:]
  625. tmp_region_bbox[dst_index] = ref_region_bbox[src_index]
  626. iner_block_idxes, split_block_idxes = [], []
  627. for block_idx in block_idxes_set:
  628. overlap_ratio = calculate_overlap_ratio(
  629. tmp_region_bbox, block_bboxes[block_idx], mode="small"
  630. )
  631. if overlap_ratio > REGION_SETTINGS.get(
  632. "match_block_overlap_ratio_threshold", 0.8
  633. ):
  634. iner_block_idxes.append(block_idx)
  635. elif overlap_ratio > REGION_SETTINGS.get(
  636. "split_block_overlap_ratio_threshold", 0.4
  637. ):
  638. split_block_idxes.append(block_idx)
  639. if len(iner_block_idxes) > 0:
  640. if len(split_block_idxes) > 0:
  641. for split_block_idx in split_block_idxes:
  642. split_block_bbox = block_bboxes[split_block_idx]
  643. x1, y1, x2, y2 = tmp_region_bbox
  644. x1_prime, y1_prime, x2_prime, y2_prime = split_block_bbox
  645. edge_distance_list = [
  646. (x1_prime - x1) / image_width,
  647. (y1_prime - y1) / image_height,
  648. (x2 - x2_prime) / image_width,
  649. (y2 - y2_prime) / image_height,
  650. ]
  651. max_distance = max(edge_distance_list)
  652. src_index = edge_distance_list.index(max_distance)
  653. dst_index = index_conversion_map[src_index]
  654. tmp_region_bbox[dst_index] = split_block_bbox[src_index]
  655. tmp_region_bbox, iner_idxes = shrink_supplement_region_bbox(
  656. tmp_region_bbox,
  657. ref_region_bbox,
  658. image_width,
  659. image_height,
  660. iner_block_idxes,
  661. block_bboxes,
  662. )
  663. if len(iner_idxes) == 0:
  664. continue
  665. matched_bboxes = [block_bboxes[idx] for idx in iner_block_idxes]
  666. supplement_region_bbox = calculate_minimum_enclosing_bbox(matched_bboxes)
  667. break
  668. else:
  669. edge_distance_list_tmp = [
  670. x for x in edge_distance_list_tmp if x != min_distance
  671. ]
  672. min_distance = min(edge_distance_list_tmp)
  673. src_index = index_conversion_map[edge_distance_list.index(min_distance)]
  674. return supplement_region_bbox, iner_block_idxes
  675. def update_region_box(bbox, region_box):
  676. if region_box is None:
  677. return bbox
  678. x1, y1, x2, y2 = bbox
  679. x1_region, y1_region, x2_region, y2_region = region_box
  680. x1_region = int(min(x1, x1_region))
  681. y1_region = int(min(y1, y1_region))
  682. x2_region = int(max(x2, x2_region))
  683. y2_region = int(max(y2, y2_region))
  684. region_box = [x1_region, y1_region, x2_region, y2_region]
  685. return region_box
  686. def convert_formula_res_to_ocr_format(formula_res_list: List, ocr_res: dict):
  687. for formula_res in formula_res_list:
  688. x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
  689. poly_points = [
  690. (x_min, y_min),
  691. (x_max, y_min),
  692. (x_max, y_max),
  693. (x_min, y_max),
  694. ]
  695. ocr_res["dt_polys"].append(poly_points)
  696. ocr_res["rec_texts"].append(f"{formula_res['rec_formula']}")
  697. if ocr_res["rec_boxes"].size == 0:
  698. ocr_res["rec_boxes"] = np.array(formula_res["dt_polys"])
  699. else:
  700. ocr_res["rec_boxes"] = np.vstack(
  701. (ocr_res["rec_boxes"], [formula_res["dt_polys"]])
  702. )
  703. ocr_res["rec_labels"].append("formula")
  704. ocr_res["rec_polys"].append(poly_points)
  705. ocr_res["rec_scores"].append(1)
  706. def caculate_bbox_area(bbox):
  707. x1, y1, x2, y2 = map(float, bbox)
  708. area = abs((x2 - x1) * (y2 - y1))
  709. return area
  710. def get_show_color(label: str, order_label=False) -> Tuple:
  711. if order_label:
  712. label_colors = {
  713. "doc_title": (255, 248, 220, 100), # Cornsilk
  714. "doc_title_text": (255, 239, 213, 100),
  715. "paragraph_title": (102, 102, 255, 100),
  716. "sub_paragraph_title": (102, 178, 255, 100),
  717. "vision": (153, 255, 51, 100),
  718. "vision_title": (144, 238, 144, 100), # Light Green
  719. "vision_footnote": (144, 238, 144, 100), # Light Green
  720. "normal_text": (153, 0, 76, 100),
  721. "cross_layout": (53, 218, 207, 100), # Thistle
  722. "cross_reference": (221, 160, 221, 100), # Floral White
  723. }
  724. else:
  725. label_colors = {
  726. # Medium Blue (from 'titles_list')
  727. "paragraph_title": (102, 102, 255, 100),
  728. "doc_title": (255, 248, 220, 100), # Cornsilk
  729. # Light Yellow (from 'tables_caption_list')
  730. "table_title": (255, 255, 102, 100),
  731. # Sky Blue (from 'imgs_caption_list')
  732. "figure_title": (102, 178, 255, 100),
  733. "chart_title": (221, 160, 221, 100), # Plum
  734. "vision_footnote": (144, 238, 144, 100), # Light Green
  735. # Deep Purple (from 'texts_list')
  736. "text": (153, 0, 76, 100),
  737. # Bright Green (from 'interequations_list')
  738. "formula": (0, 255, 0, 100),
  739. "abstract": (255, 239, 213, 100), # Papaya Whip
  740. # Medium Green (from 'lists_list' and 'indexs_list')
  741. "content": (40, 169, 92, 100),
  742. # Neutral Gray (from 'dropped_bbox_list')
  743. "seal": (158, 158, 158, 100),
  744. # Olive Yellow (from 'tables_body_list')
  745. "table": (204, 204, 0, 100),
  746. # Bright Green (from 'imgs_body_list')
  747. "image": (153, 255, 51, 100),
  748. # Bright Green (from 'imgs_body_list')
  749. "figure": (153, 255, 51, 100),
  750. "chart": (216, 191, 216, 100), # Thistle
  751. # Pale Yellow-Green (from 'tables_footnote_list')
  752. "reference": (229, 255, 204, 100),
  753. # "reference_content": (229, 255, 204, 100),
  754. "algorithm": (255, 250, 240, 100), # Floral White
  755. }
  756. default_color = (158, 158, 158, 100)
  757. return label_colors.get(label, default_color)