predict_cls.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import cv2
  2. import copy
  3. import numpy as np
  4. import math
  5. import time
  6. import torch
  7. from ...pytorchocr.base_ocr_v20 import BaseOCRV20
  8. from . import pytorchocr_utility as utility
  9. from ...pytorchocr.postprocess import build_post_process
  10. class TextClassifier(BaseOCRV20):
  11. def __init__(self, args, **kwargs):
  12. self.device = args.device
  13. self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
  14. self.cls_batch_num = args.cls_batch_num
  15. self.cls_thresh = args.cls_thresh
  16. postprocess_params = {
  17. 'name': 'ClsPostProcess',
  18. "label_list": args.label_list,
  19. }
  20. self.postprocess_op = build_post_process(postprocess_params)
  21. self.weights_path = args.cls_model_path
  22. self.yaml_path = args.cls_yaml_path
  23. network_config = utility.get_arch_config(self.weights_path)
  24. super(TextClassifier, self).__init__(network_config, **kwargs)
  25. self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
  26. self.limited_max_width = args.limited_max_width
  27. self.limited_min_width = args.limited_min_width
  28. self.load_pytorch_weights(self.weights_path)
  29. self.net.eval()
  30. self.net.to(self.device)
  31. def resize_norm_img(self, img):
  32. imgC, imgH, imgW = self.cls_image_shape
  33. h = img.shape[0]
  34. w = img.shape[1]
  35. ratio = w / float(h)
  36. imgW = max(min(imgW, self.limited_max_width), self.limited_min_width)
  37. ratio_imgH = math.ceil(imgH * ratio)
  38. ratio_imgH = max(ratio_imgH, self.limited_min_width)
  39. if ratio_imgH > imgW:
  40. resized_w = imgW
  41. else:
  42. resized_w = int(math.ceil(imgH * ratio))
  43. resized_image = cv2.resize(img, (resized_w, imgH))
  44. resized_image = resized_image.astype('float32')
  45. if self.cls_image_shape[0] == 1:
  46. resized_image = resized_image / 255
  47. resized_image = resized_image[np.newaxis, :]
  48. else:
  49. resized_image = resized_image.transpose((2, 0, 1)) / 255
  50. resized_image -= 0.5
  51. resized_image /= 0.5
  52. padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
  53. padding_im[:, :, 0:resized_w] = resized_image
  54. return padding_im
  55. def __call__(self, img_list):
  56. img_list = copy.deepcopy(img_list)
  57. img_num = len(img_list)
  58. # Calculate the aspect ratio of all text bars
  59. width_list = []
  60. for img in img_list:
  61. width_list.append(img.shape[1] / float(img.shape[0]))
  62. # Sorting can speed up the cls process
  63. indices = np.argsort(np.array(width_list))
  64. cls_res = [['', 0.0]] * img_num
  65. batch_num = self.cls_batch_num
  66. elapse = 0
  67. for beg_img_no in range(0, img_num, batch_num):
  68. end_img_no = min(img_num, beg_img_no + batch_num)
  69. norm_img_batch = []
  70. max_wh_ratio = 0
  71. for ino in range(beg_img_no, end_img_no):
  72. h, w = img_list[indices[ino]].shape[0:2]
  73. wh_ratio = w * 1.0 / h
  74. max_wh_ratio = max(max_wh_ratio, wh_ratio)
  75. for ino in range(beg_img_no, end_img_no):
  76. norm_img = self.resize_norm_img(img_list[indices[ino]])
  77. norm_img = norm_img[np.newaxis, :]
  78. norm_img_batch.append(norm_img)
  79. norm_img_batch = np.concatenate(norm_img_batch)
  80. norm_img_batch = norm_img_batch.copy()
  81. starttime = time.time()
  82. with torch.no_grad():
  83. inp = torch.from_numpy(norm_img_batch)
  84. inp = inp.to(self.device)
  85. prob_out = self.net(inp)
  86. prob_out = prob_out.cpu().numpy()
  87. cls_result = self.postprocess_op(prob_out)
  88. elapse += time.time() - starttime
  89. for rno in range(len(cls_result)):
  90. label, score = cls_result[rno]
  91. cls_res[indices[beg_img_no + rno]] = [label, score]
  92. if '180' in label and score > self.cls_thresh:
  93. img_list[indices[beg_img_no + rno]] = cv2.rotate(
  94. img_list[indices[beg_img_no + rno]], 1)
  95. return img_list, cls_res, elapse