predict_rec.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. from PIL import Image
  2. import cv2
  3. import numpy as np
  4. import math
  5. import time
  6. import torch
  7. from tqdm import tqdm
  8. from ...pytorchocr.base_ocr_v20 import BaseOCRV20
  9. from . import pytorchocr_utility as utility
  10. from ...pytorchocr.postprocess import build_post_process
  11. from ...pytorchocr.modeling.backbones.rec_hgnet import ConvBNAct
  12. class TextRecognizer(BaseOCRV20):
  13. def __init__(self, args, **kwargs):
  14. self.device = args.device
  15. self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
  16. self.character_type = args.rec_char_type
  17. self.rec_batch_num = args.rec_batch_num
  18. self.rec_algorithm = args.rec_algorithm
  19. self.max_text_length = args.max_text_length
  20. postprocess_params = {
  21. 'name': 'CTCLabelDecode',
  22. "character_type": args.rec_char_type,
  23. "character_dict_path": args.rec_char_dict_path,
  24. "use_space_char": args.use_space_char
  25. }
  26. if self.rec_algorithm == "SRN":
  27. postprocess_params = {
  28. 'name': 'SRNLabelDecode',
  29. "character_type": args.rec_char_type,
  30. "character_dict_path": args.rec_char_dict_path,
  31. "use_space_char": args.use_space_char
  32. }
  33. elif self.rec_algorithm == "RARE":
  34. postprocess_params = {
  35. 'name': 'AttnLabelDecode',
  36. "character_type": args.rec_char_type,
  37. "character_dict_path": args.rec_char_dict_path,
  38. "use_space_char": args.use_space_char
  39. }
  40. elif self.rec_algorithm == 'NRTR':
  41. postprocess_params = {
  42. 'name': 'NRTRLabelDecode',
  43. "character_dict_path": args.rec_char_dict_path,
  44. "use_space_char": args.use_space_char
  45. }
  46. elif self.rec_algorithm == "SAR":
  47. postprocess_params = {
  48. 'name': 'SARLabelDecode',
  49. "character_dict_path": args.rec_char_dict_path,
  50. "use_space_char": args.use_space_char
  51. }
  52. elif self.rec_algorithm == 'ViTSTR':
  53. postprocess_params = {
  54. 'name': 'ViTSTRLabelDecode',
  55. "character_dict_path": args.rec_char_dict_path,
  56. "use_space_char": args.use_space_char
  57. }
  58. elif self.rec_algorithm == "CAN":
  59. self.inverse = args.rec_image_inverse
  60. postprocess_params = {
  61. 'name': 'CANLabelDecode',
  62. "character_dict_path": args.rec_char_dict_path,
  63. "use_space_char": args.use_space_char
  64. }
  65. elif self.rec_algorithm == 'RFL':
  66. postprocess_params = {
  67. 'name': 'RFLLabelDecode',
  68. "character_dict_path": None,
  69. "use_space_char": args.use_space_char
  70. }
  71. self.postprocess_op = build_post_process(postprocess_params)
  72. self.limited_max_width = args.limited_max_width
  73. self.limited_min_width = args.limited_min_width
  74. self.weights_path = args.rec_model_path
  75. self.yaml_path = args.rec_yaml_path
  76. network_config = utility.get_arch_config(self.weights_path)
  77. weights = self.read_pytorch_weights(self.weights_path)
  78. self.out_channels = self.get_out_channels(weights)
  79. if self.rec_algorithm == 'NRTR':
  80. self.out_channels = list(weights.values())[-1].numpy().shape[0]
  81. elif self.rec_algorithm == 'SAR':
  82. self.out_channels = list(weights.values())[-3].numpy().shape[0]
  83. kwargs['out_channels'] = self.out_channels
  84. super(TextRecognizer, self).__init__(network_config, **kwargs)
  85. self.load_state_dict(weights)
  86. self.net.eval()
  87. self.net.to(self.device)
  88. for module in self.net.modules():
  89. if isinstance(module, ConvBNAct):
  90. if module.use_act:
  91. torch.quantization.fuse_modules(module, ['conv', 'bn', 'act'], inplace=True)
  92. else:
  93. torch.quantization.fuse_modules(module, ['conv', 'bn'], inplace=True)
  94. def resize_norm_img(self, img, max_wh_ratio):
  95. imgC, imgH, imgW = self.rec_image_shape
  96. if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR':
  97. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  98. # return padding_im
  99. image_pil = Image.fromarray(np.uint8(img))
  100. if self.rec_algorithm == 'ViTSTR':
  101. img = image_pil.resize([imgW, imgH], Image.BICUBIC)
  102. else:
  103. img = image_pil.resize([imgW, imgH], Image.ANTIALIAS)
  104. img = np.array(img)
  105. norm_img = np.expand_dims(img, -1)
  106. norm_img = norm_img.transpose((2, 0, 1))
  107. if self.rec_algorithm == 'ViTSTR':
  108. norm_img = norm_img.astype(np.float32) / 255.
  109. else:
  110. norm_img = norm_img.astype(np.float32) / 128. - 1.
  111. return norm_img
  112. elif self.rec_algorithm == 'RFL':
  113. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  114. resized_image = cv2.resize(
  115. img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
  116. resized_image = resized_image.astype('float32')
  117. resized_image = resized_image / 255
  118. resized_image = resized_image[np.newaxis, :]
  119. resized_image -= 0.5
  120. resized_image /= 0.5
  121. return resized_image
  122. assert imgC == img.shape[2]
  123. max_wh_ratio = max(max_wh_ratio, imgW / imgH)
  124. imgW = int(imgH * max_wh_ratio)
  125. imgW = max(min(imgW, self.limited_max_width), self.limited_min_width)
  126. h, w = img.shape[:2]
  127. ratio = w / float(h)
  128. ratio_imgH = max(math.ceil(imgH * ratio), self.limited_min_width)
  129. resized_w = min(imgW,int(ratio_imgH))
  130. resized_image = cv2.resize(img, (resized_w, imgH)) /127.5 - 1
  131. padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
  132. padding_im[:, :, 0:resized_w] = resized_image.transpose((2, 0, 1))
  133. return padding_im
  134. def resize_norm_img_svtr(self, img, image_shape):
  135. imgC, imgH, imgW = image_shape
  136. resized_image = cv2.resize(
  137. img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
  138. resized_image = resized_image.astype('float32')
  139. resized_image = resized_image.transpose((2, 0, 1)) / 255
  140. resized_image -= 0.5
  141. resized_image /= 0.5
  142. return resized_image
  143. def resize_norm_img_srn(self, img, image_shape):
  144. imgC, imgH, imgW = image_shape
  145. img_black = np.zeros((imgH, imgW))
  146. im_hei = img.shape[0]
  147. im_wid = img.shape[1]
  148. if im_wid <= im_hei * 1:
  149. img_new = cv2.resize(img, (imgH * 1, imgH))
  150. elif im_wid <= im_hei * 2:
  151. img_new = cv2.resize(img, (imgH * 2, imgH))
  152. elif im_wid <= im_hei * 3:
  153. img_new = cv2.resize(img, (imgH * 3, imgH))
  154. else:
  155. img_new = cv2.resize(img, (imgW, imgH))
  156. img_np = np.asarray(img_new)
  157. img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
  158. img_black[:, 0:img_np.shape[1]] = img_np
  159. img_black = img_black[:, :, np.newaxis]
  160. row, col, c = img_black.shape
  161. c = 1
  162. return np.reshape(img_black, (c, row, col)).astype(np.float32)
  163. def srn_other_inputs(self, image_shape, num_heads, max_text_length):
  164. imgC, imgH, imgW = image_shape
  165. feature_dim = int((imgH / 8) * (imgW / 8))
  166. encoder_word_pos = np.array(range(0, feature_dim)).reshape(
  167. (feature_dim, 1)).astype('int64')
  168. gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
  169. (max_text_length, 1)).astype('int64')
  170. gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
  171. gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
  172. [-1, 1, max_text_length, max_text_length])
  173. gsrm_slf_attn_bias1 = np.tile(
  174. gsrm_slf_attn_bias1,
  175. [1, num_heads, 1, 1]).astype('float32') * [-1e9]
  176. gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
  177. [-1, 1, max_text_length, max_text_length])
  178. gsrm_slf_attn_bias2 = np.tile(
  179. gsrm_slf_attn_bias2,
  180. [1, num_heads, 1, 1]).astype('float32') * [-1e9]
  181. encoder_word_pos = encoder_word_pos[np.newaxis, :]
  182. gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
  183. return [
  184. encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
  185. gsrm_slf_attn_bias2
  186. ]
  187. def process_image_srn(self, img, image_shape, num_heads, max_text_length):
  188. norm_img = self.resize_norm_img_srn(img, image_shape)
  189. norm_img = norm_img[np.newaxis, :]
  190. [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
  191. self.srn_other_inputs(image_shape, num_heads, max_text_length)
  192. gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
  193. gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
  194. encoder_word_pos = encoder_word_pos.astype(np.int64)
  195. gsrm_word_pos = gsrm_word_pos.astype(np.int64)
  196. return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
  197. gsrm_slf_attn_bias2)
  198. def resize_norm_img_sar(self, img, image_shape,
  199. width_downsample_ratio=0.25):
  200. imgC, imgH, imgW_min, imgW_max = image_shape
  201. h = img.shape[0]
  202. w = img.shape[1]
  203. valid_ratio = 1.0
  204. # make sure new_width is an integral multiple of width_divisor.
  205. width_divisor = int(1 / width_downsample_ratio)
  206. # resize
  207. ratio = w / float(h)
  208. resize_w = math.ceil(imgH * ratio)
  209. if resize_w % width_divisor != 0:
  210. resize_w = round(resize_w / width_divisor) * width_divisor
  211. if imgW_min is not None:
  212. resize_w = max(imgW_min, resize_w)
  213. if imgW_max is not None:
  214. valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
  215. resize_w = min(imgW_max, resize_w)
  216. resized_image = cv2.resize(img, (resize_w, imgH))
  217. resized_image = resized_image.astype('float32')
  218. # norm
  219. if image_shape[0] == 1:
  220. resized_image = resized_image / 255
  221. resized_image = resized_image[np.newaxis, :]
  222. else:
  223. resized_image = resized_image.transpose((2, 0, 1)) / 255
  224. resized_image -= 0.5
  225. resized_image /= 0.5
  226. resize_shape = resized_image.shape
  227. padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
  228. padding_im[:, :, 0:resize_w] = resized_image
  229. pad_shape = padding_im.shape
  230. return padding_im, resize_shape, pad_shape, valid_ratio
  231. def norm_img_can(self, img, image_shape):
  232. img = cv2.cvtColor(
  233. img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
  234. if self.inverse:
  235. img = 255 - img
  236. if self.rec_image_shape[0] == 1:
  237. h, w = img.shape
  238. _, imgH, imgW = self.rec_image_shape
  239. if h < imgH or w < imgW:
  240. padding_h = max(imgH - h, 0)
  241. padding_w = max(imgW - w, 0)
  242. img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
  243. 'constant',
  244. constant_values=(255))
  245. img = img_padded
  246. img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
  247. img = img.astype('float32')
  248. return img
  249. def __call__(self, img_list, tqdm_enable=False, tqdm_desc="OCR-rec Predict"):
  250. img_num = len(img_list)
  251. # Calculate the aspect ratio of all text bars
  252. width_list = []
  253. for img in img_list:
  254. width_list.append(img.shape[1] / float(img.shape[0]))
  255. # Sorting can speed up the recognition process
  256. indices = np.argsort(np.array(width_list))
  257. # rec_res = []
  258. rec_res = [['', 0.0]] * img_num
  259. batch_num = self.rec_batch_num
  260. elapse = 0
  261. # for beg_img_no in range(0, img_num, batch_num):
  262. with tqdm(total=img_num, desc=tqdm_desc, disable=not tqdm_enable) as pbar:
  263. index = 0
  264. for beg_img_no in range(0, img_num, batch_num):
  265. end_img_no = min(img_num, beg_img_no + batch_num)
  266. norm_img_batch = []
  267. max_wh_ratio = width_list[indices[end_img_no - 1]]
  268. for ino in range(beg_img_no, end_img_no):
  269. if self.rec_algorithm == "SAR":
  270. norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
  271. img_list[indices[ino]], self.rec_image_shape)
  272. norm_img = norm_img[np.newaxis, :]
  273. valid_ratio = np.expand_dims(valid_ratio, axis=0)
  274. valid_ratios = []
  275. valid_ratios.append(valid_ratio)
  276. norm_img_batch.append(norm_img)
  277. elif self.rec_algorithm == "SVTR":
  278. norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
  279. self.rec_image_shape)
  280. norm_img = norm_img[np.newaxis, :]
  281. norm_img_batch.append(norm_img)
  282. elif self.rec_algorithm == "SRN":
  283. norm_img = self.process_image_srn(img_list[indices[ino]],
  284. self.rec_image_shape, 8,
  285. self.max_text_length)
  286. encoder_word_pos_list = []
  287. gsrm_word_pos_list = []
  288. gsrm_slf_attn_bias1_list = []
  289. gsrm_slf_attn_bias2_list = []
  290. encoder_word_pos_list.append(norm_img[1])
  291. gsrm_word_pos_list.append(norm_img[2])
  292. gsrm_slf_attn_bias1_list.append(norm_img[3])
  293. gsrm_slf_attn_bias2_list.append(norm_img[4])
  294. norm_img_batch.append(norm_img[0])
  295. elif self.rec_algorithm == "CAN":
  296. norm_img = self.norm_img_can(img_list[indices[ino]],
  297. max_wh_ratio)
  298. norm_img = norm_img[np.newaxis, :]
  299. norm_img_batch.append(norm_img)
  300. norm_image_mask = np.ones(norm_img.shape, dtype='float32')
  301. word_label = np.ones([1, 36], dtype='int64')
  302. norm_img_mask_batch = []
  303. word_label_list = []
  304. norm_img_mask_batch.append(norm_image_mask)
  305. word_label_list.append(word_label)
  306. else:
  307. norm_img = self.resize_norm_img(img_list[indices[ino]],
  308. max_wh_ratio)
  309. norm_img = norm_img[np.newaxis, :]
  310. norm_img_batch.append(norm_img)
  311. norm_img_batch = np.concatenate(norm_img_batch)
  312. norm_img_batch = norm_img_batch.copy()
  313. if self.rec_algorithm == "SRN":
  314. starttime = time.time()
  315. encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
  316. gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
  317. gsrm_slf_attn_bias1_list = np.concatenate(
  318. gsrm_slf_attn_bias1_list)
  319. gsrm_slf_attn_bias2_list = np.concatenate(
  320. gsrm_slf_attn_bias2_list)
  321. with torch.no_grad():
  322. inp = torch.from_numpy(norm_img_batch)
  323. encoder_word_pos_inp = torch.from_numpy(encoder_word_pos_list)
  324. gsrm_word_pos_inp = torch.from_numpy(gsrm_word_pos_list)
  325. gsrm_slf_attn_bias1_inp = torch.from_numpy(gsrm_slf_attn_bias1_list)
  326. gsrm_slf_attn_bias2_inp = torch.from_numpy(gsrm_slf_attn_bias2_list)
  327. inp = inp.to(self.device)
  328. encoder_word_pos_inp = encoder_word_pos_inp.to(self.device)
  329. gsrm_word_pos_inp = gsrm_word_pos_inp.to(self.device)
  330. gsrm_slf_attn_bias1_inp = gsrm_slf_attn_bias1_inp.to(self.device)
  331. gsrm_slf_attn_bias2_inp = gsrm_slf_attn_bias2_inp.to(self.device)
  332. backbone_out = self.net.backbone(inp) # backbone_feat
  333. prob_out = self.net.head(backbone_out, [encoder_word_pos_inp, gsrm_word_pos_inp, gsrm_slf_attn_bias1_inp, gsrm_slf_attn_bias2_inp])
  334. # preds = {"predict": prob_out[2]}
  335. preds = {"predict": prob_out["predict"]}
  336. elif self.rec_algorithm == "SAR":
  337. starttime = time.time()
  338. # valid_ratios = np.concatenate(valid_ratios)
  339. # inputs = [
  340. # norm_img_batch,
  341. # valid_ratios,
  342. # ]
  343. with torch.no_grad():
  344. inp = torch.from_numpy(norm_img_batch)
  345. inp = inp.to(self.device)
  346. preds = self.net(inp)
  347. elif self.rec_algorithm == "CAN":
  348. starttime = time.time()
  349. norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
  350. word_label_list = np.concatenate(word_label_list)
  351. inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
  352. inp = [torch.from_numpy(e_i) for e_i in inputs]
  353. inp = [e_i.to(self.device) for e_i in inp]
  354. with torch.no_grad():
  355. outputs = self.net(inp)
  356. outputs = [v.cpu().numpy() for k, v in enumerate(outputs)]
  357. preds = outputs
  358. else:
  359. starttime = time.time()
  360. with torch.no_grad():
  361. inp = torch.from_numpy(norm_img_batch)
  362. inp = inp.to(self.device)
  363. preds = self.net(inp)
  364. with torch.no_grad():
  365. rec_result = self.postprocess_op(preds)
  366. for rno in range(len(rec_result)):
  367. rec_res[indices[beg_img_no + rno]] = rec_result[rno]
  368. elapse += time.time() - starttime
  369. # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
  370. current_batch_size = min(batch_num, img_num - index * batch_num)
  371. index += 1
  372. pbar.update(current_batch_size)
  373. # Fix NaN values in recognition results
  374. for i in range(len(rec_res)):
  375. text, score = rec_res[i]
  376. if isinstance(score, float) and math.isnan(score):
  377. rec_res[i] = (text, 0.0)
  378. return rec_res, elapse