predict_rec.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. from PIL import Image
  2. import cv2
  3. import numpy as np
  4. import math
  5. import time
  6. import torch
  7. from tqdm import tqdm
  8. from ...pytorchocr.base_ocr_v20 import BaseOCRV20
  9. from . import pytorchocr_utility as utility
  10. from ...pytorchocr.postprocess import build_post_process
  11. class TextRecognizer(BaseOCRV20):
  12. def __init__(self, args, **kwargs):
  13. self.device = args.device
  14. self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
  15. self.character_type = args.rec_char_type
  16. self.rec_batch_num = args.rec_batch_num
  17. self.rec_algorithm = args.rec_algorithm
  18. self.max_text_length = args.max_text_length
  19. postprocess_params = {
  20. 'name': 'CTCLabelDecode',
  21. "character_type": args.rec_char_type,
  22. "character_dict_path": args.rec_char_dict_path,
  23. "use_space_char": args.use_space_char
  24. }
  25. if self.rec_algorithm == "SRN":
  26. postprocess_params = {
  27. 'name': 'SRNLabelDecode',
  28. "character_type": args.rec_char_type,
  29. "character_dict_path": args.rec_char_dict_path,
  30. "use_space_char": args.use_space_char
  31. }
  32. elif self.rec_algorithm == "RARE":
  33. postprocess_params = {
  34. 'name': 'AttnLabelDecode',
  35. "character_type": args.rec_char_type,
  36. "character_dict_path": args.rec_char_dict_path,
  37. "use_space_char": args.use_space_char
  38. }
  39. elif self.rec_algorithm == 'NRTR':
  40. postprocess_params = {
  41. 'name': 'NRTRLabelDecode',
  42. "character_dict_path": args.rec_char_dict_path,
  43. "use_space_char": args.use_space_char
  44. }
  45. elif self.rec_algorithm == "SAR":
  46. postprocess_params = {
  47. 'name': 'SARLabelDecode',
  48. "character_dict_path": args.rec_char_dict_path,
  49. "use_space_char": args.use_space_char
  50. }
  51. elif self.rec_algorithm == 'ViTSTR':
  52. postprocess_params = {
  53. 'name': 'ViTSTRLabelDecode',
  54. "character_dict_path": args.rec_char_dict_path,
  55. "use_space_char": args.use_space_char
  56. }
  57. elif self.rec_algorithm == "CAN":
  58. self.inverse = args.rec_image_inverse
  59. postprocess_params = {
  60. 'name': 'CANLabelDecode',
  61. "character_dict_path": args.rec_char_dict_path,
  62. "use_space_char": args.use_space_char
  63. }
  64. elif self.rec_algorithm == 'RFL':
  65. postprocess_params = {
  66. 'name': 'RFLLabelDecode',
  67. "character_dict_path": None,
  68. "use_space_char": args.use_space_char
  69. }
  70. self.postprocess_op = build_post_process(postprocess_params)
  71. self.limited_max_width = args.limited_max_width
  72. self.limited_min_width = args.limited_min_width
  73. self.weights_path = args.rec_model_path
  74. self.yaml_path = args.rec_yaml_path
  75. network_config = utility.get_arch_config(self.weights_path)
  76. weights = self.read_pytorch_weights(self.weights_path)
  77. self.out_channels = self.get_out_channels(weights)
  78. if self.rec_algorithm == 'NRTR':
  79. self.out_channels = list(weights.values())[-1].numpy().shape[0]
  80. elif self.rec_algorithm == 'SAR':
  81. self.out_channels = list(weights.values())[-3].numpy().shape[0]
  82. kwargs['out_channels'] = self.out_channels
  83. super(TextRecognizer, self).__init__(network_config, **kwargs)
  84. self.load_state_dict(weights)
  85. self.net.eval()
  86. self.net.to(self.device)
  87. def resize_norm_img(self, img, max_wh_ratio):
  88. imgC, imgH, imgW = self.rec_image_shape
  89. if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR':
  90. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  91. # return padding_im
  92. image_pil = Image.fromarray(np.uint8(img))
  93. if self.rec_algorithm == 'ViTSTR':
  94. img = image_pil.resize([imgW, imgH], Image.BICUBIC)
  95. else:
  96. img = image_pil.resize([imgW, imgH], Image.ANTIALIAS)
  97. img = np.array(img)
  98. norm_img = np.expand_dims(img, -1)
  99. norm_img = norm_img.transpose((2, 0, 1))
  100. if self.rec_algorithm == 'ViTSTR':
  101. norm_img = norm_img.astype(np.float32) / 255.
  102. else:
  103. norm_img = norm_img.astype(np.float32) / 128. - 1.
  104. return norm_img
  105. elif self.rec_algorithm == 'RFL':
  106. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  107. resized_image = cv2.resize(
  108. img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
  109. resized_image = resized_image.astype('float32')
  110. resized_image = resized_image / 255
  111. resized_image = resized_image[np.newaxis, :]
  112. resized_image -= 0.5
  113. resized_image /= 0.5
  114. return resized_image
  115. assert imgC == img.shape[2]
  116. max_wh_ratio = max(max_wh_ratio, imgW / imgH)
  117. imgW = int((imgH * max_wh_ratio))
  118. imgW = max(min(imgW, self.limited_max_width), self.limited_min_width)
  119. h, w = img.shape[:2]
  120. ratio = w / float(h)
  121. ratio_imgH = math.ceil(imgH * ratio)
  122. ratio_imgH = max(ratio_imgH, self.limited_min_width)
  123. if ratio_imgH > imgW:
  124. resized_w = imgW
  125. else:
  126. resized_w = int(ratio_imgH)
  127. resized_image = cv2.resize(img, (resized_w, imgH))
  128. resized_image = resized_image.astype('float32')
  129. resized_image = resized_image.transpose((2, 0, 1)) / 255
  130. resized_image -= 0.5
  131. resized_image /= 0.5
  132. padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
  133. padding_im[:, :, 0:resized_w] = resized_image
  134. return padding_im
  135. def resize_norm_img_svtr(self, img, image_shape):
  136. imgC, imgH, imgW = image_shape
  137. resized_image = cv2.resize(
  138. img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
  139. resized_image = resized_image.astype('float32')
  140. resized_image = resized_image.transpose((2, 0, 1)) / 255
  141. resized_image -= 0.5
  142. resized_image /= 0.5
  143. return resized_image
  144. def resize_norm_img_srn(self, img, image_shape):
  145. imgC, imgH, imgW = image_shape
  146. img_black = np.zeros((imgH, imgW))
  147. im_hei = img.shape[0]
  148. im_wid = img.shape[1]
  149. if im_wid <= im_hei * 1:
  150. img_new = cv2.resize(img, (imgH * 1, imgH))
  151. elif im_wid <= im_hei * 2:
  152. img_new = cv2.resize(img, (imgH * 2, imgH))
  153. elif im_wid <= im_hei * 3:
  154. img_new = cv2.resize(img, (imgH * 3, imgH))
  155. else:
  156. img_new = cv2.resize(img, (imgW, imgH))
  157. img_np = np.asarray(img_new)
  158. img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
  159. img_black[:, 0:img_np.shape[1]] = img_np
  160. img_black = img_black[:, :, np.newaxis]
  161. row, col, c = img_black.shape
  162. c = 1
  163. return np.reshape(img_black, (c, row, col)).astype(np.float32)
  164. def srn_other_inputs(self, image_shape, num_heads, max_text_length):
  165. imgC, imgH, imgW = image_shape
  166. feature_dim = int((imgH / 8) * (imgW / 8))
  167. encoder_word_pos = np.array(range(0, feature_dim)).reshape(
  168. (feature_dim, 1)).astype('int64')
  169. gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
  170. (max_text_length, 1)).astype('int64')
  171. gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
  172. gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
  173. [-1, 1, max_text_length, max_text_length])
  174. gsrm_slf_attn_bias1 = np.tile(
  175. gsrm_slf_attn_bias1,
  176. [1, num_heads, 1, 1]).astype('float32') * [-1e9]
  177. gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
  178. [-1, 1, max_text_length, max_text_length])
  179. gsrm_slf_attn_bias2 = np.tile(
  180. gsrm_slf_attn_bias2,
  181. [1, num_heads, 1, 1]).astype('float32') * [-1e9]
  182. encoder_word_pos = encoder_word_pos[np.newaxis, :]
  183. gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
  184. return [
  185. encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
  186. gsrm_slf_attn_bias2
  187. ]
  188. def process_image_srn(self, img, image_shape, num_heads, max_text_length):
  189. norm_img = self.resize_norm_img_srn(img, image_shape)
  190. norm_img = norm_img[np.newaxis, :]
  191. [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
  192. self.srn_other_inputs(image_shape, num_heads, max_text_length)
  193. gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
  194. gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
  195. encoder_word_pos = encoder_word_pos.astype(np.int64)
  196. gsrm_word_pos = gsrm_word_pos.astype(np.int64)
  197. return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
  198. gsrm_slf_attn_bias2)
  199. def resize_norm_img_sar(self, img, image_shape,
  200. width_downsample_ratio=0.25):
  201. imgC, imgH, imgW_min, imgW_max = image_shape
  202. h = img.shape[0]
  203. w = img.shape[1]
  204. valid_ratio = 1.0
  205. # make sure new_width is an integral multiple of width_divisor.
  206. width_divisor = int(1 / width_downsample_ratio)
  207. # resize
  208. ratio = w / float(h)
  209. resize_w = math.ceil(imgH * ratio)
  210. if resize_w % width_divisor != 0:
  211. resize_w = round(resize_w / width_divisor) * width_divisor
  212. if imgW_min is not None:
  213. resize_w = max(imgW_min, resize_w)
  214. if imgW_max is not None:
  215. valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
  216. resize_w = min(imgW_max, resize_w)
  217. resized_image = cv2.resize(img, (resize_w, imgH))
  218. resized_image = resized_image.astype('float32')
  219. # norm
  220. if image_shape[0] == 1:
  221. resized_image = resized_image / 255
  222. resized_image = resized_image[np.newaxis, :]
  223. else:
  224. resized_image = resized_image.transpose((2, 0, 1)) / 255
  225. resized_image -= 0.5
  226. resized_image /= 0.5
  227. resize_shape = resized_image.shape
  228. padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
  229. padding_im[:, :, 0:resize_w] = resized_image
  230. pad_shape = padding_im.shape
  231. return padding_im, resize_shape, pad_shape, valid_ratio
  232. def norm_img_can(self, img, image_shape):
  233. img = cv2.cvtColor(
  234. img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
  235. if self.inverse:
  236. img = 255 - img
  237. if self.rec_image_shape[0] == 1:
  238. h, w = img.shape
  239. _, imgH, imgW = self.rec_image_shape
  240. if h < imgH or w < imgW:
  241. padding_h = max(imgH - h, 0)
  242. padding_w = max(imgW - w, 0)
  243. img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
  244. 'constant',
  245. constant_values=(255))
  246. img = img_padded
  247. img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
  248. img = img.astype('float32')
  249. return img
  250. def __call__(self, img_list, tqdm_enable=False, tqdm_desc="OCR-rec Predict"):
  251. img_num = len(img_list)
  252. # Calculate the aspect ratio of all text bars
  253. width_list = []
  254. for img in img_list:
  255. width_list.append(img.shape[1] / float(img.shape[0]))
  256. # Sorting can speed up the recognition process
  257. indices = np.argsort(np.array(width_list))
  258. # rec_res = []
  259. rec_res = [['', 0.0]] * img_num
  260. batch_num = self.rec_batch_num
  261. elapse = 0
  262. # for beg_img_no in range(0, img_num, batch_num):
  263. with tqdm(total=img_num, desc=tqdm_desc, disable=not tqdm_enable) as pbar:
  264. index = 0
  265. for beg_img_no in range(0, img_num, batch_num):
  266. end_img_no = min(img_num, beg_img_no + batch_num)
  267. norm_img_batch = []
  268. max_wh_ratio = 0
  269. for ino in range(beg_img_no, end_img_no):
  270. # h, w = img_list[ino].shape[0:2]
  271. h, w = img_list[indices[ino]].shape[0:2]
  272. wh_ratio = w * 1.0 / h
  273. max_wh_ratio = max(max_wh_ratio, wh_ratio)
  274. for ino in range(beg_img_no, end_img_no):
  275. if self.rec_algorithm == "SAR":
  276. norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
  277. img_list[indices[ino]], self.rec_image_shape)
  278. norm_img = norm_img[np.newaxis, :]
  279. valid_ratio = np.expand_dims(valid_ratio, axis=0)
  280. valid_ratios = []
  281. valid_ratios.append(valid_ratio)
  282. norm_img_batch.append(norm_img)
  283. elif self.rec_algorithm == "SVTR":
  284. norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
  285. self.rec_image_shape)
  286. norm_img = norm_img[np.newaxis, :]
  287. norm_img_batch.append(norm_img)
  288. elif self.rec_algorithm == "SRN":
  289. norm_img = self.process_image_srn(img_list[indices[ino]],
  290. self.rec_image_shape, 8,
  291. self.max_text_length)
  292. encoder_word_pos_list = []
  293. gsrm_word_pos_list = []
  294. gsrm_slf_attn_bias1_list = []
  295. gsrm_slf_attn_bias2_list = []
  296. encoder_word_pos_list.append(norm_img[1])
  297. gsrm_word_pos_list.append(norm_img[2])
  298. gsrm_slf_attn_bias1_list.append(norm_img[3])
  299. gsrm_slf_attn_bias2_list.append(norm_img[4])
  300. norm_img_batch.append(norm_img[0])
  301. elif self.rec_algorithm == "CAN":
  302. norm_img = self.norm_img_can(img_list[indices[ino]],
  303. max_wh_ratio)
  304. norm_img = norm_img[np.newaxis, :]
  305. norm_img_batch.append(norm_img)
  306. norm_image_mask = np.ones(norm_img.shape, dtype='float32')
  307. word_label = np.ones([1, 36], dtype='int64')
  308. norm_img_mask_batch = []
  309. word_label_list = []
  310. norm_img_mask_batch.append(norm_image_mask)
  311. word_label_list.append(word_label)
  312. else:
  313. norm_img = self.resize_norm_img(img_list[indices[ino]],
  314. max_wh_ratio)
  315. norm_img = norm_img[np.newaxis, :]
  316. norm_img_batch.append(norm_img)
  317. norm_img_batch = np.concatenate(norm_img_batch)
  318. norm_img_batch = norm_img_batch.copy()
  319. if self.rec_algorithm == "SRN":
  320. starttime = time.time()
  321. encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
  322. gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
  323. gsrm_slf_attn_bias1_list = np.concatenate(
  324. gsrm_slf_attn_bias1_list)
  325. gsrm_slf_attn_bias2_list = np.concatenate(
  326. gsrm_slf_attn_bias2_list)
  327. with torch.no_grad():
  328. inp = torch.from_numpy(norm_img_batch)
  329. encoder_word_pos_inp = torch.from_numpy(encoder_word_pos_list)
  330. gsrm_word_pos_inp = torch.from_numpy(gsrm_word_pos_list)
  331. gsrm_slf_attn_bias1_inp = torch.from_numpy(gsrm_slf_attn_bias1_list)
  332. gsrm_slf_attn_bias2_inp = torch.from_numpy(gsrm_slf_attn_bias2_list)
  333. inp = inp.to(self.device)
  334. encoder_word_pos_inp = encoder_word_pos_inp.to(self.device)
  335. gsrm_word_pos_inp = gsrm_word_pos_inp.to(self.device)
  336. gsrm_slf_attn_bias1_inp = gsrm_slf_attn_bias1_inp.to(self.device)
  337. gsrm_slf_attn_bias2_inp = gsrm_slf_attn_bias2_inp.to(self.device)
  338. backbone_out = self.net.backbone(inp) # backbone_feat
  339. prob_out = self.net.head(backbone_out, [encoder_word_pos_inp, gsrm_word_pos_inp, gsrm_slf_attn_bias1_inp, gsrm_slf_attn_bias2_inp])
  340. # preds = {"predict": prob_out[2]}
  341. preds = {"predict": prob_out["predict"]}
  342. elif self.rec_algorithm == "SAR":
  343. starttime = time.time()
  344. # valid_ratios = np.concatenate(valid_ratios)
  345. # inputs = [
  346. # norm_img_batch,
  347. # valid_ratios,
  348. # ]
  349. with torch.no_grad():
  350. inp = torch.from_numpy(norm_img_batch)
  351. inp = inp.to(self.device)
  352. preds = self.net(inp)
  353. elif self.rec_algorithm == "CAN":
  354. starttime = time.time()
  355. norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
  356. word_label_list = np.concatenate(word_label_list)
  357. inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
  358. inp = [torch.from_numpy(e_i) for e_i in inputs]
  359. inp = [e_i.to(self.device) for e_i in inp]
  360. with torch.no_grad():
  361. outputs = self.net(inp)
  362. outputs = [v.cpu().numpy() for k, v in enumerate(outputs)]
  363. preds = outputs
  364. else:
  365. starttime = time.time()
  366. with torch.no_grad():
  367. inp = torch.from_numpy(norm_img_batch)
  368. inp = inp.to(self.device)
  369. prob_out = self.net(inp)
  370. if isinstance(prob_out, list):
  371. preds = [v.cpu().numpy() for v in prob_out]
  372. else:
  373. preds = prob_out.cpu().numpy()
  374. rec_result = self.postprocess_op(preds)
  375. for rno in range(len(rec_result)):
  376. rec_res[indices[beg_img_no + rno]] = rec_result[rno]
  377. elapse += time.time() - starttime
  378. # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
  379. current_batch_size = min(batch_num, img_num - index * batch_num)
  380. index += 1
  381. pbar.update(current_batch_size)
  382. # Fix NaN values in recognition results
  383. for i in range(len(rec_res)):
  384. text, score = rec_res[i]
  385. if isinstance(score, float) and math.isnan(score):
  386. rec_res[i] = (text, 0.0)
  387. return rec_res, elapse