utils.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import re
  16. import numpy as np
  17. from pathlib import Path
  18. from scipy.ndimage import rotate
  19. def get_ocr_res(pipeline, input):
  20. """get ocr res"""
  21. ocr_res_list = []
  22. if isinstance(input, list):
  23. img = [im["img"] for im in input]
  24. elif isinstance(input, dict):
  25. img = input["img"]
  26. else:
  27. img = input
  28. for ocr_res in pipeline(img):
  29. ocr_res_list.append(ocr_res)
  30. if len(ocr_res_list) == 1:
  31. return ocr_res_list[0]
  32. else:
  33. return ocr_res_list
  34. def get_oriclas_results(inputs, predictor):
  35. results = []
  36. img_list = [img_info["img"] for img_info in inputs]
  37. for input, pred in zip(inputs, predictor(img_list)):
  38. results.append(pred)
  39. angle = int(pred["label_names"][0])
  40. input["img"] = rotate_image(input["img"], angle)
  41. return results
  42. def get_uvdoc_results(inputs, predictor):
  43. results = []
  44. img_list = [img_info["img"] for img_info in inputs]
  45. for input, pred in zip(inputs, predictor(img_list)):
  46. results.append(pred)
  47. img = np.array(pred["doctr_img"], dtype=np.uint8)
  48. input["img"] = img
  49. return results
  50. def get_predictor_res(predictor, input):
  51. """get ocr res"""
  52. result_list = []
  53. if isinstance(input, list):
  54. img = [im["img"] for im in input]
  55. elif isinstance(input, dict):
  56. img = input["img"]
  57. else:
  58. img = input
  59. for res in predictor(img):
  60. result_list.append(res)
  61. if len(result_list) == 1:
  62. return result_list[0]
  63. else:
  64. return result_list
  65. def rotate_image(image_array, rotate_angle):
  66. """rotate image"""
  67. assert (
  68. rotate_angle >= 0 and rotate_angle < 360
  69. ), "rotate_angle must in [0-360), but get {rotate_angle}."
  70. return rotate(image_array, rotate_angle, reshape=True)
  71. def get_table_text_from_html(all_table_res):
  72. all_table_ocr_res = []
  73. structure_res = []
  74. for table_res in all_table_res:
  75. table_list = []
  76. table_lines = re.findall("<tr>(.*?)</tr>", table_res["html"])
  77. single_table_ocr_res = []
  78. for td_line in table_lines:
  79. table_list.extend(re.findall("<td.*?>(.*?)</td>", td_line))
  80. for text in table_list:
  81. text = text.replace(" ", "")
  82. single_table_ocr_res.append(text)
  83. all_table_ocr_res.append(" ".join(single_table_ocr_res))
  84. structure_res.append(
  85. {
  86. "layout_bbox": table_res["layout_bbox"],
  87. "table": table_res["html"],
  88. }
  89. )
  90. return structure_res, all_table_ocr_res
  91. def format_key(key_list):
  92. """format key"""
  93. if key_list == "":
  94. return "未内置默认字段,请输入确定的key"
  95. if isinstance(key_list, list):
  96. return key_list
  97. key_list = re.sub(r"[\t\n\r\f\v]", "", key_list)
  98. key_list = key_list.replace(",", ",").split(",")
  99. return key_list
  100. def sorted_layout_boxes(res, w):
  101. """
  102. Sort text boxes in order from top to bottom, left to right
  103. args:
  104. res(list):ppstructure results
  105. w(int):image width
  106. return:
  107. sorted results(list)
  108. """
  109. num_boxes = len(res)
  110. if num_boxes == 1:
  111. res[0]["layout"] = "single"
  112. return res
  113. # Sort on the y axis first or sort it on the x axis
  114. sorted_boxes = sorted(res, key=lambda x: (x["layout_bbox"][1], x["layout_bbox"][0]))
  115. _boxes = list(sorted_boxes)
  116. new_res = []
  117. res_left = []
  118. res_mid = []
  119. res_right = []
  120. i = 0
  121. while True:
  122. if i >= num_boxes:
  123. break
  124. # Check if there are three columns of pictures
  125. if (
  126. _boxes[i]["layout_bbox"][0] > w / 4
  127. and _boxes[i]["layout_bbox"][0] + _boxes[i]["layout_bbox"][2] < 3 * w / 4
  128. ):
  129. _boxes[i]["layout"] = "double"
  130. res_mid.append(_boxes[i])
  131. i += 1
  132. # Check that the bbox is on the left
  133. elif (
  134. _boxes[i]["layout_bbox"][0] < w / 4
  135. and _boxes[i]["layout_bbox"][0] + _boxes[i]["layout_bbox"][2] < 3 * w / 5
  136. ):
  137. _boxes[i]["layout"] = "double"
  138. res_left.append(_boxes[i])
  139. i += 1
  140. elif (
  141. _boxes[i]["layout_bbox"][0] > 2 * w / 5
  142. and _boxes[i]["layout_bbox"][0] + _boxes[i]["layout_bbox"][2] < w
  143. ):
  144. _boxes[i]["layout"] = "double"
  145. res_right.append(_boxes[i])
  146. i += 1
  147. else:
  148. new_res += res_left
  149. new_res += res_right
  150. _boxes[i]["layout"] = "single"
  151. new_res.append(_boxes[i])
  152. res_left = []
  153. res_right = []
  154. i += 1
  155. res_left = sorted(res_left, key=lambda x: (x["layout_bbox"][1]))
  156. res_mid = sorted(res_mid, key=lambda x: (x["layout_bbox"][1]))
  157. res_right = sorted(res_right, key=lambda x: (x["layout_bbox"][1]))
  158. if res_left:
  159. new_res += res_left
  160. if res_mid:
  161. new_res += res_mid
  162. if res_right:
  163. new_res += res_right
  164. return new_res