utils.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. __all__ = [
  15. "get_sub_regions_ocr_res",
  16. "get_show_color",
  17. "sorted_layout_boxes",
  18. ]
  19. import re
  20. from copy import deepcopy
  21. from typing import Dict, List, Optional, Tuple, Union
  22. import numpy as np
  23. from PIL import Image
  24. from ..components import convert_points_to_boxes
  25. from ..ocr.result import OCRResult
  26. from .setting import BLOCK_LABEL_MAP, REGION_SETTINGS
  27. def get_overlap_boxes_idx(src_boxes: np.ndarray, ref_boxes: np.ndarray) -> List:
  28. """
  29. Get the indices of source boxes that overlap with reference boxes based on a specified threshold.
  30. Args:
  31. src_boxes (np.ndarray): A 2D numpy array of source bounding boxes.
  32. ref_boxes (np.ndarray): A 2D numpy array of reference bounding boxes.
  33. Returns:
  34. match_idx_list (list): A list of indices of source boxes that overlap with reference boxes.
  35. """
  36. match_idx_list = []
  37. src_boxes_num = len(src_boxes)
  38. if src_boxes_num > 0 and len(ref_boxes) > 0:
  39. for rno in range(len(ref_boxes)):
  40. ref_box = ref_boxes[rno]
  41. x1 = np.maximum(ref_box[0], src_boxes[:, 0])
  42. y1 = np.maximum(ref_box[1], src_boxes[:, 1])
  43. x2 = np.minimum(ref_box[2], src_boxes[:, 2])
  44. y2 = np.minimum(ref_box[3], src_boxes[:, 3])
  45. pub_w = x2 - x1
  46. pub_h = y2 - y1
  47. match_idx = np.where((pub_w > 3) & (pub_h > 3))[0]
  48. match_idx_list.extend(match_idx)
  49. return match_idx_list
  50. def get_sub_regions_ocr_res(
  51. overall_ocr_res: OCRResult,
  52. object_boxes: List,
  53. flag_within: bool = True,
  54. return_match_idx: bool = False,
  55. ) -> OCRResult:
  56. """
  57. Filters OCR results to only include text boxes within specified object boxes based on a flag.
  58. Args:
  59. overall_ocr_res (OCRResult): The original OCR result containing all text boxes.
  60. object_boxes (list): A list of bounding boxes for the objects of interest.
  61. flag_within (bool): If True, only include text boxes within the object boxes. If False, exclude text boxes within the object boxes.
  62. return_match_idx (bool): If True, return the list of matching indices.
  63. Returns:
  64. OCRResult: A filtered OCR result containing only the relevant text boxes.
  65. """
  66. sub_regions_ocr_res = {}
  67. sub_regions_ocr_res["rec_polys"] = []
  68. sub_regions_ocr_res["rec_texts"] = []
  69. sub_regions_ocr_res["rec_scores"] = []
  70. sub_regions_ocr_res["rec_boxes"] = []
  71. overall_text_boxes = overall_ocr_res["rec_boxes"]
  72. match_idx_list = get_overlap_boxes_idx(overall_text_boxes, object_boxes)
  73. match_idx_list = list(set(match_idx_list))
  74. for box_no in range(len(overall_text_boxes)):
  75. if flag_within:
  76. if box_no in match_idx_list:
  77. flag_match = True
  78. else:
  79. flag_match = False
  80. else:
  81. if box_no not in match_idx_list:
  82. flag_match = True
  83. else:
  84. flag_match = False
  85. if flag_match:
  86. sub_regions_ocr_res["rec_polys"].append(
  87. overall_ocr_res["rec_polys"][box_no]
  88. )
  89. sub_regions_ocr_res["rec_texts"].append(
  90. overall_ocr_res["rec_texts"][box_no]
  91. )
  92. sub_regions_ocr_res["rec_scores"].append(
  93. overall_ocr_res["rec_scores"][box_no]
  94. )
  95. sub_regions_ocr_res["rec_boxes"].append(
  96. overall_ocr_res["rec_boxes"][box_no]
  97. )
  98. for key in ["rec_polys", "rec_scores", "rec_boxes"]:
  99. sub_regions_ocr_res[key] = np.array(sub_regions_ocr_res[key])
  100. return (
  101. (sub_regions_ocr_res, match_idx_list)
  102. if return_match_idx
  103. else sub_regions_ocr_res
  104. )
  105. def sorted_layout_boxes(res, w):
  106. """
  107. Sort text boxes in order from top to bottom, left to right
  108. Args:
  109. res: List of dictionaries containing layout information.
  110. w: Width of image.
  111. Returns:
  112. List of dictionaries containing sorted layout information.
  113. """
  114. num_boxes = len(res)
  115. if num_boxes == 1:
  116. return res
  117. # Sort on the y axis first or sort it on the x axis
  118. sorted_boxes = sorted(res, key=lambda x: (x["block_bbox"][1], x["block_bbox"][0]))
  119. _boxes = list(sorted_boxes)
  120. new_res = []
  121. res_left = []
  122. res_right = []
  123. i = 0
  124. while True:
  125. if i >= num_boxes:
  126. break
  127. # Check that the bbox is on the left
  128. elif (
  129. _boxes[i]["block_bbox"][0] < w / 4
  130. and _boxes[i]["block_bbox"][2] < 3 * w / 5
  131. ):
  132. res_left.append(_boxes[i])
  133. i += 1
  134. elif _boxes[i]["block_bbox"][0] > 2 * w / 5:
  135. res_right.append(_boxes[i])
  136. i += 1
  137. else:
  138. new_res += res_left
  139. new_res += res_right
  140. new_res.append(_boxes[i])
  141. res_left = []
  142. res_right = []
  143. i += 1
  144. res_left = sorted(res_left, key=lambda x: (x["block_bbox"][1]))
  145. res_right = sorted(res_right, key=lambda x: (x["block_bbox"][1]))
  146. if res_left:
  147. new_res += res_left
  148. if res_right:
  149. new_res += res_right
  150. return new_res
  151. def calculate_projection_overlap_ratio(
  152. bbox1: List[float],
  153. bbox2: List[float],
  154. direction: str = "horizontal",
  155. mode="union",
  156. ) -> float:
  157. """
  158. Calculate the IoU of lines between two bounding boxes.
  159. Args:
  160. bbox1 (List[float]): First bounding box [x_min, y_min, x_max, y_max].
  161. bbox2 (List[float]): Second bounding box [x_min, y_min, x_max, y_max].
  162. direction (str): direction of the projection, "horizontal" or "vertical".
  163. Returns:
  164. float: Line overlap ratio. Returns 0 if there is no overlap.
  165. """
  166. start_index, end_index = 1, 3
  167. if direction == "horizontal":
  168. start_index, end_index = 0, 2
  169. intersection_start = max(bbox1[start_index], bbox2[start_index])
  170. intersection_end = min(bbox1[end_index], bbox2[end_index])
  171. overlap = intersection_end - intersection_start
  172. if overlap <= 0:
  173. return 0
  174. if mode == "union":
  175. ref_width = max(bbox1[end_index], bbox2[end_index]) - min(
  176. bbox1[start_index], bbox2[start_index]
  177. )
  178. elif mode == "small":
  179. ref_width = min(
  180. bbox1[end_index] - bbox1[start_index], bbox2[end_index] - bbox2[start_index]
  181. )
  182. elif mode == "large":
  183. ref_width = max(
  184. bbox1[end_index] - bbox1[start_index], bbox2[end_index] - bbox2[start_index]
  185. )
  186. else:
  187. raise ValueError(
  188. f"Invalid mode {mode}, must be one of ['union', 'small', 'large']."
  189. )
  190. return overlap / ref_width if ref_width > 0 else 0.0
  191. def calculate_overlap_ratio(
  192. bbox1: Union[list, tuple], bbox2: Union[list, tuple], mode="union"
  193. ) -> float:
  194. """
  195. Calculate the overlap ratio between two bounding boxes.
  196. Args:
  197. bbox1 (list or tuple): The first bounding box, format [x_min, y_min, x_max, y_max]
  198. bbox2 (list or tuple): The second bounding box, format [x_min, y_min, x_max, y_max]
  199. mode (str): The mode of calculation, either 'union', 'small', or 'large'.
  200. Returns:
  201. float: The overlap ratio value between the two bounding boxes
  202. """
  203. x_min_inter = max(bbox1[0], bbox2[0])
  204. y_min_inter = max(bbox1[1], bbox2[1])
  205. x_max_inter = min(bbox1[2], bbox2[2])
  206. y_max_inter = min(bbox1[3], bbox2[3])
  207. inter_width = max(0, x_max_inter - x_min_inter)
  208. inter_height = max(0, y_max_inter - y_min_inter)
  209. inter_area = inter_width * inter_height
  210. bbox1_area = caculate_bbox_area(bbox1)
  211. bbox2_area = caculate_bbox_area(bbox2)
  212. if mode == "union":
  213. ref_area = bbox1_area + bbox2_area - inter_area
  214. elif mode == "small":
  215. ref_area = min(bbox1_area, bbox2_area)
  216. elif mode == "large":
  217. ref_area = max(bbox1_area, bbox2_area)
  218. else:
  219. raise ValueError(
  220. f"Invalid mode {mode}, must be one of ['union', 'small', 'large']."
  221. )
  222. if ref_area == 0:
  223. return 0.0
  224. return inter_area / ref_area
  225. def group_boxes_into_lines(ocr_rec_res, line_height_iou_threshold):
  226. rec_boxes = ocr_rec_res["boxes"]
  227. rec_texts = ocr_rec_res["rec_texts"]
  228. rec_labels = ocr_rec_res["rec_labels"]
  229. text_boxes = [
  230. rec_boxes[i] for i in range(len(rec_boxes)) if rec_labels[i] == "text"
  231. ]
  232. text_orientation = calculate_text_orientation(text_boxes)
  233. match_direction = "vertical" if text_orientation == "horizontal" else "horizontal"
  234. line_start_index = 1 if text_orientation == "horizontal" else 0
  235. line_end_index = 3 if text_orientation == "horizontal" else 2
  236. spans = list(zip(rec_boxes, rec_texts, rec_labels))
  237. sort_index = 1
  238. reverse = False
  239. if text_orientation == "vertical":
  240. sort_index = 0
  241. reverse = True
  242. spans.sort(key=lambda span: span[0][sort_index], reverse=reverse)
  243. spans = [list(span) for span in spans]
  244. lines = []
  245. line = [spans[0]]
  246. line_region_box = spans[0][0].copy()
  247. line_heights = []
  248. # merge line
  249. for span in spans[1:]:
  250. rec_bbox = span[0]
  251. if (
  252. calculate_projection_overlap_ratio(
  253. line_region_box, rec_bbox, match_direction, mode="small"
  254. )
  255. >= line_height_iou_threshold
  256. ):
  257. line.append(span)
  258. line_region_box[line_start_index] = min(
  259. line_region_box[line_start_index], rec_bbox[line_start_index]
  260. )
  261. line_region_box[line_end_index] = max(
  262. line_region_box[line_end_index], rec_bbox[line_end_index]
  263. )
  264. else:
  265. line_heights.append(
  266. line_region_box[line_end_index] - line_region_box[line_start_index]
  267. )
  268. lines.append(line)
  269. line = [span]
  270. line_region_box = rec_bbox.copy()
  271. lines.append(line)
  272. line_heights.append(
  273. line_region_box[line_end_index] - line_region_box[line_start_index]
  274. )
  275. min_height = min(line_heights) if line_heights else 0
  276. max_height = max(line_heights) if line_heights else 0
  277. if max_height > min_height * 2 and text_orientation == "vertical":
  278. line_heights = np.array(line_heights)
  279. min_height_num = np.sum(line_heights < min_height * 1.1)
  280. if min_height_num < len(lines) * 0.4:
  281. condition = line_heights > min_height * 1.1
  282. lines = [value for value, keep in zip(lines, condition) if keep]
  283. return lines, text_orientation, np.mean(line_heights)
  284. def calculate_minimum_enclosing_bbox(bboxes):
  285. """
  286. Calculate the minimum enclosing bounding box for a list of bounding boxes.
  287. Args:
  288. bboxes (list): A list of bounding boxes represented as lists of four integers [x1, y1, x2, y2].
  289. Returns:
  290. list: The minimum enclosing bounding box represented as a list of four integers [x1, y1, x2, y2].
  291. """
  292. if not bboxes:
  293. raise ValueError("The list of bounding boxes is empty.")
  294. # Convert the list of bounding boxes to a NumPy array
  295. bboxes_array = np.array(bboxes)
  296. # Compute the minimum and maximum values along the respective axes
  297. min_x = np.min(bboxes_array[:, 0])
  298. min_y = np.min(bboxes_array[:, 1])
  299. max_x = np.max(bboxes_array[:, 2])
  300. max_y = np.max(bboxes_array[:, 3])
  301. # Return the minimum enclosing bounding box
  302. return [min_x, min_y, max_x, max_y]
  303. def calculate_text_orientation(
  304. bboxes: List[List[int]], orientation_ratio: float = 1.5
  305. ) -> bool:
  306. """
  307. Calculate the orientation of the text based on the bounding boxes.
  308. Args:
  309. bboxes (list): A list of bounding boxes.
  310. orientation_ratio (float): Ratio for determining orientation. Default is 1.5.
  311. Returns:
  312. str: "horizontal" or "vertical".
  313. """
  314. horizontal_box_num = 0
  315. for bbox in bboxes:
  316. if len(bbox) != 4:
  317. raise ValueError(
  318. "Invalid bounding box format. Expected a list of length 4."
  319. )
  320. x1, y1, x2, y2 = bbox
  321. width = x2 - x1
  322. height = y2 - y1
  323. horizontal_box_num += 1 if width * orientation_ratio >= height else 0
  324. return "horizontal" if horizontal_box_num >= len(bboxes) * 0.5 else "vertical"
  325. def is_english_letter(char):
  326. return bool(re.match(r"^[A-Za-z]$", char))
  327. def is_numeric(char):
  328. return bool(re.match(r"^[\d.]+$", char))
  329. def is_non_breaking_punctuation(char):
  330. """
  331. 判断一个字符是否是不需要换行的标点符号,包括全角和半角的符号。
  332. :param char: str, 单个字符
  333. :return: bool, 如果字符是不需要换行的标点符号,返回True,否则返回False
  334. """
  335. non_breaking_punctuations = {
  336. ",", # 半角逗号
  337. ",", # 全角逗号
  338. "、", # 顿号
  339. ";", # 半角分号
  340. ";", # 全角分号
  341. ":", # 半角冒号
  342. ":", # 全角冒号
  343. "-", # 连字符
  344. }
  345. return char in non_breaking_punctuations
  346. def format_line(
  347. line: List[List[Union[List[int], str]]],
  348. text_direction: int,
  349. block_width: int,
  350. block_start_coordinate: int,
  351. block_stop_coordinate: int,
  352. line_gap_limit: int = 10,
  353. block_label: str = "text",
  354. ) -> None:
  355. """
  356. Format a line of text spans based on layout constraints.
  357. Args:
  358. line (list): A list of spans, where each span is a list containing a bounding box and text.
  359. block_left_coordinate (int): The text line directional minimum coordinate of the layout bounding box.
  360. block_stop_coordinate (int): The text line directional maximum x-coordinate of the layout bounding box.
  361. first_line_span_limit (int): The limit for the number of pixels before the first span that should be considered part of the first line. Default is 10.
  362. line_gap_limit (int): The limit for the number of pixels after the last span that should be considered part of the last line. Default is 10.
  363. block_label (str): The label associated with the entire block. Default is 'text'.
  364. Returns:
  365. None: The function modifies the line in place.
  366. """
  367. first_span_box = line[0][0]
  368. last_span_box = line[-1][0]
  369. for span in line:
  370. if span[2] == "formula" and block_label != "formula":
  371. formula_rec = span[1]
  372. if not formula_rec.startswith("$") and not formula_rec.endswith("$"):
  373. if len(line) > 1:
  374. span[1] = f"${span[1]}$"
  375. else:
  376. span[1] = f"\n${span[1]}$"
  377. line_text = ""
  378. for span in line:
  379. _, text, label = span
  380. line_text += text
  381. if len(text) > 0 and is_english_letter(line_text[-1]) or label == "formula":
  382. line_text += " "
  383. if text_direction == "horizontal":
  384. text_start_index = 0
  385. text_stop_index = 2
  386. else:
  387. text_start_index = 1
  388. text_stop_index = 3
  389. need_new_line = False
  390. if (
  391. len(line_text) > 0
  392. and not is_english_letter(line_text[-1])
  393. and not is_non_breaking_punctuation(line_text[-1])
  394. ):
  395. if (
  396. text_direction == "horizontal"
  397. and block_stop_coordinate - last_span_box[text_stop_index] > line_gap_limit
  398. ) or (
  399. text_direction == "vertical"
  400. and (
  401. block_stop_coordinate - last_span_box[text_stop_index] > line_gap_limit
  402. or first_span_box[1] - block_start_coordinate > line_gap_limit
  403. )
  404. ):
  405. need_new_line = True
  406. if line_text.endswith("-"):
  407. line_text = line_text[:-1]
  408. elif (
  409. len(line_text) > 0 and is_english_letter(line_text[-1])
  410. ) or line_text.endswith("$"):
  411. line_text += " "
  412. elif (
  413. len(line_text) > 0
  414. and not is_english_letter(line_text[-1])
  415. and not is_non_breaking_punctuation(line_text[-1])
  416. and not is_numeric(line_text[-1])
  417. ) or text_direction == "vertical":
  418. if block_stop_coordinate - last_span_box[text_stop_index] > block_width * 0.4:
  419. line_text += "\n"
  420. if (
  421. first_span_box[text_start_index] - block_start_coordinate
  422. > block_width * 0.4
  423. ):
  424. line_text = "\n" + line_text
  425. return line_text, need_new_line
  426. def split_boxes_by_projection(spans: List[List[int]], direction, offset=1e-5):
  427. """
  428. Check if there is any complete containment in the x-direction
  429. between the bounding boxes and split the containing box accordingly.
  430. Args:
  431. spans (list of lists): Each element is a list containing an ndarray of length 4, a text string, and a label.
  432. direction: 'horizontal' or 'vertical', indicating whether the spans are arranged horizontally or vertically.
  433. offset (float): A small offset value to ensure that the split boxes are not too close to the original boxes.
  434. Returns:
  435. A new list of boxes, including split boxes, with the same `rec_text` and `label` attributes.
  436. """
  437. def is_projection_contained(box_a, box_b, start_idx, end_idx):
  438. """Check if box_a completely contains box_b in the x-direction."""
  439. return box_a[start_idx] <= box_b[start_idx] and box_a[end_idx] >= box_b[end_idx]
  440. new_boxes = []
  441. if direction == "horizontal":
  442. projection_start_index, projection_end_index = 0, 2
  443. else:
  444. projection_start_index, projection_end_index = 1, 3
  445. for i in range(len(spans)):
  446. span = spans[i]
  447. is_split = False
  448. for j in range(i, len(spans)):
  449. box_b = spans[j][0]
  450. box_a, text, label = span
  451. if is_projection_contained(
  452. box_a, box_b, projection_start_index, projection_end_index
  453. ):
  454. is_split = True
  455. # Split box_a based on the x-coordinates of box_b
  456. if box_a[projection_start_index] < box_b[projection_start_index]:
  457. w = (
  458. box_b[projection_start_index]
  459. - offset
  460. - box_a[projection_start_index]
  461. )
  462. if w > 1:
  463. new_bbox = box_a.copy()
  464. new_bbox[projection_end_index] = (
  465. box_b[projection_start_index] - offset
  466. )
  467. new_boxes.append(
  468. [
  469. np.array(new_bbox),
  470. text,
  471. label,
  472. ]
  473. )
  474. if box_a[projection_end_index] > box_b[projection_end_index]:
  475. w = (
  476. box_a[projection_end_index]
  477. - box_b[projection_end_index]
  478. + offset
  479. )
  480. if w > 1:
  481. box_a[projection_start_index] = (
  482. box_b[projection_end_index] + offset
  483. )
  484. span = [
  485. np.array(box_a),
  486. text,
  487. label,
  488. ]
  489. if j == len(spans) - 1 and is_split:
  490. new_boxes.append(span)
  491. if not is_split:
  492. new_boxes.append(span)
  493. return new_boxes
  494. def remove_extra_space(input_text: str) -> str:
  495. """
  496. Process the input text to handle spaces.
  497. The function removes multiple consecutive spaces between Chinese characters and ensures that
  498. only a single space is retained between Chinese and non-Chinese characters.
  499. Args:
  500. input_text (str): The text to be processed.
  501. Returns:
  502. str: The processed text with properly formatted spaces.
  503. """
  504. # Remove spaces between Chinese characters
  505. text_without_spaces = re.sub(
  506. r"(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])", "", input_text
  507. )
  508. # Ensure single space between Chinese and non-Chinese characters
  509. text_with_single_spaces = re.sub(
  510. r"(?<=[\u4e00-\u9fff])\s+(?=[^\u4e00-\u9fff])|(?<=[^\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])",
  511. " ",
  512. text_without_spaces,
  513. )
  514. # Reduce any remaining consecutive spaces to a single space
  515. final_text = re.sub(r"\s+", " ", text_with_single_spaces).strip()
  516. return final_text
  517. def gather_imgs(original_img, layout_det_objs):
  518. imgs_in_doc = []
  519. for det_obj in layout_det_objs:
  520. if det_obj["label"] in BLOCK_LABEL_MAP["image_labels"]:
  521. label = det_obj["label"]
  522. x_min, y_min, x_max, y_max = list(map(int, det_obj["coordinate"]))
  523. img_path = f"imgs/img_in_{label}_box_{x_min}_{y_min}_{x_max}_{y_max}.jpg"
  524. img = Image.fromarray(original_img[y_min:y_max, x_min:x_max, ::-1])
  525. imgs_in_doc.append(
  526. {
  527. "path": img_path,
  528. "img": img,
  529. "coordinate": (x_min, y_min, x_max, y_max),
  530. "score": det_obj["score"],
  531. }
  532. )
  533. return imgs_in_doc
  534. def _get_minbox_if_overlap_by_ratio(
  535. bbox1: Union[List[int], Tuple[int, int, int, int]],
  536. bbox2: Union[List[int], Tuple[int, int, int, int]],
  537. ratio: float,
  538. smaller: bool = True,
  539. ) -> Optional[Union[List[int], Tuple[int, int, int, int]]]:
  540. """
  541. Determine if the overlap area between two bounding boxes exceeds a given ratio
  542. and return the smaller (or larger) bounding box based on the `smaller` flag.
  543. Args:
  544. bbox1 (Union[List[int], Tuple[int, int, int, int]]): Coordinates of the first bounding box [x_min, y_min, x_max, y_max].
  545. bbox2 (Union[List[int], Tuple[int, int, int, int]]): Coordinates of the second bounding box [x_min, y_min, x_max, y_max].
  546. ratio (float): The overlap ratio threshold.
  547. smaller (bool): If True, return the smaller bounding box; otherwise, return the larger one.
  548. Returns:
  549. Optional[Union[List[int], Tuple[int, int, int, int]]]:
  550. The selected bounding box or None if the overlap ratio is not exceeded.
  551. """
  552. # Calculate the areas of both bounding boxes
  553. area1 = caculate_bbox_area(bbox1)
  554. area2 = caculate_bbox_area(bbox2)
  555. # Calculate the overlap ratio using a helper function
  556. overlap_ratio = calculate_overlap_ratio(bbox1, bbox2, mode="small")
  557. # Check if the overlap ratio exceeds the threshold
  558. if overlap_ratio > ratio:
  559. if (area1 <= area2 and smaller) or (area1 >= area2 and not smaller):
  560. return 1
  561. else:
  562. return 2
  563. return None
  564. def remove_overlap_blocks(
  565. blocks: List[Dict[str, List[int]]], threshold: float = 0.65, smaller: bool = True
  566. ) -> Tuple[List[Dict[str, List[int]]], List[Dict[str, List[int]]]]:
  567. """
  568. Remove overlapping blocks based on a specified overlap ratio threshold.
  569. Args:
  570. blocks (List[Dict[str, List[int]]]): List of block dictionaries, each containing a 'block_bbox' key.
  571. threshold (float): Ratio threshold to determine significant overlap.
  572. smaller (bool): If True, the smaller block in overlap is removed.
  573. Returns:
  574. Tuple[List[Dict[str, List[int]]], List[Dict[str, List[int]]]]:
  575. A tuple containing the updated list of blocks and a list of dropped blocks.
  576. """
  577. dropped_indexes = set()
  578. blocks = deepcopy(blocks)
  579. overlap_image_blocks = []
  580. # Iterate over each pair of blocks to find overlaps
  581. for i, block1 in enumerate(blocks["boxes"]):
  582. for j in range(i + 1, len(blocks["boxes"])):
  583. block2 = blocks["boxes"][j]
  584. # Skip blocks that are already marked for removal
  585. if i in dropped_indexes or j in dropped_indexes:
  586. continue
  587. # Check for overlap and determine which block to remove
  588. overlap_box_index = _get_minbox_if_overlap_by_ratio(
  589. block1["coordinate"],
  590. block2["coordinate"],
  591. threshold,
  592. smaller=smaller,
  593. )
  594. if overlap_box_index is not None:
  595. is_block1_image = block1["label"] == "image"
  596. is_block2_image = block2["label"] == "image"
  597. if is_block1_image != is_block2_image:
  598. # 如果只有一个块在视觉标签中,删除在视觉标签中的那个块
  599. drop_index = i if is_block1_image else j
  600. overlap_image_blocks.append(blocks["boxes"][drop_index])
  601. else:
  602. # 如果两个块都在或都不在视觉标签中,根据 overlap_box_index 决定删除哪个块
  603. drop_index = i if overlap_box_index == 1 else j
  604. dropped_indexes.add(drop_index)
  605. # Remove marked blocks from the original list
  606. for index in sorted(dropped_indexes, reverse=True):
  607. del blocks["boxes"][index]
  608. return blocks
  609. def get_bbox_intersection(bbox1, bbox2, return_format="bbox"):
  610. """
  611. Compute the intersection of two bounding boxes, supporting both 4-coordinate and 8-coordinate formats.
  612. Args:
  613. bbox1 (tuple): The first bounding box, either in 4-coordinate format (x_min, y_min, x_max, y_max)
  614. or 8-coordinate format (x1, y1, x2, y2, x3, y3, x4, y4).
  615. bbox2 (tuple): The second bounding box in the same format as bbox1.
  616. return_format (str): The format of the output intersection, either 'bbox' or 'poly'.
  617. Returns:
  618. tuple or None: The intersection bounding box in the specified format, or None if there is no intersection.
  619. """
  620. bbox1 = np.array(bbox1)
  621. bbox2 = np.array(bbox2)
  622. # Convert both bounding boxes to rectangles
  623. rect1 = bbox1 if len(bbox1.shape) == 1 else convert_points_to_boxes([bbox1])[0]
  624. rect2 = bbox2 if len(bbox2.shape) == 1 else convert_points_to_boxes([bbox2])[0]
  625. # Calculate the intersection rectangle
  626. x_min_inter = max(rect1[0], rect2[0])
  627. y_min_inter = max(rect1[1], rect2[1])
  628. x_max_inter = min(rect1[2], rect2[2])
  629. y_max_inter = min(rect1[3], rect2[3])
  630. # Check if there is an intersection
  631. if x_min_inter >= x_max_inter or y_min_inter >= y_max_inter:
  632. return None
  633. if return_format == "bbox":
  634. return np.array([x_min_inter, y_min_inter, x_max_inter, y_max_inter])
  635. elif return_format == "poly":
  636. return np.array(
  637. [
  638. [x_min_inter, y_min_inter],
  639. [x_max_inter, y_min_inter],
  640. [x_max_inter, y_max_inter],
  641. [x_min_inter, y_max_inter],
  642. ],
  643. dtype=np.int16,
  644. )
  645. else:
  646. raise ValueError("return_format must be either 'bbox' or 'poly'.")
  647. def shrink_supplement_region_bbox(
  648. supplement_region_bbox,
  649. ref_region_bbox,
  650. image_width,
  651. image_height,
  652. block_idxes_set,
  653. block_bboxes,
  654. ) -> List:
  655. """
  656. Shrink the supplement region bbox according to the reference region bbox and match the block bboxes.
  657. Args:
  658. supplement_region_bbox (list): The supplement region bbox.
  659. ref_region_bbox (list): The reference region bbox.
  660. image_width (int): The width of the image.
  661. image_height (int): The height of the image.
  662. block_idxes_set (set): The indexes of the blocks that intersect with the region bbox.
  663. block_bboxes (dict): The dictionary of block bboxes.
  664. Returns:
  665. list: The new region bbox and the matched block idxes.
  666. """
  667. x1, y1, x2, y2 = supplement_region_bbox
  668. x1_prime, y1_prime, x2_prime, y2_prime = ref_region_bbox
  669. index_conversion_map = {0: 2, 1: 3, 2: 0, 3: 1}
  670. edge_distance_list = [
  671. (x1_prime - x1) / image_width,
  672. (y1_prime - y1) / image_height,
  673. (x2 - x2_prime) / image_width,
  674. (y2 - y2_prime) / image_height,
  675. ]
  676. edge_distance_list_tmp = edge_distance_list[:]
  677. min_distance = min(edge_distance_list)
  678. src_index = index_conversion_map[edge_distance_list.index(min_distance)]
  679. if len(block_idxes_set) == 0:
  680. return supplement_region_bbox, []
  681. for _ in range(3):
  682. dst_index = index_conversion_map[src_index]
  683. tmp_region_bbox = supplement_region_bbox[:]
  684. tmp_region_bbox[dst_index] = ref_region_bbox[src_index]
  685. iner_block_idxes, split_block_idxes = [], []
  686. for block_idx in block_idxes_set:
  687. overlap_ratio = calculate_overlap_ratio(
  688. tmp_region_bbox, block_bboxes[block_idx], mode="small"
  689. )
  690. if overlap_ratio > REGION_SETTINGS.get(
  691. "match_block_overlap_ratio_threshold", 0.8
  692. ):
  693. iner_block_idxes.append(block_idx)
  694. elif overlap_ratio > REGION_SETTINGS.get(
  695. "split_block_overlap_ratio_threshold", 0.4
  696. ):
  697. split_block_idxes.append(block_idx)
  698. if len(iner_block_idxes) > 0:
  699. if len(split_block_idxes) > 0:
  700. for split_block_idx in split_block_idxes:
  701. split_block_bbox = block_bboxes[split_block_idx]
  702. x1, y1, x2, y2 = tmp_region_bbox
  703. x1_prime, y1_prime, x2_prime, y2_prime = split_block_bbox
  704. edge_distance_list = [
  705. (x1_prime - x1) / image_width,
  706. (y1_prime - y1) / image_height,
  707. (x2 - x2_prime) / image_width,
  708. (y2 - y2_prime) / image_height,
  709. ]
  710. max_distance = max(edge_distance_list)
  711. src_index = edge_distance_list.index(max_distance)
  712. dst_index = index_conversion_map[src_index]
  713. tmp_region_bbox[dst_index] = split_block_bbox[src_index]
  714. tmp_region_bbox, iner_idxes = shrink_supplement_region_bbox(
  715. tmp_region_bbox,
  716. ref_region_bbox,
  717. image_width,
  718. image_height,
  719. iner_block_idxes,
  720. block_bboxes,
  721. )
  722. if len(iner_idxes) == 0:
  723. continue
  724. matched_bboxes = [block_bboxes[idx] for idx in iner_block_idxes]
  725. supplement_region_bbox = calculate_minimum_enclosing_bbox(matched_bboxes)
  726. break
  727. else:
  728. edge_distance_list_tmp = [
  729. x for x in edge_distance_list_tmp if x != min_distance
  730. ]
  731. min_distance = min(edge_distance_list_tmp)
  732. src_index = index_conversion_map[edge_distance_list.index(min_distance)]
  733. return supplement_region_bbox, iner_block_idxes
  734. def update_region_box(bbox, region_box):
  735. if region_box is None:
  736. return bbox
  737. x1, y1, x2, y2 = bbox
  738. x1_region, y1_region, x2_region, y2_region = region_box
  739. x1_region = int(min(x1, x1_region))
  740. y1_region = int(min(y1, y1_region))
  741. x2_region = int(max(x2, x2_region))
  742. y2_region = int(max(y2, y2_region))
  743. region_box = [x1_region, y1_region, x2_region, y2_region]
  744. return region_box
  745. def convert_formula_res_to_ocr_format(formula_res_list: List, ocr_res: dict):
  746. for formula_res in formula_res_list:
  747. x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
  748. poly_points = [
  749. (x_min, y_min),
  750. (x_max, y_min),
  751. (x_max, y_max),
  752. (x_min, y_max),
  753. ]
  754. ocr_res["dt_polys"].append(poly_points)
  755. formula_res_text: str = formula_res["rec_formula"]
  756. ocr_res["rec_texts"].append(formula_res_text)
  757. if ocr_res["rec_boxes"].size == 0:
  758. ocr_res["rec_boxes"] = np.array(formula_res["dt_polys"])
  759. else:
  760. ocr_res["rec_boxes"] = np.vstack(
  761. (ocr_res["rec_boxes"], [formula_res["dt_polys"]])
  762. )
  763. ocr_res["rec_labels"].append("formula")
  764. ocr_res["rec_polys"].append(poly_points)
  765. ocr_res["rec_scores"].append(1)
  766. def caculate_bbox_area(bbox):
  767. x1, y1, x2, y2 = map(float, bbox)
  768. area = abs((x2 - x1) * (y2 - y1))
  769. return area
  770. def get_show_color(label: str, order_label=False) -> Tuple:
  771. if order_label:
  772. label_colors = {
  773. "doc_title": (255, 248, 220, 100), # Cornsilk
  774. "doc_title_text": (255, 239, 213, 100),
  775. "paragraph_title": (102, 102, 255, 100),
  776. "sub_paragraph_title": (102, 178, 255, 100),
  777. "vision": (153, 255, 51, 100),
  778. "vision_title": (144, 238, 144, 100), # Light Green
  779. "vision_footnote": (144, 238, 144, 100), # Light Green
  780. "normal_text": (153, 0, 76, 100),
  781. "cross_layout": (53, 218, 207, 100), # Thistle
  782. "cross_reference": (221, 160, 221, 100), # Floral White
  783. }
  784. else:
  785. label_colors = {
  786. # Medium Blue (from 'titles_list')
  787. "paragraph_title": (102, 102, 255, 100),
  788. "doc_title": (255, 248, 220, 100), # Cornsilk
  789. # Light Yellow (from 'tables_caption_list')
  790. "table_title": (255, 255, 102, 100),
  791. # Sky Blue (from 'imgs_caption_list')
  792. "figure_title": (102, 178, 255, 100),
  793. "chart_title": (221, 160, 221, 100), # Plum
  794. "vision_footnote": (144, 238, 144, 100), # Light Green
  795. # Deep Purple (from 'texts_list')
  796. "text": (153, 0, 76, 100),
  797. # Bright Green (from 'interequations_list')
  798. "formula": (0, 255, 0, 100),
  799. "abstract": (255, 239, 213, 100), # Papaya Whip
  800. # Medium Green (from 'lists_list' and 'indexs_list')
  801. "content": (40, 169, 92, 100),
  802. # Neutral Gray (from 'dropped_bbox_list')
  803. "seal": (158, 158, 158, 100),
  804. # Olive Yellow (from 'tables_body_list')
  805. "table": (204, 204, 0, 100),
  806. # Bright Green (from 'imgs_body_list')
  807. "image": (153, 255, 51, 100),
  808. # Bright Green (from 'imgs_body_list')
  809. "figure": (153, 255, 51, 100),
  810. "chart": (216, 191, 216, 100), # Thistle
  811. # Pale Yellow-Green (from 'tables_footnote_list')
  812. "reference": (229, 255, 204, 100),
  813. # "reference_content": (229, 255, 204, 100),
  814. "algorithm": (255, 250, 240, 100), # Floral White
  815. }
  816. default_color = (158, 158, 158, 100)
  817. return label_colors.get(label, default_color)