utils.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. __all__ = ["convert_points_to_boxes", "get_sub_regions_ocr_res"]
  15. import numpy as np
  16. import copy
  17. def convert_points_to_boxes(dt_polys):
  18. if len(dt_polys) > 0:
  19. dt_polys_tmp = dt_polys.copy()
  20. dt_polys_tmp = np.array(dt_polys_tmp)
  21. boxes_left = np.min(dt_polys_tmp[:, :, 0], axis=1)
  22. boxes_right = np.max(dt_polys_tmp[:, :, 0], axis=1)
  23. boxes_top = np.min(dt_polys_tmp[:, :, 1], axis=1)
  24. boxes_bottom = np.max(dt_polys_tmp[:, :, 1], axis=1)
  25. dt_boxes = np.array([boxes_left, boxes_top, boxes_right, boxes_bottom])
  26. dt_boxes = dt_boxes.T
  27. else:
  28. dt_boxes = np.array([])
  29. return dt_boxes
  30. def get_overlap_boxes_idx(src_boxes, ref_boxes):
  31. """get overlap boxes idx"""
  32. match_idx_list = []
  33. src_boxes_num = len(src_boxes)
  34. if src_boxes_num > 0 and len(ref_boxes) > 0:
  35. for rno in range(len(ref_boxes)):
  36. ref_box = ref_boxes[rno]
  37. x1 = np.maximum(ref_box[0], src_boxes[:, 0])
  38. y1 = np.maximum(ref_box[1], src_boxes[:, 1])
  39. x2 = np.minimum(ref_box[2], src_boxes[:, 2])
  40. y2 = np.minimum(ref_box[3], src_boxes[:, 3])
  41. pub_w = x2 - x1
  42. pub_h = y2 - y1
  43. match_idx = np.where((pub_w > 3) & (pub_h > 3))[0]
  44. match_idx_list.extend(match_idx)
  45. return match_idx_list
  46. def get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=True):
  47. """
  48. :param flag_within: True (within the object regions), False (outside the object regions)
  49. :return:
  50. """
  51. sub_regions_ocr_res = copy.deepcopy(overall_ocr_res)
  52. sub_regions_ocr_res["input_img"] = overall_ocr_res["input_img"]
  53. sub_regions_ocr_res["img_id"] = -1
  54. sub_regions_ocr_res["dt_polys"] = []
  55. sub_regions_ocr_res["rec_text"] = []
  56. sub_regions_ocr_res["rec_score"] = []
  57. sub_regions_ocr_res["dt_boxes"] = []
  58. overall_text_boxes = overall_ocr_res["dt_boxes"]
  59. match_idx_list = get_overlap_boxes_idx(overall_text_boxes, object_boxes)
  60. match_idx_list = list(set(match_idx_list))
  61. for box_no in range(len(overall_text_boxes)):
  62. if flag_within:
  63. if box_no in match_idx_list:
  64. flag_match = True
  65. else:
  66. flag_match = False
  67. else:
  68. if box_no not in match_idx_list:
  69. flag_match = True
  70. else:
  71. flag_match = False
  72. if flag_match:
  73. sub_regions_ocr_res["dt_polys"].append(overall_ocr_res["dt_polys"][box_no])
  74. sub_regions_ocr_res["rec_text"].append(overall_ocr_res["rec_text"][box_no])
  75. sub_regions_ocr_res["rec_score"].append(
  76. overall_ocr_res["rec_score"][box_no]
  77. )
  78. sub_regions_ocr_res["dt_boxes"].append(overall_ocr_res["dt_boxes"][box_no])
  79. return sub_regions_ocr_res