cal_ocr_word_box.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. __all__ = ["cal_ocr_word_box"]
  15. import numpy as np
  16. # from .convert_points_and_boxes import convert_points_to_boxes
  17. def cal_ocr_word_box(rec_str, box, rec_word_info):
  18. """Calculate the detection frame for each word based on the results of recognition and detection of ocr"""
  19. col_num, word_list, word_col_list, state_list = rec_word_info
  20. box = box.tolist()
  21. bbox_x_start = box[0][0]
  22. bbox_x_end = box[1][0]
  23. bbox_y_start = box[0][1]
  24. bbox_y_end = box[2][1]
  25. cell_width = (bbox_x_end - bbox_x_start) / col_num
  26. word_box_list = []
  27. word_box_content_list = []
  28. cn_width_list = []
  29. cn_col_list = []
  30. for word, word_col, state in zip(word_list, word_col_list, state_list):
  31. if state == "cn":
  32. if len(word_col) != 1:
  33. char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width
  34. char_width = char_seq_length / (len(word_col) - 1)
  35. cn_width_list.append(char_width)
  36. cn_col_list += word_col
  37. word_box_content_list += word
  38. else:
  39. cell_x_start = bbox_x_start + int(word_col[0] * cell_width)
  40. cell_x_end = bbox_x_start + int((word_col[-1] + 1) * cell_width)
  41. cell = (
  42. (cell_x_start, bbox_y_start),
  43. (cell_x_end, bbox_y_start),
  44. (cell_x_end, bbox_y_end),
  45. (cell_x_start, bbox_y_end),
  46. )
  47. word_box_list.append(cell)
  48. word_box_content_list.append("".join(word))
  49. if len(cn_col_list) != 0:
  50. if len(cn_width_list) != 0:
  51. avg_char_width = np.mean(cn_width_list)
  52. else:
  53. avg_char_width = (bbox_x_end - bbox_x_start) / len(rec_str)
  54. for center_idx in cn_col_list:
  55. center_x = (center_idx + 0.5) * cell_width
  56. cell_x_start = max(int(center_x - avg_char_width / 2), 0) + bbox_x_start
  57. cell_x_end = (
  58. min(int(center_x + avg_char_width / 2), bbox_x_end - bbox_x_start)
  59. + bbox_x_start
  60. )
  61. cell = (
  62. (cell_x_start, bbox_y_start),
  63. (cell_x_end, bbox_y_start),
  64. (cell_x_end, bbox_y_end),
  65. (cell_x_start, bbox_y_end),
  66. )
  67. word_box_list.append(cell)
  68. word_box_list = sort_boxes(word_box_list, y_thresh=12)
  69. return word_box_content_list, word_box_list
  70. def sort_boxes(boxes, y_thresh=10):
  71. box_centers = [np.mean(box, axis=0) for box in boxes]
  72. items = list(zip(boxes, box_centers))
  73. items.sort(key=lambda x: x[1][1])
  74. lines = []
  75. current_line = []
  76. last_y = None
  77. for box, center in items:
  78. if last_y is None or abs(center[1] - last_y) < y_thresh:
  79. current_line.append((box, center))
  80. else:
  81. lines.append(current_line)
  82. current_line = [(box, center)]
  83. last_y = center[1]
  84. if current_line:
  85. lines.append(current_line)
  86. final_box = []
  87. for line in lines:
  88. line = sorted(line, key=lambda x: x[1][0])
  89. final_box.extend(box for box, center in line)
  90. return final_box