utils.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import re
  16. import numpy as np
  17. from pathlib import Path
  18. from scipy.ndimage import rotate
  19. def get_ocr_res(pipeline, input):
  20. """get ocr res"""
  21. ocr_res_list = []
  22. if isinstance(input, list):
  23. img = [im["img"] for im in input]
  24. elif isinstance(input, dict):
  25. img = input["img"]
  26. else:
  27. img = input
  28. for ocr_res in pipeline(img):
  29. ocr_res_list.append(ocr_res)
  30. if len(ocr_res_list) == 1:
  31. return ocr_res_list[0]
  32. else:
  33. return ocr_res_list
  34. def get_oriclas_results(inputs, predictor):
  35. results = []
  36. img_list = [img_info["img"] for img_info in inputs]
  37. for input, pred in zip(inputs, predictor(img_list)):
  38. results.append(pred)
  39. angle = int(pred["label_names"][0])
  40. input["img"] = rotate_image(input["img"], angle)
  41. return results
  42. def get_uvdoc_results(inputs, predictor):
  43. results = []
  44. img_list = [img_info["img"] for img_info in inputs]
  45. for input, pred in zip(inputs, predictor(img_list)):
  46. results.append(pred)
  47. input["img"] = pred["doctr_img"]
  48. return results
  49. def get_predictor_res(predictor, input):
  50. """get ocr res"""
  51. result_list = []
  52. if isinstance(input, list):
  53. img = [im["img"] for im in input]
  54. elif isinstance(input, dict):
  55. img = input["img"]
  56. else:
  57. img = input
  58. for res in predictor(img):
  59. result_list.append(res)
  60. if len(result_list) == 1:
  61. return result_list[0]
  62. else:
  63. return result_list
  64. def rotate_image(image_array, rotate_angle):
  65. """rotate image"""
  66. assert (
  67. rotate_angle >= 0 and rotate_angle < 360
  68. ), "rotate_angle must in [0-360), but get {rotate_angle}."
  69. return rotate(image_array, rotate_angle, reshape=True)
  70. def get_table_text_from_html(all_table_res):
  71. all_table_ocr_res = []
  72. structure_res = []
  73. for table_res in all_table_res:
  74. table_list = []
  75. table_lines = re.findall("<tr>(.*?)</tr>", table_res["html"])
  76. single_table_ocr_res = []
  77. for td_line in table_lines:
  78. table_list.extend(re.findall("<td.*?>(.*?)</td>", td_line))
  79. for text in table_list:
  80. text = text.replace(" ", "")
  81. single_table_ocr_res.append(text)
  82. all_table_ocr_res.append(" ".join(single_table_ocr_res))
  83. structure_res.append(
  84. {
  85. "layout_bbox": table_res["layout_bbox"],
  86. "table": table_res["html"],
  87. }
  88. )
  89. return structure_res, all_table_ocr_res
  90. def format_key(key_list):
  91. """format key"""
  92. if key_list == "":
  93. return "未内置默认字段,请输入确定的key"
  94. if isinstance(key_list, list):
  95. return key_list
  96. key_list = re.sub(r"[\t\n\r\f\v]", "", key_list)
  97. key_list = key_list.replace(",", ",").split(",")
  98. return key_list
  99. def sorted_layout_boxes(res, w):
  100. """
  101. Sort text boxes in order from top to bottom, left to right
  102. args:
  103. res(list):ppstructure results
  104. w(int):image width
  105. return:
  106. sorted results(list)
  107. """
  108. num_boxes = len(res)
  109. if num_boxes == 1:
  110. res[0]["layout"] = "single"
  111. return res
  112. # Sort on the y axis first or sort it on the x axis
  113. sorted_boxes = sorted(res, key=lambda x: (x["layout_bbox"][1], x["layout_bbox"][0]))
  114. _boxes = list(sorted_boxes)
  115. new_res = []
  116. res_left = []
  117. res_mid = []
  118. res_right = []
  119. i = 0
  120. while True:
  121. if i >= num_boxes:
  122. break
  123. # Check if there are three columns of pictures
  124. if (
  125. _boxes[i]["layout_bbox"][0] > w / 4
  126. and _boxes[i]["layout_bbox"][0] + _boxes[i]["layout_bbox"][2] < 3 * w / 4
  127. ):
  128. _boxes[i]["layout"] = "double"
  129. res_mid.append(_boxes[i])
  130. i += 1
  131. # Check that the bbox is on the left
  132. elif (
  133. _boxes[i]["layout_bbox"][0] < w / 4
  134. and _boxes[i]["layout_bbox"][0] + _boxes[i]["layout_bbox"][2] < 3 * w / 5
  135. ):
  136. _boxes[i]["layout"] = "double"
  137. res_left.append(_boxes[i])
  138. i += 1
  139. elif (
  140. _boxes[i]["layout_bbox"][0] > 2 * w / 5
  141. and _boxes[i]["layout_bbox"][0] + _boxes[i]["layout_bbox"][2] < w
  142. ):
  143. _boxes[i]["layout"] = "double"
  144. res_right.append(_boxes[i])
  145. i += 1
  146. else:
  147. new_res += res_left
  148. new_res += res_right
  149. _boxes[i]["layout"] = "single"
  150. new_res.append(_boxes[i])
  151. res_left = []
  152. res_right = []
  153. i += 1
  154. res_left = sorted(res_left, key=lambda x: (x["layout_bbox"][1]))
  155. res_mid = sorted(res_mid, key=lambda x: (x["layout_bbox"][1]))
  156. res_right = sorted(res_right, key=lambda x: (x["layout_bbox"][1]))
  157. if res_left:
  158. new_res += res_left
  159. if res_mid:
  160. new_res += res_mid
  161. if res_right:
  162. new_res += res_right
  163. return new_res