predict_system.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import cv2
  2. import copy
  3. import numpy as np
  4. from . import predict_rec
  5. from . import predict_det
  6. from . import predict_cls
  7. class TextSystem(object):
  8. def __init__(self, args, **kwargs):
  9. self.text_detector = predict_det.TextDetector(args, **kwargs)
  10. self.text_recognizer = predict_rec.TextRecognizer(args, **kwargs)
  11. self.use_angle_cls = args.use_angle_cls
  12. self.drop_score = args.drop_score
  13. if self.use_angle_cls:
  14. self.text_classifier = predict_cls.TextClassifier(args, **kwargs)
  15. def get_rotate_crop_image(self, img, points):
  16. '''
  17. img_height, img_width = img.shape[0:2]
  18. left = int(np.min(points[:, 0]))
  19. right = int(np.max(points[:, 0]))
  20. top = int(np.min(points[:, 1]))
  21. bottom = int(np.max(points[:, 1]))
  22. img_crop = img[top:bottom, left:right, :].copy()
  23. points[:, 0] = points[:, 0] - left
  24. points[:, 1] = points[:, 1] - top
  25. '''
  26. img_crop_width = int(
  27. max(
  28. np.linalg.norm(points[0] - points[1]),
  29. np.linalg.norm(points[2] - points[3])))
  30. img_crop_height = int(
  31. max(
  32. np.linalg.norm(points[0] - points[3]),
  33. np.linalg.norm(points[1] - points[2])))
  34. pts_std = np.float32([[0, 0], [img_crop_width, 0],
  35. [img_crop_width, img_crop_height],
  36. [0, img_crop_height]])
  37. M = cv2.getPerspectiveTransform(points, pts_std)
  38. dst_img = cv2.warpPerspective(
  39. img,
  40. M, (img_crop_width, img_crop_height),
  41. borderMode=cv2.BORDER_REPLICATE,
  42. flags=cv2.INTER_CUBIC)
  43. dst_img_height, dst_img_width = dst_img.shape[0:2]
  44. if dst_img_height * 1.0 / dst_img_width >= 1.5:
  45. dst_img = np.rot90(dst_img)
  46. return dst_img
  47. def __call__(self, img):
  48. ori_im = img.copy()
  49. dt_boxes, elapse = self.text_detector(img)
  50. print("dt_boxes num : {}, elapse : {}".format(
  51. len(dt_boxes), elapse))
  52. if dt_boxes is None:
  53. return None, None
  54. img_crop_list = []
  55. dt_boxes = sorted_boxes(dt_boxes)
  56. for bno in range(len(dt_boxes)):
  57. tmp_box = copy.deepcopy(dt_boxes[bno])
  58. img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
  59. img_crop_list.append(img_crop)
  60. if self.use_angle_cls:
  61. img_crop_list, angle_list, elapse = self.text_classifier(
  62. img_crop_list)
  63. print("cls num : {}, elapse : {}".format(
  64. len(img_crop_list), elapse))
  65. rec_res, elapse = self.text_recognizer(img_crop_list)
  66. print("rec_res num : {}, elapse : {}".format(
  67. len(rec_res), elapse))
  68. # self.print_draw_crop_rec_res(img_crop_list, rec_res)
  69. filter_boxes, filter_rec_res = [], []
  70. for box, rec_reuslt in zip(dt_boxes, rec_res):
  71. text, score = rec_reuslt
  72. if score >= self.drop_score:
  73. filter_boxes.append(box)
  74. filter_rec_res.append(rec_reuslt)
  75. return filter_boxes, filter_rec_res
  76. def sorted_boxes(dt_boxes):
  77. """
  78. Sort text boxes in order from top to bottom, left to right
  79. args:
  80. dt_boxes(array):detected text boxes with shape [4, 2]
  81. return:
  82. sorted boxes(array) with shape [4, 2]
  83. """
  84. num_boxes = dt_boxes.shape[0]
  85. sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
  86. _boxes = list(sorted_boxes)
  87. for i in range(num_boxes - 1):
  88. if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
  89. (_boxes[i + 1][0][0] < _boxes[i][0][0]):
  90. tmp = _boxes[i]
  91. _boxes[i] = _boxes[i + 1]
  92. _boxes[i + 1] = tmp
  93. return _boxes