ocr_utils.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import copy
  3. import cv2
  4. import numpy as np
  5. class OcrConfidence:
  6. min_confidence = 0.5
  7. min_width = 3
  8. LINE_WIDTH_TO_HEIGHT_RATIO_THRESHOLD = 4 # 一般情况下,行宽度超过高度4倍时才是一个正常的横向文本块
  9. def merge_spans_to_line(spans, threshold=0.6):
  10. if len(spans) == 0:
  11. return []
  12. else:
  13. # 按照y0坐标排序
  14. spans.sort(key=lambda span: span['bbox'][1])
  15. lines = []
  16. current_line = [spans[0]]
  17. for span in spans[1:]:
  18. # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
  19. if _is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
  20. current_line.append(span)
  21. else:
  22. # 否则,开始新行
  23. lines.append(current_line)
  24. current_line = [span]
  25. # 添加最后一行
  26. if current_line:
  27. lines.append(current_line)
  28. return lines
  29. def _is_overlaps_y_exceeds_threshold(bbox1,
  30. bbox2,
  31. overlap_ratio_threshold=0.8):
  32. """检查两个bbox在y轴上是否有重叠,并且该重叠区域的高度占两个bbox高度更低的那个超过80%"""
  33. _, y0_1, _, y1_1 = bbox1
  34. _, y0_2, _, y1_2 = bbox2
  35. overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2))
  36. height1, height2 = y1_1 - y0_1, y1_2 - y0_2
  37. # max_height = max(height1, height2)
  38. min_height = min(height1, height2)
  39. return (overlap / min_height) > overlap_ratio_threshold if min_height > 0 else False
  40. def _is_overlaps_x_exceeds_threshold(bbox1,
  41. bbox2,
  42. overlap_ratio_threshold=0.8):
  43. """检查两个bbox在x轴上是否有重叠,并且该重叠区域的宽度占两个bbox宽度更低的那个超过指定阈值"""
  44. x0_1, _, x1_1, _ = bbox1
  45. x0_2, _, x1_2, _ = bbox2
  46. overlap = max(0, min(x1_1, x1_2) - max(x0_1, x0_2))
  47. width1, width2 = x1_1 - x0_1, x1_2 - x0_2
  48. min_width = min(width1, width2)
  49. return (overlap / min_width) > overlap_ratio_threshold if min_width > 0 else False
  50. def img_decode(content: bytes):
  51. np_arr = np.frombuffer(content, dtype=np.uint8)
  52. return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
  53. def check_img(img):
  54. if isinstance(img, bytes):
  55. img = img_decode(img)
  56. if isinstance(img, np.ndarray) and len(img.shape) == 2:
  57. img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  58. return img
  59. def alpha_to_color(img, alpha_color=(255, 255, 255)):
  60. if len(img.shape) == 3 and img.shape[2] == 4:
  61. B, G, R, A = cv2.split(img)
  62. alpha = A / 255
  63. R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8)
  64. G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8)
  65. B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8)
  66. img = cv2.merge((B, G, R))
  67. return img
  68. def preprocess_image(_image):
  69. alpha_color = (255, 255, 255)
  70. _image = alpha_to_color(_image, alpha_color)
  71. return _image
  72. def sorted_boxes(dt_boxes):
  73. """
  74. Sort text boxes in order from top to bottom, left to right
  75. args:
  76. dt_boxes(array):detected text boxes with shape [4, 2]
  77. return:
  78. sorted boxes(array) with shape [4, 2]
  79. """
  80. num_boxes = dt_boxes.shape[0]
  81. sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
  82. _boxes = list(sorted_boxes)
  83. for i in range(num_boxes - 1):
  84. for j in range(i, -1, -1):
  85. if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
  86. (_boxes[j + 1][0][0] < _boxes[j][0][0]):
  87. tmp = _boxes[j]
  88. _boxes[j] = _boxes[j + 1]
  89. _boxes[j + 1] = tmp
  90. else:
  91. break
  92. return _boxes
  93. def bbox_to_points(bbox):
  94. """ 将bbox格式转换为四个顶点的数组 """
  95. x0, y0, x1, y1 = bbox
  96. return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
  97. def points_to_bbox(points):
  98. """ 将四个顶点的数组转换为bbox格式 """
  99. x0, y0 = points[0]
  100. x1, _ = points[1]
  101. _, y1 = points[2]
  102. return [x0, y0, x1, y1]
  103. def merge_intervals(intervals):
  104. # Sort the intervals based on the start value
  105. intervals.sort(key=lambda x: x[0])
  106. merged = []
  107. for interval in intervals:
  108. # If the list of merged intervals is empty or if the current
  109. # interval does not overlap with the previous, simply append it.
  110. if not merged or merged[-1][1] < interval[0]:
  111. merged.append(interval)
  112. else:
  113. # Otherwise, there is overlap, so we merge the current and previous intervals.
  114. merged[-1][1] = max(merged[-1][1], interval[1])
  115. return merged
  116. def remove_intervals(original, masks):
  117. # Merge all mask intervals
  118. merged_masks = merge_intervals(masks)
  119. result = []
  120. original_start, original_end = original
  121. for mask in merged_masks:
  122. mask_start, mask_end = mask
  123. # If the mask starts after the original range, ignore it
  124. if mask_start > original_end:
  125. continue
  126. # If the mask ends before the original range starts, ignore it
  127. if mask_end < original_start:
  128. continue
  129. # Remove the masked part from the original range
  130. if original_start < mask_start:
  131. result.append([original_start, mask_start - 1])
  132. original_start = max(mask_end + 1, original_start)
  133. # Add the remaining part of the original range, if any
  134. if original_start <= original_end:
  135. result.append([original_start, original_end])
  136. return result
  137. def update_det_boxes(dt_boxes, mfd_res):
  138. new_dt_boxes = []
  139. angle_boxes_list = []
  140. for text_box in dt_boxes:
  141. if calculate_is_angle(text_box):
  142. angle_boxes_list.append(text_box)
  143. continue
  144. text_bbox = points_to_bbox(text_box)
  145. masks_list = []
  146. for mf_box in mfd_res:
  147. mf_bbox = mf_box['bbox']
  148. if _is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
  149. masks_list.append([mf_bbox[0], mf_bbox[2]])
  150. text_x_range = [text_bbox[0], text_bbox[2]]
  151. text_remove_mask_range = remove_intervals(text_x_range, masks_list)
  152. temp_dt_box = []
  153. for text_remove_mask in text_remove_mask_range:
  154. temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
  155. if len(temp_dt_box) > 0:
  156. new_dt_boxes.extend(temp_dt_box)
  157. new_dt_boxes.extend(angle_boxes_list)
  158. return new_dt_boxes
  159. def merge_overlapping_spans(spans):
  160. """
  161. Merges overlapping spans on the same line.
  162. :param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
  163. :return: A list of merged spans
  164. """
  165. # Return an empty list if the input spans list is empty
  166. if not spans:
  167. return []
  168. # Sort spans by their starting x-coordinate
  169. spans.sort(key=lambda x: x[0])
  170. # Initialize the list of merged spans
  171. merged = []
  172. for span in spans:
  173. # Unpack span coordinates
  174. x1, y1, x2, y2 = span
  175. # If the merged list is empty or there's no horizontal overlap, add the span directly
  176. if not merged or merged[-1][2] < x1:
  177. merged.append(span)
  178. else:
  179. # If there is horizontal overlap, merge the current span with the previous one
  180. last_span = merged.pop()
  181. # Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
  182. x1 = min(last_span[0], x1)
  183. y1 = min(last_span[1], y1)
  184. x2 = max(last_span[2], x2)
  185. y2 = max(last_span[3], y2)
  186. # Add the merged span back to the list
  187. merged.append((x1, y1, x2, y2))
  188. # Return the list of merged spans
  189. return merged
  190. def merge_det_boxes(dt_boxes):
  191. """
  192. Merge detection boxes.
  193. This function takes a list of detected bounding boxes, each represented by four corner points.
  194. The goal is to merge these bounding boxes into larger text regions.
  195. Parameters:
  196. dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
  197. Returns:
  198. list: A list containing the merged text regions, where each region is represented by four corner points.
  199. """
  200. # Convert the detection boxes into a dictionary format with bounding boxes and type
  201. dt_boxes_dict_list = []
  202. angle_boxes_list = []
  203. for text_box in dt_boxes:
  204. text_bbox = points_to_bbox(text_box)
  205. if calculate_is_angle(text_box):
  206. angle_boxes_list.append(text_box)
  207. continue
  208. text_box_dict = {'bbox': text_bbox}
  209. dt_boxes_dict_list.append(text_box_dict)
  210. # Merge adjacent text regions into lines
  211. lines = merge_spans_to_line(dt_boxes_dict_list)
  212. # Initialize a new list for storing the merged text regions
  213. new_dt_boxes = []
  214. for line in lines:
  215. line_bbox_list = []
  216. for span in line:
  217. line_bbox_list.append(span['bbox'])
  218. # 计算整行的宽度和高度
  219. min_x = min(bbox[0] for bbox in line_bbox_list)
  220. max_x = max(bbox[2] for bbox in line_bbox_list)
  221. min_y = min(bbox[1] for bbox in line_bbox_list)
  222. max_y = max(bbox[3] for bbox in line_bbox_list)
  223. line_width = max_x - min_x
  224. line_height = max_y - min_y
  225. # 只有当行宽度超过高度4倍时才进行合并
  226. if line_width > line_height * LINE_WIDTH_TO_HEIGHT_RATIO_THRESHOLD:
  227. # Merge overlapping text regions within the same line
  228. merged_spans = merge_overlapping_spans(line_bbox_list)
  229. # Convert the merged text regions back to point format and add them to the new detection box list
  230. for span in merged_spans:
  231. new_dt_boxes.append(bbox_to_points(span))
  232. else:
  233. # 不进行合并,直接添加原始区域
  234. for bbox in line_bbox_list:
  235. new_dt_boxes.append(bbox_to_points(bbox))
  236. new_dt_boxes.extend(angle_boxes_list)
  237. return new_dt_boxes
  238. def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
  239. paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
  240. # Adjust the coordinates of the formula area
  241. adjusted_mfdetrec_res = []
  242. for mf_res in single_page_mfdetrec_res:
  243. mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
  244. # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
  245. x0 = mf_xmin - xmin + paste_x
  246. y0 = mf_ymin - ymin + paste_y
  247. x1 = mf_xmax - xmin + paste_x
  248. y1 = mf_ymax - ymin + paste_y
  249. # Filter formula blocks outside the graph
  250. if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
  251. continue
  252. else:
  253. adjusted_mfdetrec_res.append({
  254. "bbox": [x0, y0, x1, y1],
  255. })
  256. return adjusted_mfdetrec_res
  257. def get_ocr_result_list(ocr_res, useful_list, ocr_enable, bgr_image, lang):
  258. paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
  259. ocr_result_list = []
  260. ori_im = bgr_image.copy()
  261. for box_ocr_res in ocr_res:
  262. if len(box_ocr_res) == 2:
  263. p1, p2, p3, p4 = box_ocr_res[0]
  264. text, score = box_ocr_res[1]
  265. # logger.info(f"text: {text}, score: {score}")
  266. if score < OcrConfidence.min_confidence: # 过滤低置信度的结果
  267. continue
  268. else:
  269. p1, p2, p3, p4 = box_ocr_res
  270. text, score = "", 1
  271. if ocr_enable:
  272. tmp_box = copy.deepcopy(np.array([p1, p2, p3, p4]).astype('float32'))
  273. img_crop = get_rotate_crop_image(ori_im, tmp_box)
  274. # average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
  275. # if average_angle_degrees > 0.5:
  276. poly = [p1, p2, p3, p4]
  277. if (p3[0] - p1[0]) < OcrConfidence.min_width:
  278. # logger.info(f"width too small: {p3[0] - p1[0]}, text: {text}")
  279. continue
  280. if calculate_is_angle(poly):
  281. # logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
  282. # 与x轴的夹角超过0.5度,对边界做一下矫正
  283. # 计算几何中心
  284. x_center = sum(point[0] for point in poly) / 4
  285. y_center = sum(point[1] for point in poly) / 4
  286. new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
  287. new_width = p3[0] - p1[0]
  288. p1 = [x_center - new_width / 2, y_center - new_height / 2]
  289. p2 = [x_center + new_width / 2, y_center - new_height / 2]
  290. p3 = [x_center + new_width / 2, y_center + new_height / 2]
  291. p4 = [x_center - new_width / 2, y_center + new_height / 2]
  292. # Convert the coordinates back to the original coordinate system
  293. p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
  294. p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
  295. p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
  296. p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
  297. if ocr_enable:
  298. ocr_result_list.append({
  299. 'category_id': 15,
  300. 'poly': p1 + p2 + p3 + p4,
  301. 'score': 1,
  302. 'text': text,
  303. 'np_img': img_crop,
  304. 'lang': lang,
  305. })
  306. else:
  307. ocr_result_list.append({
  308. 'category_id': 15,
  309. 'poly': p1 + p2 + p3 + p4,
  310. 'score': float(round(score, 2)),
  311. 'text': text,
  312. })
  313. return ocr_result_list
  314. def calculate_is_angle(poly):
  315. p1, p2, p3, p4 = poly
  316. height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
  317. if 0.8 * height <= (p3[1] - p1[1]) <= 1.2 * height:
  318. return False
  319. else:
  320. # logger.info((p3[1] - p1[1])/height)
  321. return True
  322. def get_rotate_crop_image(img, points):
  323. '''
  324. img_height, img_width = img.shape[0:2]
  325. left = int(np.min(points[:, 0]))
  326. right = int(np.max(points[:, 0]))
  327. top = int(np.min(points[:, 1]))
  328. bottom = int(np.max(points[:, 1]))
  329. img_crop = img[top:bottom, left:right, :].copy()
  330. points[:, 0] = points[:, 0] - left
  331. points[:, 1] = points[:, 1] - top
  332. '''
  333. assert len(points) == 4, "shape of points must be 4*2"
  334. img_crop_width = int(
  335. max(
  336. np.linalg.norm(points[0] - points[1]),
  337. np.linalg.norm(points[2] - points[3])))
  338. img_crop_height = int(
  339. max(
  340. np.linalg.norm(points[0] - points[3]),
  341. np.linalg.norm(points[1] - points[2])))
  342. pts_std = np.float32([[0, 0], [img_crop_width, 0],
  343. [img_crop_width, img_crop_height],
  344. [0, img_crop_height]])
  345. M = cv2.getPerspectiveTransform(points, pts_std)
  346. dst_img = cv2.warpPerspective(
  347. img,
  348. M, (img_crop_width, img_crop_height),
  349. borderMode=cv2.BORDER_REPLICATE,
  350. flags=cv2.INTER_CUBIC)
  351. dst_img_height, dst_img_width = dst_img.shape[0:2]
  352. rotate_radio = 2
  353. if dst_img_height * 1.0 / dst_img_width >= rotate_radio:
  354. dst_img = np.rot90(dst_img)
  355. return dst_img