utils.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. __all__ = [
  15. "get_sub_regions_ocr_res",
  16. "get_show_color",
  17. "sorted_layout_boxes",
  18. ]
  19. import re
  20. from copy import deepcopy
  21. from typing import Dict, List, Optional, Tuple, Union
  22. import numpy as np
  23. from PIL import Image
  24. from ..components import convert_points_to_boxes
  25. from ..ocr.result import OCRResult
  26. from .setting import BLOCK_LABEL_MAP, REGION_SETTINGS
  27. def get_overlap_boxes_idx(src_boxes: np.ndarray, ref_boxes: np.ndarray) -> List:
  28. """
  29. Get the indices of source boxes that overlap with reference boxes based on a specified threshold.
  30. Args:
  31. src_boxes (np.ndarray): A 2D numpy array of source bounding boxes.
  32. ref_boxes (np.ndarray): A 2D numpy array of reference bounding boxes.
  33. Returns:
  34. match_idx_list (list): A list of indices of source boxes that overlap with reference boxes.
  35. """
  36. match_idx_list = []
  37. src_boxes_num = len(src_boxes)
  38. if src_boxes_num > 0 and len(ref_boxes) > 0:
  39. for rno in range(len(ref_boxes)):
  40. ref_box = ref_boxes[rno]
  41. x1 = np.maximum(ref_box[0], src_boxes[:, 0])
  42. y1 = np.maximum(ref_box[1], src_boxes[:, 1])
  43. x2 = np.minimum(ref_box[2], src_boxes[:, 2])
  44. y2 = np.minimum(ref_box[3], src_boxes[:, 3])
  45. pub_w = x2 - x1
  46. pub_h = y2 - y1
  47. match_idx = np.where((pub_w > 3) & (pub_h > 3))[0]
  48. match_idx_list.extend(match_idx)
  49. return match_idx_list
  50. def get_sub_regions_ocr_res(
  51. overall_ocr_res: OCRResult,
  52. object_boxes: List,
  53. flag_within: bool = True,
  54. return_match_idx: bool = False,
  55. ) -> OCRResult:
  56. """
  57. Filters OCR results to only include text boxes within specified object boxes based on a flag.
  58. Args:
  59. overall_ocr_res (OCRResult): The original OCR result containing all text boxes.
  60. object_boxes (list): A list of bounding boxes for the objects of interest.
  61. flag_within (bool): If True, only include text boxes within the object boxes. If False, exclude text boxes within the object boxes.
  62. return_match_idx (bool): If True, return the list of matching indices.
  63. Returns:
  64. OCRResult: A filtered OCR result containing only the relevant text boxes.
  65. """
  66. sub_regions_ocr_res = {}
  67. sub_regions_ocr_res["rec_polys"] = []
  68. sub_regions_ocr_res["rec_texts"] = []
  69. sub_regions_ocr_res["rec_scores"] = []
  70. sub_regions_ocr_res["rec_boxes"] = []
  71. overall_text_boxes = overall_ocr_res["rec_boxes"]
  72. match_idx_list = get_overlap_boxes_idx(overall_text_boxes, object_boxes)
  73. match_idx_list = list(set(match_idx_list))
  74. for box_no in range(len(overall_text_boxes)):
  75. if flag_within:
  76. if box_no in match_idx_list:
  77. flag_match = True
  78. else:
  79. flag_match = False
  80. else:
  81. if box_no not in match_idx_list:
  82. flag_match = True
  83. else:
  84. flag_match = False
  85. if flag_match:
  86. sub_regions_ocr_res["rec_polys"].append(
  87. overall_ocr_res["rec_polys"][box_no]
  88. )
  89. sub_regions_ocr_res["rec_texts"].append(
  90. overall_ocr_res["rec_texts"][box_no]
  91. )
  92. sub_regions_ocr_res["rec_scores"].append(
  93. overall_ocr_res["rec_scores"][box_no]
  94. )
  95. sub_regions_ocr_res["rec_boxes"].append(
  96. overall_ocr_res["rec_boxes"][box_no]
  97. )
  98. for key in ["rec_polys", "rec_scores", "rec_boxes"]:
  99. sub_regions_ocr_res[key] = np.array(sub_regions_ocr_res[key])
  100. return (
  101. (sub_regions_ocr_res, match_idx_list)
  102. if return_match_idx
  103. else sub_regions_ocr_res
  104. )
  105. def sorted_layout_boxes(res, w):
  106. """
  107. Sort text boxes in order from top to bottom, left to right
  108. Args:
  109. res: List of dictionaries containing layout information.
  110. w: Width of image.
  111. Returns:
  112. List of dictionaries containing sorted layout information.
  113. """
  114. num_boxes = len(res)
  115. if num_boxes == 1:
  116. return res
  117. # Sort on the y axis first or sort it on the x axis
  118. sorted_boxes = sorted(res, key=lambda x: (x["block_bbox"][1], x["block_bbox"][0]))
  119. _boxes = list(sorted_boxes)
  120. new_res = []
  121. res_left = []
  122. res_right = []
  123. i = 0
  124. while True:
  125. if i >= num_boxes:
  126. break
  127. # Check that the bbox is on the left
  128. elif (
  129. _boxes[i]["block_bbox"][0] < w / 4
  130. and _boxes[i]["block_bbox"][2] < 3 * w / 5
  131. ):
  132. res_left.append(_boxes[i])
  133. i += 1
  134. elif _boxes[i]["block_bbox"][0] > 2 * w / 5:
  135. res_right.append(_boxes[i])
  136. i += 1
  137. else:
  138. new_res += res_left
  139. new_res += res_right
  140. new_res.append(_boxes[i])
  141. res_left = []
  142. res_right = []
  143. i += 1
  144. res_left = sorted(res_left, key=lambda x: (x["block_bbox"][1]))
  145. res_right = sorted(res_right, key=lambda x: (x["block_bbox"][1]))
  146. if res_left:
  147. new_res += res_left
  148. if res_right:
  149. new_res += res_right
  150. return new_res
  151. def calculate_projection_overlap_ratio(
  152. bbox1: List[float],
  153. bbox2: List[float],
  154. direction: str = "horizontal",
  155. mode="union",
  156. ) -> float:
  157. """
  158. Calculate the IoU of lines between two bounding boxes.
  159. Args:
  160. bbox1 (List[float]): First bounding box [x_min, y_min, x_max, y_max].
  161. bbox2 (List[float]): Second bounding box [x_min, y_min, x_max, y_max].
  162. direction (str): direction of the projection, "horizontal" or "vertical".
  163. Returns:
  164. float: Line overlap ratio. Returns 0 if there is no overlap.
  165. """
  166. start_index, end_index = 1, 3
  167. if direction == "horizontal":
  168. start_index, end_index = 0, 2
  169. intersection_start = max(bbox1[start_index], bbox2[start_index])
  170. intersection_end = min(bbox1[end_index], bbox2[end_index])
  171. overlap = intersection_end - intersection_start
  172. if overlap <= 0:
  173. return 0
  174. if mode == "union":
  175. ref_width = max(bbox1[end_index], bbox2[end_index]) - min(
  176. bbox1[start_index], bbox2[start_index]
  177. )
  178. elif mode == "small":
  179. ref_width = min(
  180. bbox1[end_index] - bbox1[start_index], bbox2[end_index] - bbox2[start_index]
  181. )
  182. elif mode == "large":
  183. ref_width = max(
  184. bbox1[end_index] - bbox1[start_index], bbox2[end_index] - bbox2[start_index]
  185. )
  186. else:
  187. raise ValueError(
  188. f"Invalid mode {mode}, must be one of ['union', 'small', 'large']."
  189. )
  190. return overlap / ref_width if ref_width > 0 else 0.0
  191. def calculate_overlap_ratio(
  192. bbox1: Union[np.ndarray, list, tuple],
  193. bbox2: Union[np.ndarray, list, tuple],
  194. mode="union",
  195. ) -> float:
  196. """
  197. Calculate the overlap ratio between two bounding boxes using NumPy.
  198. Args:
  199. bbox1 (np.ndarray, list or tuple): The first bounding box, format [x_min, y_min, x_max, y_max]
  200. bbox2 (np.ndarray, list or tuple): The second bounding box, format [x_min, y_min, x_max, y_max]
  201. mode (str): The mode of calculation, either 'union', 'small', or 'large'.
  202. Returns:
  203. float: The overlap ratio value between the two bounding boxes
  204. """
  205. bbox1 = np.array(bbox1)
  206. bbox2 = np.array(bbox2)
  207. x_min_inter = np.maximum(bbox1[0], bbox2[0])
  208. y_min_inter = np.maximum(bbox1[1], bbox2[1])
  209. x_max_inter = np.minimum(bbox1[2], bbox2[2])
  210. y_max_inter = np.minimum(bbox1[3], bbox2[3])
  211. inter_width = np.maximum(0, x_max_inter - x_min_inter)
  212. inter_height = np.maximum(0, y_max_inter - y_min_inter)
  213. inter_area = inter_width * inter_height
  214. bbox1_area = calculate_bbox_area(bbox1)
  215. bbox2_area = calculate_bbox_area(bbox2)
  216. if mode == "union":
  217. ref_area = bbox1_area + bbox2_area - inter_area
  218. elif mode == "small":
  219. ref_area = np.minimum(bbox1_area, bbox2_area)
  220. elif mode == "large":
  221. ref_area = np.maximum(bbox1_area, bbox2_area)
  222. else:
  223. raise ValueError(
  224. f"Invalid mode {mode}, must be one of ['union', 'small', 'large']."
  225. )
  226. if ref_area == 0:
  227. return 0.0
  228. return inter_area / ref_area
  229. def calculate_minimum_enclosing_bbox(bboxes):
  230. """
  231. Calculate the minimum enclosing bounding box for a list of bounding boxes.
  232. Args:
  233. bboxes (list): A list of bounding boxes represented as lists of four integers [x1, y1, x2, y2].
  234. Returns:
  235. list: The minimum enclosing bounding box represented as a list of four integers [x1, y1, x2, y2].
  236. """
  237. if not bboxes:
  238. raise ValueError("The list of bounding boxes is empty.")
  239. # Convert the list of bounding boxes to a NumPy array
  240. bboxes_array = np.array(bboxes)
  241. # Compute the minimum and maximum values along the respective axes
  242. min_x = np.min(bboxes_array[:, 0])
  243. min_y = np.min(bboxes_array[:, 1])
  244. max_x = np.max(bboxes_array[:, 2])
  245. max_y = np.max(bboxes_array[:, 3])
  246. # Return the minimum enclosing bounding box
  247. return np.array([min_x, min_y, max_x, max_y])
  248. def is_english_letter(char):
  249. """check if the char is english letter"""
  250. return bool(re.match(r"^[A-Za-z]$", char))
  251. def is_numeric(char):
  252. """check if the char is numeric"""
  253. return bool(re.match(r"^[\d]+$", char))
  254. def is_non_breaking_punctuation(char):
  255. """
  256. check if the char is non-breaking punctuation
  257. Args:
  258. char (str): character to check
  259. Returns:
  260. bool: True if the char is non-breaking punctuation
  261. """
  262. non_breaking_punctuations = {
  263. ",",
  264. ",",
  265. "、",
  266. ";",
  267. ";",
  268. ":",
  269. ":",
  270. "-",
  271. "'",
  272. '"',
  273. "“",
  274. }
  275. return char in non_breaking_punctuations
  276. def gather_imgs(original_img, layout_det_objs):
  277. imgs_in_doc = []
  278. for det_obj in layout_det_objs:
  279. if det_obj["label"] in BLOCK_LABEL_MAP["image_labels"]:
  280. label = det_obj["label"]
  281. x_min, y_min, x_max, y_max = list(map(int, det_obj["coordinate"]))
  282. img_path = f"imgs/img_in_{label}_box_{x_min}_{y_min}_{x_max}_{y_max}.jpg"
  283. img = Image.fromarray(original_img[y_min:y_max, x_min:x_max, ::-1])
  284. imgs_in_doc.append(
  285. {
  286. "path": img_path,
  287. "img": img,
  288. "coordinate": (x_min, y_min, x_max, y_max),
  289. "score": det_obj["score"],
  290. }
  291. )
  292. return imgs_in_doc
  293. def _get_minbox_if_overlap_by_ratio(
  294. bbox1: Union[List[int], Tuple[int, int, int, int]],
  295. bbox2: Union[List[int], Tuple[int, int, int, int]],
  296. ratio: float,
  297. smaller: bool = True,
  298. ) -> Optional[Union[List[int], Tuple[int, int, int, int]]]:
  299. """
  300. Determine if the overlap area between two bounding boxes exceeds a given ratio
  301. and return the smaller (or larger) bounding box based on the `smaller` flag.
  302. Args:
  303. bbox1 (Union[List[int], Tuple[int, int, int, int]]): Coordinates of the first bounding box [x_min, y_min, x_max, y_max].
  304. bbox2 (Union[List[int], Tuple[int, int, int, int]]): Coordinates of the second bounding box [x_min, y_min, x_max, y_max].
  305. ratio (float): The overlap ratio threshold.
  306. smaller (bool): If True, return the smaller bounding box; otherwise, return the larger one.
  307. Returns:
  308. Optional[Union[List[int], Tuple[int, int, int, int]]]:
  309. The selected bounding box or None if the overlap ratio is not exceeded.
  310. """
  311. # Calculate the areas of both bounding boxes
  312. area1 = calculate_bbox_area(bbox1)
  313. area2 = calculate_bbox_area(bbox2)
  314. # Calculate the overlap ratio using a helper function
  315. overlap_ratio = calculate_overlap_ratio(bbox1, bbox2, mode="small")
  316. # Check if the overlap ratio exceeds the threshold
  317. if overlap_ratio > ratio:
  318. if (area1 <= area2 and smaller) or (area1 >= area2 and not smaller):
  319. return 1
  320. else:
  321. return 2
  322. return None
  323. def remove_overlap_blocks(
  324. blocks: List[Dict[str, List[int]]], threshold: float = 0.65, smaller: bool = True
  325. ) -> Tuple[List[Dict[str, List[int]]], List[Dict[str, List[int]]]]:
  326. """
  327. Remove overlapping blocks based on a specified overlap ratio threshold.
  328. Args:
  329. blocks (List[Dict[str, List[int]]]): List of block dictionaries, each containing a 'block_bbox' key.
  330. threshold (float): Ratio threshold to determine significant overlap.
  331. smaller (bool): If True, the smaller block in overlap is removed.
  332. Returns:
  333. Tuple[List[Dict[str, List[int]]], List[Dict[str, List[int]]]]:
  334. A tuple containing the updated list of blocks and a list of dropped blocks.
  335. """
  336. dropped_indexes = set()
  337. blocks = deepcopy(blocks)
  338. overlap_image_blocks = []
  339. # Iterate over each pair of blocks to find overlaps
  340. for i, block1 in enumerate(blocks["boxes"]):
  341. for j in range(i + 1, len(blocks["boxes"])):
  342. block2 = blocks["boxes"][j]
  343. # Skip blocks that are already marked for removal
  344. if i in dropped_indexes or j in dropped_indexes:
  345. continue
  346. # Check for overlap and determine which block to remove
  347. overlap_box_index = _get_minbox_if_overlap_by_ratio(
  348. block1["coordinate"],
  349. block2["coordinate"],
  350. threshold,
  351. smaller=smaller,
  352. )
  353. if overlap_box_index is not None:
  354. is_block1_image = block1["label"] == "image"
  355. is_block2_image = block2["label"] == "image"
  356. if is_block1_image != is_block2_image:
  357. drop_index = i if is_block1_image else j
  358. overlap_image_blocks.append(blocks["boxes"][drop_index])
  359. else:
  360. drop_index = i if overlap_box_index == 1 else j
  361. dropped_indexes.add(drop_index)
  362. # Remove marked blocks from the original list
  363. for index in sorted(dropped_indexes, reverse=True):
  364. del blocks["boxes"][index]
  365. return blocks
  366. def get_bbox_intersection(bbox1, bbox2, return_format="bbox"):
  367. """
  368. Compute the intersection of two bounding boxes, supporting both 4-coordinate and 8-coordinate formats.
  369. Args:
  370. bbox1 (tuple): The first bounding box, either in 4-coordinate format (x_min, y_min, x_max, y_max)
  371. or 8-coordinate format (x1, y1, x2, y2, x3, y3, x4, y4).
  372. bbox2 (tuple): The second bounding box in the same format as bbox1.
  373. return_format (str): The format of the output intersection, either 'bbox' or 'poly'.
  374. Returns:
  375. tuple or None: The intersection bounding box in the specified format, or None if there is no intersection.
  376. """
  377. bbox1 = np.array(bbox1)
  378. bbox2 = np.array(bbox2)
  379. # Convert both bounding boxes to rectangles
  380. rect1 = bbox1 if len(bbox1.shape) == 1 else convert_points_to_boxes([bbox1])[0]
  381. rect2 = bbox2 if len(bbox2.shape) == 1 else convert_points_to_boxes([bbox2])[0]
  382. # Calculate the intersection rectangle
  383. x_min_inter = max(rect1[0], rect2[0])
  384. y_min_inter = max(rect1[1], rect2[1])
  385. x_max_inter = min(rect1[2], rect2[2])
  386. y_max_inter = min(rect1[3], rect2[3])
  387. # Check if there is an intersection
  388. if x_min_inter >= x_max_inter or y_min_inter >= y_max_inter:
  389. return None
  390. if return_format == "bbox":
  391. return np.array([x_min_inter, y_min_inter, x_max_inter, y_max_inter])
  392. elif return_format == "poly":
  393. return np.array(
  394. [
  395. [x_min_inter, y_min_inter],
  396. [x_max_inter, y_min_inter],
  397. [x_max_inter, y_max_inter],
  398. [x_min_inter, y_max_inter],
  399. ],
  400. dtype=np.int16,
  401. )
  402. else:
  403. raise ValueError("return_format must be either 'bbox' or 'poly'.")
  404. def shrink_supplement_region_bbox(
  405. supplement_region_bbox,
  406. ref_region_bbox,
  407. image_width,
  408. image_height,
  409. block_idxes_set,
  410. block_bboxes,
  411. ) -> List:
  412. """
  413. Shrink the supplement region bbox according to the reference region bbox and match the block bboxes.
  414. Args:
  415. supplement_region_bbox (list): The supplement region bbox.
  416. ref_region_bbox (list): The reference region bbox.
  417. image_width (int): The width of the image.
  418. image_height (int): The height of the image.
  419. block_idxes_set (set): The indexes of the blocks that intersect with the region bbox.
  420. block_bboxes (dict): The dictionary of block bboxes.
  421. Returns:
  422. list: The new region bbox and the matched block idxes.
  423. """
  424. x1, y1, x2, y2 = supplement_region_bbox
  425. x1_prime, y1_prime, x2_prime, y2_prime = ref_region_bbox
  426. index_conversion_map = {0: 2, 1: 3, 2: 0, 3: 1}
  427. edge_distance_list = [
  428. (x1_prime - x1) / image_width,
  429. (y1_prime - y1) / image_height,
  430. (x2 - x2_prime) / image_width,
  431. (y2 - y2_prime) / image_height,
  432. ]
  433. edge_distance_list_tmp = deepcopy(edge_distance_list)
  434. min_distance = min(edge_distance_list)
  435. src_index = index_conversion_map[edge_distance_list.index(min_distance)]
  436. if len(block_idxes_set) == 0:
  437. return supplement_region_bbox, []
  438. for _ in range(3):
  439. dst_index = index_conversion_map[src_index]
  440. tmp_region_bbox = supplement_region_bbox[:]
  441. tmp_region_bbox[dst_index] = ref_region_bbox[src_index]
  442. iner_block_idxes, split_block_idxes = [], []
  443. for block_idx in block_idxes_set:
  444. overlap_ratio = calculate_overlap_ratio(
  445. tmp_region_bbox, block_bboxes[block_idx], mode="small"
  446. )
  447. if overlap_ratio > REGION_SETTINGS.get(
  448. "match_block_overlap_ratio_threshold", 0.8
  449. ):
  450. iner_block_idxes.append(block_idx)
  451. elif overlap_ratio > REGION_SETTINGS.get(
  452. "split_block_overlap_ratio_threshold", 0.4
  453. ):
  454. split_block_idxes.append(block_idx)
  455. if len(iner_block_idxes) > 0:
  456. if len(split_block_idxes) > 0:
  457. for split_block_idx in split_block_idxes:
  458. split_block_bbox = block_bboxes[split_block_idx]
  459. x1, y1, x2, y2 = tmp_region_bbox
  460. x1_prime, y1_prime, x2_prime, y2_prime = split_block_bbox
  461. edge_distance_list = [
  462. (x1_prime - x1) / image_width,
  463. (y1_prime - y1) / image_height,
  464. (x2 - x2_prime) / image_width,
  465. (y2 - y2_prime) / image_height,
  466. ]
  467. max_distance = max(edge_distance_list)
  468. src_index = edge_distance_list.index(max_distance)
  469. dst_index = index_conversion_map[src_index]
  470. tmp_region_bbox[dst_index] = split_block_bbox[src_index]
  471. tmp_region_bbox, iner_idxes = shrink_supplement_region_bbox(
  472. tmp_region_bbox,
  473. ref_region_bbox,
  474. image_width,
  475. image_height,
  476. iner_block_idxes,
  477. block_bboxes,
  478. )
  479. if len(iner_idxes) == 0:
  480. continue
  481. matched_bboxes = [block_bboxes[idx] for idx in iner_block_idxes]
  482. supplement_region_bbox = calculate_minimum_enclosing_bbox(matched_bboxes)
  483. break
  484. else:
  485. edge_distance_list_tmp.remove(min_distance)
  486. min_distance = min(edge_distance_list_tmp)
  487. src_index = index_conversion_map[edge_distance_list.index(min_distance)]
  488. return supplement_region_bbox, iner_block_idxes
  489. def update_region_box(bbox, region_box):
  490. """Update region box with bbox"""
  491. if region_box is None:
  492. return bbox
  493. x1, y1, x2, y2 = bbox
  494. x1_region, y1_region, x2_region, y2_region = region_box
  495. x1_region = int(min(x1, x1_region))
  496. y1_region = int(min(y1, y1_region))
  497. x2_region = int(max(x2, x2_region))
  498. y2_region = int(max(y2, y2_region))
  499. region_box = [x1_region, y1_region, x2_region, y2_region]
  500. return region_box
  501. def convert_formula_res_to_ocr_format(formula_res_list: List, ocr_res: dict):
  502. """Convert formula result to OCR result format
  503. Args:
  504. formula_res_list (List): Formula results
  505. ocr_res (dict): OCR result
  506. Returns:
  507. ocr_res (dict): Updated OCR result
  508. """
  509. for formula_res in formula_res_list:
  510. x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
  511. poly_points = [
  512. (x_min, y_min),
  513. (x_max, y_min),
  514. (x_max, y_max),
  515. (x_min, y_max),
  516. ]
  517. ocr_res["dt_polys"].append(poly_points)
  518. formula_res_text: str = formula_res["rec_formula"]
  519. ocr_res["rec_texts"].append(formula_res_text)
  520. if ocr_res["rec_boxes"].size == 0:
  521. ocr_res["rec_boxes"] = np.array(formula_res["dt_polys"])
  522. else:
  523. ocr_res["rec_boxes"] = np.vstack(
  524. (ocr_res["rec_boxes"], [formula_res["dt_polys"]])
  525. )
  526. ocr_res["rec_labels"].append("formula")
  527. ocr_res["rec_polys"].append(poly_points)
  528. ocr_res["rec_scores"].append(1)
  529. def calculate_bbox_area(bbox):
  530. """Calculate bounding box area"""
  531. x1, y1, x2, y2 = map(float, bbox)
  532. area = abs((x2 - x1) * (y2 - y1))
  533. return area
  534. def caculate_euclidean_dist(point1, point2):
  535. """Calculate euclidean distance between two points"""
  536. x1, y1 = point1
  537. x2, y2 = point2
  538. return ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5
  539. def get_seg_flag(block, prev_block):
  540. """Get segment start flag and end flag based on previous block
  541. Args:
  542. block (Block): Current block
  543. prev_block (Block): Previous block
  544. Returns:
  545. seg_start_flag (bool): Segment start flag
  546. seg_end_flag (bool): Segment end flag
  547. """
  548. seg_start_flag = True
  549. seg_end_flag = True
  550. context_left_coordinate = block.start_coordinate
  551. context_right_coordinate = block.end_coordinate
  552. seg_start_coordinate = block.seg_start_coordinate
  553. seg_end_coordinate = block.seg_end_coordinate
  554. if prev_block is not None:
  555. num_of_prev_lines = prev_block.num_of_lines
  556. pre_block_seg_end_coordinate = prev_block.seg_end_coordinate
  557. prev_end_space_small = (
  558. abs(prev_block.end_coordinate - pre_block_seg_end_coordinate) < 10
  559. )
  560. prev_lines_more_than_one = num_of_prev_lines > 1
  561. overlap_blocks = (
  562. context_left_coordinate < prev_block.end_coordinate
  563. and context_right_coordinate > prev_block.start_coordinate
  564. )
  565. # update context_left_coordinate and context_right_coordinate
  566. if overlap_blocks:
  567. context_left_coordinate = min(
  568. prev_block.start_coordinate, context_left_coordinate
  569. )
  570. context_right_coordinate = max(
  571. prev_block.end_coordinate, context_right_coordinate
  572. )
  573. prev_end_space_small = (
  574. abs(context_right_coordinate - pre_block_seg_end_coordinate) < 10
  575. )
  576. edge_distance = 0
  577. else:
  578. edge_distance = abs(block.start_coordinate - prev_block.end_coordinate)
  579. current_start_space_small = seg_start_coordinate - context_left_coordinate < 10
  580. if (
  581. prev_end_space_small
  582. and current_start_space_small
  583. and prev_lines_more_than_one
  584. and edge_distance < max(prev_block.width, block.width)
  585. ):
  586. seg_start_flag = False
  587. else:
  588. if seg_start_coordinate - context_left_coordinate < 10:
  589. seg_start_flag = False
  590. if context_right_coordinate - seg_end_coordinate < 10:
  591. seg_end_flag = False
  592. return seg_start_flag, seg_end_flag
  593. def get_show_color(label: str, order_label=False) -> Tuple:
  594. if order_label:
  595. label_colors = {
  596. "doc_title": (255, 248, 220, 100), # Cornsilk
  597. "doc_title_text": (255, 239, 213, 100),
  598. "paragraph_title": (102, 102, 255, 100),
  599. "sub_paragraph_title": (102, 178, 255, 100),
  600. "vision": (153, 255, 51, 100),
  601. "vision_title": (144, 238, 144, 100), # Light Green
  602. "vision_footnote": (144, 238, 144, 100), # Light Green
  603. "normal_text": (153, 0, 76, 100),
  604. "cross_layout": (53, 218, 207, 100), # Thistle
  605. "cross_reference": (221, 160, 221, 100), # Floral White
  606. }
  607. else:
  608. label_colors = {
  609. # Medium Blue (from 'titles_list')
  610. "paragraph_title": (102, 102, 255, 100),
  611. "doc_title": (255, 248, 220, 100), # Cornsilk
  612. # Light Yellow (from 'tables_caption_list')
  613. "table_title": (255, 255, 102, 100),
  614. # Sky Blue (from 'imgs_caption_list')
  615. "figure_title": (102, 178, 255, 100),
  616. "chart_title": (221, 160, 221, 100), # Plum
  617. "vision_footnote": (144, 238, 144, 100), # Light Green
  618. # Deep Purple (from 'texts_list')
  619. "text": (153, 0, 76, 100),
  620. "vertical_text": (153, 0, 76, 100),
  621. "inline_formula": (153, 0, 76, 100),
  622. # Bright Green (from 'interequations_list')
  623. "formula": (0, 255, 0, 100),
  624. "display_formula": (0, 255, 0, 100),
  625. "abstract": (255, 239, 213, 100), # Papaya Whip
  626. # Medium Green (from 'lists_list' and 'indexs_list')
  627. "content": (40, 169, 92, 100),
  628. # Neutral Gray (from 'dropped_bbox_list')
  629. "seal": (158, 158, 158, 100),
  630. # Olive Yellow (from 'tables_body_list')
  631. "table": (204, 204, 0, 100),
  632. # Bright Green (from 'imgs_body_list')
  633. "image": (153, 255, 51, 100),
  634. # Bright Green (from 'imgs_body_list')
  635. "figure": (153, 255, 51, 100),
  636. "chart": (216, 191, 216, 100), # Thistle
  637. # Pale Yellow-Green (from 'tables_footnote_list')
  638. "reference": (229, 255, 204, 100),
  639. "reference_content": (229, 255, 204, 100),
  640. "algorithm": (255, 250, 240, 100), # Floral White
  641. }
  642. default_color = (158, 158, 158, 100)
  643. return label_colors.get(label, default_color)