utils.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import re
  15. from scipy.ndimage import rotate
  16. def get_ocr_res(pipeline, input):
  17. """get ocr res"""
  18. ocr_res_list = []
  19. if isinstance(input, list):
  20. img = [im["img"] for im in input]
  21. elif isinstance(input, dict):
  22. img = input["img"]
  23. else:
  24. img = input
  25. for ocr_res in pipeline(img):
  26. ocr_res_list.append(ocr_res)
  27. if len(ocr_res_list) == 1:
  28. return ocr_res_list[0]
  29. else:
  30. return ocr_res_list
  31. def get_oriclas_results(inputs, predictor):
  32. results = []
  33. img_list = [img_info["img"] for img_info in inputs]
  34. for input, pred in zip(inputs, predictor(img_list)):
  35. results.append(pred)
  36. angle = int(pred["label_names"][0])
  37. input["img"] = rotate_image(input["img"], angle)
  38. return results
  39. def get_unwarp_results(inputs, predictor):
  40. results = []
  41. img_list = [img_info["img"] for img_info in inputs]
  42. for input, pred in zip(inputs, predictor(img_list)):
  43. results.append(pred)
  44. input["img"] = pred["doctr_img"]
  45. return results
  46. def get_predictor_res(predictor, input):
  47. """get ocr res"""
  48. result_list = []
  49. if isinstance(input, list):
  50. img = [im["img"] for im in input]
  51. elif isinstance(input, dict):
  52. img = input["img"]
  53. else:
  54. img = input
  55. for res in predictor(img):
  56. result_list.append(res)
  57. if len(result_list) == 1:
  58. return result_list[0]
  59. else:
  60. return result_list
  61. def rotate_image(image_array, rotate_angle):
  62. """rotate image"""
  63. assert (
  64. rotate_angle >= 0 and rotate_angle < 360
  65. ), "rotate_angle must in [0-360), but get {rotate_angle}."
  66. return rotate(image_array, rotate_angle, reshape=True)
  67. def get_table_text_from_html(all_table_res):
  68. all_table_ocr_res = []
  69. structure_res = []
  70. for table_res in all_table_res:
  71. table_list = []
  72. table_lines = re.findall("<tr>(.*?)</tr>", table_res["html"])
  73. single_table_ocr_res = []
  74. for td_line in table_lines:
  75. table_list.extend(re.findall("<td.*?>(.*?)</td>", td_line))
  76. for text in table_list:
  77. text = text.replace(" ", "")
  78. single_table_ocr_res.append(text)
  79. all_table_ocr_res.append(" ".join(single_table_ocr_res))
  80. structure_res.append(
  81. {
  82. "layout_bbox": table_res["layout_bbox"],
  83. "table": table_res["html"],
  84. }
  85. )
  86. return structure_res, all_table_ocr_res
  87. def format_key(key_list):
  88. """format key"""
  89. if key_list == "":
  90. return "未内置默认字段,请输入确定的key"
  91. if isinstance(key_list, list):
  92. return key_list
  93. key_list = re.sub(r"[\t\n\r\f\v]", "", key_list)
  94. key_list = key_list.replace(",", ",").split(",")
  95. return key_list
  96. def sorted_layout_boxes(res, w):
  97. """
  98. Sort text boxes in order from top to bottom, left to right
  99. args:
  100. res(list):ppstructure results
  101. w(int):image width
  102. return:
  103. sorted results(list)
  104. """
  105. num_boxes = len(res)
  106. if num_boxes == 1:
  107. res[0]["layout"] = "single"
  108. return res
  109. # Sort on the y axis first or sort it on the x axis
  110. sorted_boxes = sorted(res, key=lambda x: (x["layout_bbox"][1], x["layout_bbox"][0]))
  111. _boxes = list(sorted_boxes)
  112. new_res = []
  113. res_left = []
  114. res_right = []
  115. i = 0
  116. while True:
  117. if i >= num_boxes:
  118. break
  119. # Check that the bbox is on the left
  120. elif (
  121. _boxes[i]["layout_bbox"][0] < w / 4
  122. and _boxes[i]["layout_bbox"][2] < 3 * w / 5
  123. ):
  124. _boxes[i]["layout"] = "double"
  125. res_left.append(_boxes[i])
  126. i += 1
  127. elif _boxes[i]["layout_bbox"][0] > 2 * w / 5:
  128. _boxes[i]["layout"] = "double"
  129. res_right.append(_boxes[i])
  130. i += 1
  131. else:
  132. new_res += res_left
  133. new_res += res_right
  134. _boxes[i]["layout"] = "single"
  135. new_res.append(_boxes[i])
  136. res_left = []
  137. res_right = []
  138. i += 1
  139. res_left = sorted(res_left, key=lambda x: (x["layout_bbox"][1]))
  140. res_right = sorted(res_right, key=lambda x: (x["layout_bbox"][1]))
  141. if res_left:
  142. new_res += res_left
  143. if res_right:
  144. new_res += res_right
  145. return new_res