utils.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. __all__ = [
  15. "get_sub_regions_ocr_res",
  16. "get_show_color",
  17. "sorted_layout_boxes",
  18. ]
  19. import re
  20. from copy import deepcopy
  21. from typing import Dict, List, Optional, Tuple, Union
  22. import numpy as np
  23. from PIL import Image
  24. from ..components import convert_points_to_boxes
  25. from ..ocr.result import OCRResult
  26. from .setting import REGION_SETTINGS
  27. def get_overlap_boxes_idx(src_boxes: np.ndarray, ref_boxes: np.ndarray) -> List:
  28. """
  29. Get the indices of source boxes that overlap with reference boxes based on a specified threshold.
  30. Args:
  31. src_boxes (np.ndarray): A 2D numpy array of source bounding boxes.
  32. ref_boxes (np.ndarray): A 2D numpy array of reference bounding boxes.
  33. Returns:
  34. match_idx_list (list): A list of indices of source boxes that overlap with reference boxes.
  35. """
  36. match_idx_list = []
  37. src_boxes_num = len(src_boxes)
  38. if src_boxes_num > 0 and len(ref_boxes) > 0:
  39. for rno in range(len(ref_boxes)):
  40. ref_box = ref_boxes[rno]
  41. x1 = np.maximum(ref_box[0], src_boxes[:, 0])
  42. y1 = np.maximum(ref_box[1], src_boxes[:, 1])
  43. x2 = np.minimum(ref_box[2], src_boxes[:, 2])
  44. y2 = np.minimum(ref_box[3], src_boxes[:, 3])
  45. pub_w = x2 - x1
  46. pub_h = y2 - y1
  47. match_idx = np.where((pub_w > 3) & (pub_h > 3))[0]
  48. match_idx_list.extend(match_idx)
  49. return match_idx_list
  50. def get_sub_regions_ocr_res(
  51. overall_ocr_res: OCRResult,
  52. object_boxes: List,
  53. flag_within: bool = True,
  54. return_match_idx: bool = False,
  55. ) -> OCRResult:
  56. """
  57. Filters OCR results to only include text boxes within specified object boxes based on a flag.
  58. Args:
  59. overall_ocr_res (OCRResult): The original OCR result containing all text boxes.
  60. object_boxes (list): A list of bounding boxes for the objects of interest.
  61. flag_within (bool): If True, only include text boxes within the object boxes. If False, exclude text boxes within the object boxes.
  62. return_match_idx (bool): If True, return the list of matching indices.
  63. Returns:
  64. OCRResult: A filtered OCR result containing only the relevant text boxes.
  65. """
  66. sub_regions_ocr_res = {}
  67. sub_regions_ocr_res["rec_polys"] = []
  68. sub_regions_ocr_res["rec_texts"] = []
  69. sub_regions_ocr_res["rec_scores"] = []
  70. sub_regions_ocr_res["rec_boxes"] = []
  71. overall_text_boxes = overall_ocr_res["rec_boxes"]
  72. match_idx_list = get_overlap_boxes_idx(overall_text_boxes, object_boxes)
  73. match_idx_list = list(set(match_idx_list))
  74. for box_no in range(len(overall_text_boxes)):
  75. if flag_within:
  76. if box_no in match_idx_list:
  77. flag_match = True
  78. else:
  79. flag_match = False
  80. else:
  81. if box_no not in match_idx_list:
  82. flag_match = True
  83. else:
  84. flag_match = False
  85. if flag_match:
  86. sub_regions_ocr_res["rec_polys"].append(
  87. overall_ocr_res["rec_polys"][box_no]
  88. )
  89. sub_regions_ocr_res["rec_texts"].append(
  90. overall_ocr_res["rec_texts"][box_no]
  91. )
  92. sub_regions_ocr_res["rec_scores"].append(
  93. overall_ocr_res["rec_scores"][box_no]
  94. )
  95. sub_regions_ocr_res["rec_boxes"].append(
  96. overall_ocr_res["rec_boxes"][box_no]
  97. )
  98. for key in ["rec_polys", "rec_scores", "rec_boxes"]:
  99. sub_regions_ocr_res[key] = np.array(sub_regions_ocr_res[key])
  100. return (
  101. (sub_regions_ocr_res, match_idx_list)
  102. if return_match_idx
  103. else sub_regions_ocr_res
  104. )
  105. def sorted_layout_boxes(res, w):
  106. """
  107. Sort text boxes in order from top to bottom, left to right
  108. Args:
  109. res: List of dictionaries containing layout information.
  110. w: Width of image.
  111. Returns:
  112. List of dictionaries containing sorted layout information.
  113. """
  114. num_boxes = len(res)
  115. if num_boxes == 1:
  116. return res
  117. # Sort on the y axis first or sort it on the x axis
  118. sorted_boxes = sorted(res, key=lambda x: (x["block_bbox"][1], x["block_bbox"][0]))
  119. _boxes = list(sorted_boxes)
  120. new_res = []
  121. res_left = []
  122. res_right = []
  123. i = 0
  124. while True:
  125. if i >= num_boxes:
  126. break
  127. # Check that the bbox is on the left
  128. elif (
  129. _boxes[i]["block_bbox"][0] < w / 4
  130. and _boxes[i]["block_bbox"][2] < 3 * w / 5
  131. ):
  132. res_left.append(_boxes[i])
  133. i += 1
  134. elif _boxes[i]["block_bbox"][0] > 2 * w / 5:
  135. res_right.append(_boxes[i])
  136. i += 1
  137. else:
  138. new_res += res_left
  139. new_res += res_right
  140. new_res.append(_boxes[i])
  141. res_left = []
  142. res_right = []
  143. i += 1
  144. res_left = sorted(res_left, key=lambda x: (x["block_bbox"][1]))
  145. res_right = sorted(res_right, key=lambda x: (x["block_bbox"][1]))
  146. if res_left:
  147. new_res += res_left
  148. if res_right:
  149. new_res += res_right
  150. return new_res
  151. def calculate_projection_overlap_ratio(
  152. bbox1: List[float],
  153. bbox2: List[float],
  154. direction: str = "horizontal",
  155. mode="union",
  156. ) -> float:
  157. """
  158. Calculate the IoU of lines between two bounding boxes.
  159. Args:
  160. bbox1 (List[float]): First bounding box [x_min, y_min, x_max, y_max].
  161. bbox2 (List[float]): Second bounding box [x_min, y_min, x_max, y_max].
  162. direction (str): direction of the projection, "horizontal" or "vertical".
  163. Returns:
  164. float: Line overlap ratio. Returns 0 if there is no overlap.
  165. """
  166. start_index, end_index = 1, 3
  167. if direction == "horizontal":
  168. start_index, end_index = 0, 2
  169. intersection_start = max(bbox1[start_index], bbox2[start_index])
  170. intersection_end = min(bbox1[end_index], bbox2[end_index])
  171. overlap = intersection_end - intersection_start
  172. if overlap <= 0:
  173. return 0
  174. if mode == "union":
  175. ref_width = max(bbox1[end_index], bbox2[end_index]) - min(
  176. bbox1[start_index], bbox2[start_index]
  177. )
  178. elif mode == "small":
  179. ref_width = min(
  180. bbox1[end_index] - bbox1[start_index], bbox2[end_index] - bbox2[start_index]
  181. )
  182. elif mode == "large":
  183. ref_width = max(
  184. bbox1[end_index] - bbox1[start_index], bbox2[end_index] - bbox2[start_index]
  185. )
  186. else:
  187. raise ValueError(
  188. f"Invalid mode {mode}, must be one of ['union', 'small', 'large']."
  189. )
  190. return overlap / ref_width if ref_width > 0 else 0.0
  191. def calculate_overlap_ratio(
  192. bbox1: Union[list, tuple], bbox2: Union[list, tuple], mode="union"
  193. ) -> float:
  194. """
  195. Calculate the overlap ratio between two bounding boxes.
  196. Args:
  197. bbox1 (list or tuple): The first bounding box, format [x_min, y_min, x_max, y_max]
  198. bbox2 (list or tuple): The second bounding box, format [x_min, y_min, x_max, y_max]
  199. mode (str): The mode of calculation, either 'union', 'small', or 'large'.
  200. Returns:
  201. float: The overlap ratio value between the two bounding boxes
  202. """
  203. x_min_inter = max(bbox1[0], bbox2[0])
  204. y_min_inter = max(bbox1[1], bbox2[1])
  205. x_max_inter = min(bbox1[2], bbox2[2])
  206. y_max_inter = min(bbox1[3], bbox2[3])
  207. inter_width = max(0, x_max_inter - x_min_inter)
  208. inter_height = max(0, y_max_inter - y_min_inter)
  209. inter_area = inter_width * inter_height
  210. bbox1_area = caculate_bbox_area(bbox1)
  211. bbox2_area = caculate_bbox_area(bbox2)
  212. if mode == "union":
  213. ref_area = bbox1_area + bbox2_area - inter_area
  214. elif mode == "small":
  215. ref_area = min(bbox1_area, bbox2_area)
  216. elif mode == "large":
  217. ref_area = max(bbox1_area, bbox2_area)
  218. else:
  219. raise ValueError(
  220. f"Invalid mode {mode}, must be one of ['union', 'small', 'large']."
  221. )
  222. if ref_area == 0:
  223. return 0.0
  224. return inter_area / ref_area
  225. def group_boxes_into_lines(ocr_rec_res, line_height_iou_threshold):
  226. rec_boxes = ocr_rec_res["boxes"]
  227. rec_texts = ocr_rec_res["rec_texts"]
  228. rec_labels = ocr_rec_res["rec_labels"]
  229. text_boxes = [
  230. rec_boxes[i] for i in range(len(rec_boxes)) if rec_labels[i] == "text"
  231. ]
  232. text_orientation = calculate_text_orientation(text_boxes)
  233. match_direction = "vertical" if text_orientation == "horizontal" else "horizontal"
  234. line_start_index = 1 if text_orientation == "horizontal" else 0
  235. line_end_index = 3 if text_orientation == "horizontal" else 2
  236. spans = list(zip(rec_boxes, rec_texts, rec_labels))
  237. sort_index = 1
  238. reverse = False
  239. if text_orientation == "vertical":
  240. sort_index = 0
  241. reverse = True
  242. spans.sort(key=lambda span: span[0][sort_index], reverse=reverse)
  243. spans = [list(span) for span in spans]
  244. lines = []
  245. line = [spans[0]]
  246. line_region_box = spans[0][0].copy()
  247. line_heights = []
  248. # merge line
  249. for span in spans[1:]:
  250. rec_bbox = span[0]
  251. if (
  252. calculate_projection_overlap_ratio(
  253. line_region_box, rec_bbox, match_direction, mode="small"
  254. )
  255. >= line_height_iou_threshold
  256. ):
  257. line.append(span)
  258. line_region_box[line_start_index] = min(
  259. line_region_box[line_start_index], rec_bbox[line_start_index]
  260. )
  261. line_region_box[line_end_index] = max(
  262. line_region_box[line_end_index], rec_bbox[line_end_index]
  263. )
  264. else:
  265. line_heights.append(
  266. line_region_box[line_end_index] - line_region_box[line_start_index]
  267. )
  268. lines.append(line)
  269. line = [span]
  270. line_region_box = rec_bbox.copy()
  271. lines.append(line)
  272. line_heights.append(
  273. line_region_box[line_end_index] - line_region_box[line_start_index]
  274. )
  275. min_height = min(line_heights) if line_heights else 0
  276. max_height = max(line_heights) if line_heights else 0
  277. if max_height > min_height * 2 and text_orientation == "vertical":
  278. line_heights = np.array(line_heights)
  279. min_height_num = np.sum(line_heights < min_height * 1.1)
  280. if min_height_num < len(lines) * 0.4:
  281. condition = line_heights > min_height * 1.1
  282. lines = [value for value, keep in zip(lines, condition) if keep]
  283. return lines, text_orientation, np.mean(line_heights)
  284. def calculate_minimum_enclosing_bbox(bboxes):
  285. """
  286. Calculate the minimum enclosing bounding box for a list of bounding boxes.
  287. Args:
  288. bboxes (list): A list of bounding boxes represented as lists of four integers [x1, y1, x2, y2].
  289. Returns:
  290. list: The minimum enclosing bounding box represented as a list of four integers [x1, y1, x2, y2].
  291. """
  292. if not bboxes:
  293. raise ValueError("The list of bounding boxes is empty.")
  294. # Convert the list of bounding boxes to a NumPy array
  295. bboxes_array = np.array(bboxes)
  296. # Compute the minimum and maximum values along the respective axes
  297. min_x = np.min(bboxes_array[:, 0])
  298. min_y = np.min(bboxes_array[:, 1])
  299. max_x = np.max(bboxes_array[:, 2])
  300. max_y = np.max(bboxes_array[:, 3])
  301. # Return the minimum enclosing bounding box
  302. return [min_x, min_y, max_x, max_y]
  303. def calculate_text_orientation(
  304. bboxes: List[List[int]], orientation_ratio: float = 1.5
  305. ) -> bool:
  306. """
  307. Calculate the orientation of the text based on the bounding boxes.
  308. Args:
  309. bboxes (list): A list of bounding boxes.
  310. orientation_ratio (float): Ratio for determining orientation. Default is 1.5.
  311. Returns:
  312. str: "horizontal" or "vertical".
  313. """
  314. horizontal_box_num = 0
  315. for bbox in bboxes:
  316. if len(bbox) != 4:
  317. raise ValueError(
  318. "Invalid bounding box format. Expected a list of length 4."
  319. )
  320. x1, y1, x2, y2 = bbox
  321. width = x2 - x1
  322. height = y2 - y1
  323. horizontal_box_num += 1 if width * orientation_ratio >= height else 0
  324. return "horizontal" if horizontal_box_num >= len(bboxes) * 0.5 else "vertical"
  325. def is_english_letter(char):
  326. return bool(re.match(r"^[A-Za-z]$", char))
  327. def is_numeric(char):
  328. return bool(re.match(r"^[\d.]+$", char))
  329. def is_non_breaking_punctuation(char):
  330. """
  331. 判断一个字符是否是不需要换行的标点符号,包括全角和半角的符号。
  332. :param char: str, 单个字符
  333. :return: bool, 如果字符是不需要换行的标点符号,返回True,否则返回False
  334. """
  335. non_breaking_punctuations = {
  336. ",", # 半角逗号
  337. ",", # 全角逗号
  338. "、", # 顿号
  339. ";", # 半角分号
  340. ";", # 全角分号
  341. ":", # 半角冒号
  342. ":", # 全角冒号
  343. "-", # 连字符
  344. }
  345. return char in non_breaking_punctuations
  346. def format_line(
  347. line: List[List[Union[List[int], str]]],
  348. text_direction: int,
  349. block_width: int,
  350. block_start_coordinate: int,
  351. block_stop_coordinate: int,
  352. line_gap_limit: int = 10,
  353. block_label: str = "text",
  354. ) -> None:
  355. """
  356. Format a line of text spans based on layout constraints.
  357. Args:
  358. line (list): A list of spans, where each span is a list containing a bounding box and text.
  359. block_left_coordinate (int): The text line directional minimum coordinate of the layout bounding box.
  360. block_stop_coordinate (int): The text line directional maximum x-coordinate of the layout bounding box.
  361. first_line_span_limit (int): The limit for the number of pixels before the first span that should be considered part of the first line. Default is 10.
  362. line_gap_limit (int): The limit for the number of pixels after the last span that should be considered part of the last line. Default is 10.
  363. block_label (str): The label associated with the entire block. Default is 'text'.
  364. Returns:
  365. None: The function modifies the line in place.
  366. """
  367. first_span_box = line[0][0]
  368. last_span_box = line[-1][0]
  369. for span in line:
  370. if span[2] == "formula" and block_label != "formula":
  371. if len(line) > 1:
  372. span[1] = f"${span[1]}$"
  373. else:
  374. span[1] = f"\n${span[1]}$"
  375. line_text = ""
  376. for span in line:
  377. _, text, label = span
  378. line_text += text
  379. if len(text) > 0 and is_english_letter(line_text[-1]) or label == "formula":
  380. line_text += " "
  381. if text_direction == "horizontal":
  382. text_start_index = 0
  383. text_stop_index = 2
  384. else:
  385. text_start_index = 1
  386. text_stop_index = 3
  387. need_new_line = False
  388. if (
  389. len(line_text) > 0
  390. and not is_english_letter(line_text[-1])
  391. and not is_non_breaking_punctuation(line_text[-1])
  392. ):
  393. if (
  394. text_direction == "horizontal"
  395. and block_stop_coordinate - last_span_box[text_stop_index] > line_gap_limit
  396. ) or (
  397. text_direction == "vertical"
  398. and (
  399. block_stop_coordinate - last_span_box[text_stop_index] > line_gap_limit
  400. or first_span_box[1] - block_start_coordinate > line_gap_limit
  401. )
  402. ):
  403. need_new_line = True
  404. if line_text.endswith("-"):
  405. line_text = line_text[:-1]
  406. elif (
  407. len(line_text) > 0 and is_english_letter(line_text[-1])
  408. ) or line_text.endswith("$"):
  409. line_text += " "
  410. elif (
  411. len(line_text) > 0
  412. and not is_english_letter(line_text[-1])
  413. and not is_non_breaking_punctuation(line_text[-1])
  414. and not is_numeric(line_text[-1])
  415. ) or text_direction == "vertical":
  416. if block_stop_coordinate - last_span_box[text_stop_index] > block_width * 0.4:
  417. line_text += "\n"
  418. if (
  419. first_span_box[text_start_index] - block_start_coordinate
  420. > block_width * 0.4
  421. ):
  422. line_text = "\n" + line_text
  423. return line_text, need_new_line
  424. def split_boxes_by_projection(spans: List[List[int]], direction, offset=1e-5):
  425. """
  426. Check if there is any complete containment in the x-direction
  427. between the bounding boxes and split the containing box accordingly.
  428. Args:
  429. spans (list of lists): Each element is a list containing an ndarray of length 4, a text string, and a label.
  430. direction: 'horizontal' or 'vertical', indicating whether the spans are arranged horizontally or vertically.
  431. offset (float): A small offset value to ensure that the split boxes are not too close to the original boxes.
  432. Returns:
  433. A new list of boxes, including split boxes, with the same `rec_text` and `label` attributes.
  434. """
  435. def is_projection_contained(box_a, box_b, start_idx, end_idx):
  436. """Check if box_a completely contains box_b in the x-direction."""
  437. return box_a[start_idx] <= box_b[start_idx] and box_a[end_idx] >= box_b[end_idx]
  438. new_boxes = []
  439. if direction == "horizontal":
  440. projection_start_index, projection_end_index = 0, 2
  441. else:
  442. projection_start_index, projection_end_index = 1, 3
  443. for i in range(len(spans)):
  444. span = spans[i]
  445. is_split = False
  446. for j in range(i, len(spans)):
  447. box_b = spans[j][0]
  448. box_a, text, label = span
  449. if is_projection_contained(
  450. box_a, box_b, projection_start_index, projection_end_index
  451. ):
  452. is_split = True
  453. # Split box_a based on the x-coordinates of box_b
  454. if box_a[projection_start_index] < box_b[projection_start_index]:
  455. w = (
  456. box_b[projection_start_index]
  457. - offset
  458. - box_a[projection_start_index]
  459. )
  460. if w > 1:
  461. new_bbox = box_a.copy()
  462. new_bbox[projection_end_index] = (
  463. box_b[projection_start_index] - offset
  464. )
  465. new_boxes.append(
  466. [
  467. np.array(new_bbox),
  468. text,
  469. label,
  470. ]
  471. )
  472. if box_a[projection_end_index] > box_b[projection_end_index]:
  473. w = (
  474. box_a[projection_end_index]
  475. - box_b[projection_end_index]
  476. + offset
  477. )
  478. if w > 1:
  479. box_a[projection_start_index] = (
  480. box_b[projection_end_index] + offset
  481. )
  482. span = [
  483. np.array(box_a),
  484. text,
  485. label,
  486. ]
  487. if j == len(spans) - 1 and is_split:
  488. new_boxes.append(span)
  489. if not is_split:
  490. new_boxes.append(span)
  491. return new_boxes
  492. def remove_extra_space(input_text: str) -> str:
  493. """
  494. Process the input text to handle spaces.
  495. The function removes multiple consecutive spaces between Chinese characters and ensures that
  496. only a single space is retained between Chinese and non-Chinese characters.
  497. Args:
  498. input_text (str): The text to be processed.
  499. Returns:
  500. str: The processed text with properly formatted spaces.
  501. """
  502. # Remove spaces between Chinese characters
  503. text_without_spaces = re.sub(
  504. r"(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])", "", input_text
  505. )
  506. # Ensure single space between Chinese and non-Chinese characters
  507. text_with_single_spaces = re.sub(
  508. r"(?<=[\u4e00-\u9fff])\s+(?=[^\u4e00-\u9fff])|(?<=[^\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])",
  509. " ",
  510. text_without_spaces,
  511. )
  512. # Reduce any remaining consecutive spaces to a single space
  513. final_text = re.sub(r"\s+", " ", text_with_single_spaces).strip()
  514. return final_text
  515. def gather_imgs(original_img, layout_det_objs):
  516. imgs_in_doc = []
  517. for det_obj in layout_det_objs:
  518. if det_obj["label"] in ("image", "chart", "seal", "formula", "table"):
  519. x_min, y_min, x_max, y_max = list(map(int, det_obj["coordinate"]))
  520. img_path = f"imgs/img_in_table_box_{x_min}_{y_min}_{x_max}_{y_max}.jpg"
  521. img = Image.fromarray(original_img[y_min:y_max, x_min:x_max, ::-1])
  522. imgs_in_doc.append(
  523. {
  524. "path": img_path,
  525. "img": img,
  526. "coordinate": (x_min, y_min, x_max, y_max),
  527. "score": det_obj["score"],
  528. }
  529. )
  530. return imgs_in_doc
  531. def _get_minbox_if_overlap_by_ratio(
  532. bbox1: Union[List[int], Tuple[int, int, int, int]],
  533. bbox2: Union[List[int], Tuple[int, int, int, int]],
  534. ratio: float,
  535. smaller: bool = True,
  536. ) -> Optional[Union[List[int], Tuple[int, int, int, int]]]:
  537. """
  538. Determine if the overlap area between two bounding boxes exceeds a given ratio
  539. and return the smaller (or larger) bounding box based on the `smaller` flag.
  540. Args:
  541. bbox1 (Union[List[int], Tuple[int, int, int, int]]): Coordinates of the first bounding box [x_min, y_min, x_max, y_max].
  542. bbox2 (Union[List[int], Tuple[int, int, int, int]]): Coordinates of the second bounding box [x_min, y_min, x_max, y_max].
  543. ratio (float): The overlap ratio threshold.
  544. smaller (bool): If True, return the smaller bounding box; otherwise, return the larger one.
  545. Returns:
  546. Optional[Union[List[int], Tuple[int, int, int, int]]]:
  547. The selected bounding box or None if the overlap ratio is not exceeded.
  548. """
  549. # Calculate the areas of both bounding boxes
  550. area1 = caculate_bbox_area(bbox1)
  551. area2 = caculate_bbox_area(bbox2)
  552. # Calculate the overlap ratio using a helper function
  553. overlap_ratio = calculate_overlap_ratio(bbox1, bbox2, mode="small")
  554. # Check if the overlap ratio exceeds the threshold
  555. if overlap_ratio > ratio:
  556. if (area1 <= area2 and smaller) or (area1 >= area2 and not smaller):
  557. return 1
  558. else:
  559. return 2
  560. return None
  561. def remove_overlap_blocks(
  562. blocks: List[Dict[str, List[int]]], threshold: float = 0.65, smaller: bool = True
  563. ) -> Tuple[List[Dict[str, List[int]]], List[Dict[str, List[int]]]]:
  564. """
  565. Remove overlapping blocks based on a specified overlap ratio threshold.
  566. Args:
  567. blocks (List[Dict[str, List[int]]]): List of block dictionaries, each containing a 'block_bbox' key.
  568. threshold (float): Ratio threshold to determine significant overlap.
  569. smaller (bool): If True, the smaller block in overlap is removed.
  570. Returns:
  571. Tuple[List[Dict[str, List[int]]], List[Dict[str, List[int]]]]:
  572. A tuple containing the updated list of blocks and a list of dropped blocks.
  573. """
  574. dropped_indexes = set()
  575. blocks = deepcopy(blocks)
  576. overlap_image_blocks = []
  577. # Iterate over each pair of blocks to find overlaps
  578. for i, block1 in enumerate(blocks["boxes"]):
  579. for j in range(i + 1, len(blocks["boxes"])):
  580. block2 = blocks["boxes"][j]
  581. # Skip blocks that are already marked for removal
  582. if i in dropped_indexes or j in dropped_indexes:
  583. continue
  584. # Check for overlap and determine which block to remove
  585. overlap_box_index = _get_minbox_if_overlap_by_ratio(
  586. block1["coordinate"],
  587. block2["coordinate"],
  588. threshold,
  589. smaller=smaller,
  590. )
  591. if overlap_box_index is not None:
  592. is_block1_image = block1["label"] == "image"
  593. is_block2_image = block2["label"] == "image"
  594. if is_block1_image != is_block2_image:
  595. # 如果只有一个块在视觉标签中,删除在视觉标签中的那个块
  596. drop_index = i if is_block1_image else j
  597. overlap_image_blocks.append(blocks["boxes"][drop_index])
  598. else:
  599. # 如果两个块都在或都不在视觉标签中,根据 overlap_box_index 决定删除哪个块
  600. drop_index = i if overlap_box_index == 1 else j
  601. dropped_indexes.add(drop_index)
  602. # Remove marked blocks from the original list
  603. for index in sorted(dropped_indexes, reverse=True):
  604. del blocks["boxes"][index]
  605. return blocks
  606. def get_bbox_intersection(bbox1, bbox2, return_format="bbox"):
  607. """
  608. Compute the intersection of two bounding boxes, supporting both 4-coordinate and 8-coordinate formats.
  609. Args:
  610. bbox1 (tuple): The first bounding box, either in 4-coordinate format (x_min, y_min, x_max, y_max)
  611. or 8-coordinate format (x1, y1, x2, y2, x3, y3, x4, y4).
  612. bbox2 (tuple): The second bounding box in the same format as bbox1.
  613. return_format (str): The format of the output intersection, either 'bbox' or 'poly'.
  614. Returns:
  615. tuple or None: The intersection bounding box in the specified format, or None if there is no intersection.
  616. """
  617. bbox1 = np.array(bbox1)
  618. bbox2 = np.array(bbox2)
  619. # Convert both bounding boxes to rectangles
  620. rect1 = bbox1 if len(bbox1.shape) == 1 else convert_points_to_boxes([bbox1])[0]
  621. rect2 = bbox2 if len(bbox2.shape) == 1 else convert_points_to_boxes([bbox2])[0]
  622. # Calculate the intersection rectangle
  623. x_min_inter = max(rect1[0], rect2[0])
  624. y_min_inter = max(rect1[1], rect2[1])
  625. x_max_inter = min(rect1[2], rect2[2])
  626. y_max_inter = min(rect1[3], rect2[3])
  627. # Check if there is an intersection
  628. if x_min_inter >= x_max_inter or y_min_inter >= y_max_inter:
  629. return None
  630. if return_format == "bbox":
  631. return np.array([x_min_inter, y_min_inter, x_max_inter, y_max_inter])
  632. elif return_format == "poly":
  633. return np.array(
  634. [
  635. [x_min_inter, y_min_inter],
  636. [x_max_inter, y_min_inter],
  637. [x_max_inter, y_max_inter],
  638. [x_min_inter, y_max_inter],
  639. ],
  640. dtype=np.int16,
  641. )
  642. else:
  643. raise ValueError("return_format must be either 'bbox' or 'poly'.")
  644. def shrink_supplement_region_bbox(
  645. supplement_region_bbox,
  646. ref_region_bbox,
  647. image_width,
  648. image_height,
  649. block_idxes_set,
  650. block_bboxes,
  651. ) -> List:
  652. """
  653. Shrink the supplement region bbox according to the reference region bbox and match the block bboxes.
  654. Args:
  655. supplement_region_bbox (list): The supplement region bbox.
  656. ref_region_bbox (list): The reference region bbox.
  657. image_width (int): The width of the image.
  658. image_height (int): The height of the image.
  659. block_idxes_set (set): The indexes of the blocks that intersect with the region bbox.
  660. block_bboxes (dict): The dictionary of block bboxes.
  661. Returns:
  662. list: The new region bbox and the matched block idxes.
  663. """
  664. x1, y1, x2, y2 = supplement_region_bbox
  665. x1_prime, y1_prime, x2_prime, y2_prime = ref_region_bbox
  666. index_conversion_map = {0: 2, 1: 3, 2: 0, 3: 1}
  667. edge_distance_list = [
  668. (x1_prime - x1) / image_width,
  669. (y1_prime - y1) / image_height,
  670. (x2 - x2_prime) / image_width,
  671. (y2 - y2_prime) / image_height,
  672. ]
  673. edge_distance_list_tmp = edge_distance_list[:]
  674. min_distance = min(edge_distance_list)
  675. src_index = index_conversion_map[edge_distance_list.index(min_distance)]
  676. if len(block_idxes_set) == 0:
  677. return supplement_region_bbox, []
  678. for _ in range(3):
  679. dst_index = index_conversion_map[src_index]
  680. tmp_region_bbox = supplement_region_bbox[:]
  681. tmp_region_bbox[dst_index] = ref_region_bbox[src_index]
  682. iner_block_idxes, split_block_idxes = [], []
  683. for block_idx in block_idxes_set:
  684. overlap_ratio = calculate_overlap_ratio(
  685. tmp_region_bbox, block_bboxes[block_idx], mode="small"
  686. )
  687. if overlap_ratio > REGION_SETTINGS.get(
  688. "match_block_overlap_ratio_threshold", 0.8
  689. ):
  690. iner_block_idxes.append(block_idx)
  691. elif overlap_ratio > REGION_SETTINGS.get(
  692. "split_block_overlap_ratio_threshold", 0.4
  693. ):
  694. split_block_idxes.append(block_idx)
  695. if len(iner_block_idxes) > 0:
  696. if len(split_block_idxes) > 0:
  697. for split_block_idx in split_block_idxes:
  698. split_block_bbox = block_bboxes[split_block_idx]
  699. x1, y1, x2, y2 = tmp_region_bbox
  700. x1_prime, y1_prime, x2_prime, y2_prime = split_block_bbox
  701. edge_distance_list = [
  702. (x1_prime - x1) / image_width,
  703. (y1_prime - y1) / image_height,
  704. (x2 - x2_prime) / image_width,
  705. (y2 - y2_prime) / image_height,
  706. ]
  707. max_distance = max(edge_distance_list)
  708. src_index = edge_distance_list.index(max_distance)
  709. dst_index = index_conversion_map[src_index]
  710. tmp_region_bbox[dst_index] = split_block_bbox[src_index]
  711. tmp_region_bbox, iner_idxes = shrink_supplement_region_bbox(
  712. tmp_region_bbox,
  713. ref_region_bbox,
  714. image_width,
  715. image_height,
  716. iner_block_idxes,
  717. block_bboxes,
  718. )
  719. if len(iner_idxes) == 0:
  720. continue
  721. matched_bboxes = [block_bboxes[idx] for idx in iner_block_idxes]
  722. supplement_region_bbox = calculate_minimum_enclosing_bbox(matched_bboxes)
  723. break
  724. else:
  725. edge_distance_list_tmp = [
  726. x for x in edge_distance_list_tmp if x != min_distance
  727. ]
  728. min_distance = min(edge_distance_list_tmp)
  729. src_index = index_conversion_map[edge_distance_list.index(min_distance)]
  730. return supplement_region_bbox, iner_block_idxes
  731. def update_region_box(bbox, region_box):
  732. if region_box is None:
  733. return bbox
  734. x1, y1, x2, y2 = bbox
  735. x1_region, y1_region, x2_region, y2_region = region_box
  736. x1_region = int(min(x1, x1_region))
  737. y1_region = int(min(y1, y1_region))
  738. x2_region = int(max(x2, x2_region))
  739. y2_region = int(max(y2, y2_region))
  740. region_box = [x1_region, y1_region, x2_region, y2_region]
  741. return region_box
  742. def convert_formula_res_to_ocr_format(formula_res_list: List, ocr_res: dict):
  743. for formula_res in formula_res_list:
  744. x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
  745. poly_points = [
  746. (x_min, y_min),
  747. (x_max, y_min),
  748. (x_max, y_max),
  749. (x_min, y_max),
  750. ]
  751. ocr_res["dt_polys"].append(poly_points)
  752. formula_res_text: str = formula_res["rec_formula"]
  753. if formula_res_text.startswith("$$") and formula_res_text.endswith("$$"):
  754. formula_res_text = formula_res_text[2:-2]
  755. elif formula_res_text.startswith("$") and formula_res_text.endswith("$"):
  756. formula_res_text = formula_res_text[1:-1]
  757. ocr_res["rec_texts"].append(formula_res_text)
  758. if ocr_res["rec_boxes"].size == 0:
  759. ocr_res["rec_boxes"] = np.array(formula_res["dt_polys"])
  760. else:
  761. ocr_res["rec_boxes"] = np.vstack(
  762. (ocr_res["rec_boxes"], [formula_res["dt_polys"]])
  763. )
  764. ocr_res["rec_labels"].append("formula")
  765. ocr_res["rec_polys"].append(poly_points)
  766. ocr_res["rec_scores"].append(1)
  767. def caculate_bbox_area(bbox):
  768. x1, y1, x2, y2 = map(float, bbox)
  769. area = abs((x2 - x1) * (y2 - y1))
  770. return area
  771. def get_show_color(label: str, order_label=False) -> Tuple:
  772. if order_label:
  773. label_colors = {
  774. "doc_title": (255, 248, 220, 100), # Cornsilk
  775. "doc_title_text": (255, 239, 213, 100),
  776. "paragraph_title": (102, 102, 255, 100),
  777. "sub_paragraph_title": (102, 178, 255, 100),
  778. "vision": (153, 255, 51, 100),
  779. "vision_title": (144, 238, 144, 100), # Light Green
  780. "vision_footnote": (144, 238, 144, 100), # Light Green
  781. "normal_text": (153, 0, 76, 100),
  782. "cross_layout": (53, 218, 207, 100), # Thistle
  783. "cross_reference": (221, 160, 221, 100), # Floral White
  784. }
  785. else:
  786. label_colors = {
  787. # Medium Blue (from 'titles_list')
  788. "paragraph_title": (102, 102, 255, 100),
  789. "doc_title": (255, 248, 220, 100), # Cornsilk
  790. # Light Yellow (from 'tables_caption_list')
  791. "table_title": (255, 255, 102, 100),
  792. # Sky Blue (from 'imgs_caption_list')
  793. "figure_title": (102, 178, 255, 100),
  794. "chart_title": (221, 160, 221, 100), # Plum
  795. "vision_footnote": (144, 238, 144, 100), # Light Green
  796. # Deep Purple (from 'texts_list')
  797. "text": (153, 0, 76, 100),
  798. # Bright Green (from 'interequations_list')
  799. "formula": (0, 255, 0, 100),
  800. "abstract": (255, 239, 213, 100), # Papaya Whip
  801. # Medium Green (from 'lists_list' and 'indexs_list')
  802. "content": (40, 169, 92, 100),
  803. # Neutral Gray (from 'dropped_bbox_list')
  804. "seal": (158, 158, 158, 100),
  805. # Olive Yellow (from 'tables_body_list')
  806. "table": (204, 204, 0, 100),
  807. # Bright Green (from 'imgs_body_list')
  808. "image": (153, 255, 51, 100),
  809. # Bright Green (from 'imgs_body_list')
  810. "figure": (153, 255, 51, 100),
  811. "chart": (216, 191, 216, 100), # Thistle
  812. # Pale Yellow-Green (from 'tables_footnote_list')
  813. "reference": (229, 255, 204, 100),
  814. # "reference_content": (229, 255, 204, 100),
  815. "algorithm": (255, 250, 240, 100), # Floral White
  816. }
  817. default_color = (158, 158, 158, 100)
  818. return label_colors.get(label, default_color)