predict_det.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. import sys
  2. import numpy as np
  3. import time
  4. import torch
  5. from ...pytorchocr.base_ocr_v20 import BaseOCRV20
  6. from . import pytorchocr_utility as utility
  7. from ...pytorchocr.data import create_operators, transform
  8. from ...pytorchocr.postprocess import build_post_process
  9. class TextDetector(BaseOCRV20):
  10. def __init__(self, args, **kwargs):
  11. self.args = args
  12. self.det_algorithm = args.det_algorithm
  13. self.device = args.device
  14. pre_process_list = [{
  15. 'DetResizeForTest': {
  16. 'limit_side_len': args.det_limit_side_len,
  17. 'limit_type': args.det_limit_type,
  18. }
  19. }, {
  20. 'NormalizeImage': {
  21. 'std': [0.229, 0.224, 0.225],
  22. 'mean': [0.485, 0.456, 0.406],
  23. 'scale': '1./255.',
  24. 'order': 'hwc'
  25. }
  26. }, {
  27. 'ToCHWImage': None
  28. }, {
  29. 'KeepKeys': {
  30. 'keep_keys': ['image', 'shape']
  31. }
  32. }]
  33. postprocess_params = {}
  34. if self.det_algorithm == "DB":
  35. postprocess_params['name'] = 'DBPostProcess'
  36. postprocess_params["thresh"] = args.det_db_thresh
  37. postprocess_params["box_thresh"] = args.det_db_box_thresh
  38. postprocess_params["max_candidates"] = 1000
  39. postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
  40. postprocess_params["use_dilation"] = args.use_dilation
  41. postprocess_params["score_mode"] = args.det_db_score_mode
  42. elif self.det_algorithm == "DB++":
  43. postprocess_params['name'] = 'DBPostProcess'
  44. postprocess_params["thresh"] = args.det_db_thresh
  45. postprocess_params["box_thresh"] = args.det_db_box_thresh
  46. postprocess_params["max_candidates"] = 1000
  47. postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
  48. postprocess_params["use_dilation"] = args.use_dilation
  49. postprocess_params["score_mode"] = args.det_db_score_mode
  50. pre_process_list[1] = {
  51. 'NormalizeImage': {
  52. 'std': [1.0, 1.0, 1.0],
  53. 'mean':
  54. [0.48109378172549, 0.45752457890196, 0.40787054090196],
  55. 'scale': '1./255.',
  56. 'order': 'hwc'
  57. }
  58. }
  59. elif self.det_algorithm == "EAST":
  60. postprocess_params['name'] = 'EASTPostProcess'
  61. postprocess_params["score_thresh"] = args.det_east_score_thresh
  62. postprocess_params["cover_thresh"] = args.det_east_cover_thresh
  63. postprocess_params["nms_thresh"] = args.det_east_nms_thresh
  64. elif self.det_algorithm == "SAST":
  65. pre_process_list[0] = {
  66. 'DetResizeForTest': {
  67. 'resize_long': args.det_limit_side_len
  68. }
  69. }
  70. postprocess_params['name'] = 'SASTPostProcess'
  71. postprocess_params["score_thresh"] = args.det_sast_score_thresh
  72. postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
  73. self.det_sast_polygon = args.det_sast_polygon
  74. if self.det_sast_polygon:
  75. postprocess_params["sample_pts_num"] = 6
  76. postprocess_params["expand_scale"] = 1.2
  77. postprocess_params["shrink_ratio_of_width"] = 0.2
  78. else:
  79. postprocess_params["sample_pts_num"] = 2
  80. postprocess_params["expand_scale"] = 1.0
  81. postprocess_params["shrink_ratio_of_width"] = 0.3
  82. elif self.det_algorithm == "PSE":
  83. postprocess_params['name'] = 'PSEPostProcess'
  84. postprocess_params["thresh"] = args.det_pse_thresh
  85. postprocess_params["box_thresh"] = args.det_pse_box_thresh
  86. postprocess_params["min_area"] = args.det_pse_min_area
  87. postprocess_params["box_type"] = args.det_pse_box_type
  88. postprocess_params["scale"] = args.det_pse_scale
  89. self.det_pse_box_type = args.det_pse_box_type
  90. elif self.det_algorithm == "FCE":
  91. pre_process_list[0] = {
  92. 'DetResizeForTest': {
  93. 'rescale_img': [1080, 736]
  94. }
  95. }
  96. postprocess_params['name'] = 'FCEPostProcess'
  97. postprocess_params["scales"] = args.scales
  98. postprocess_params["alpha"] = args.alpha
  99. postprocess_params["beta"] = args.beta
  100. postprocess_params["fourier_degree"] = args.fourier_degree
  101. postprocess_params["box_type"] = args.det_fce_box_type
  102. else:
  103. print("unknown det_algorithm:{}".format(self.det_algorithm))
  104. sys.exit(0)
  105. self.preprocess_op = create_operators(pre_process_list)
  106. self.postprocess_op = build_post_process(postprocess_params)
  107. self.weights_path = args.det_model_path
  108. self.yaml_path = args.det_yaml_path
  109. network_config = utility.get_arch_config(self.weights_path)
  110. super(TextDetector, self).__init__(network_config, **kwargs)
  111. self.load_pytorch_weights(self.weights_path)
  112. self.net.eval()
  113. self.net.to(self.device)
  114. for module in self.net.modules():
  115. if hasattr(module, 'rep'):
  116. module.rep()
  117. def _batch_process_same_size(self, img_list):
  118. """
  119. 对相同尺寸的图像进行批处理
  120. Args:
  121. img_list: 相同尺寸的图像列表
  122. Returns:
  123. batch_results: 批处理结果列表
  124. total_elapse: 总耗时
  125. """
  126. starttime = time.time()
  127. # 预处理所有图像
  128. batch_data = []
  129. batch_shapes = []
  130. ori_imgs = []
  131. for img in img_list:
  132. ori_im = img.copy()
  133. ori_imgs.append(ori_im)
  134. data = {'image': img}
  135. data = transform(data, self.preprocess_op)
  136. if data is None:
  137. # 如果预处理失败,返回空结果
  138. return [(None, 0) for _ in img_list], 0
  139. img_processed, shape_list = data
  140. batch_data.append(img_processed)
  141. batch_shapes.append(shape_list)
  142. # 堆叠成批处理张量
  143. try:
  144. batch_tensor = np.stack(batch_data, axis=0)
  145. batch_shapes = np.stack(batch_shapes, axis=0)
  146. except Exception as e:
  147. # 如果堆叠失败,回退到逐个处理
  148. batch_results = []
  149. for img in img_list:
  150. dt_boxes, elapse = self.__call__(img)
  151. batch_results.append((dt_boxes, elapse))
  152. return batch_results, time.time() - starttime
  153. # 批处理推理
  154. with torch.no_grad():
  155. inp = torch.from_numpy(batch_tensor)
  156. inp = inp.to(self.device)
  157. outputs = self.net(inp)
  158. # 处理输出
  159. preds = {}
  160. if self.det_algorithm == "EAST":
  161. preds['f_geo'] = outputs['f_geo'].cpu().numpy()
  162. preds['f_score'] = outputs['f_score'].cpu().numpy()
  163. elif self.det_algorithm == 'SAST':
  164. preds['f_border'] = outputs['f_border'].cpu().numpy()
  165. preds['f_score'] = outputs['f_score'].cpu().numpy()
  166. preds['f_tco'] = outputs['f_tco'].cpu().numpy()
  167. preds['f_tvo'] = outputs['f_tvo'].cpu().numpy()
  168. elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
  169. preds['maps'] = outputs['maps'].cpu().numpy()
  170. elif self.det_algorithm == 'FCE':
  171. for i, (k, output) in enumerate(outputs.items()):
  172. preds['level_{}'.format(i)] = output.cpu().numpy()
  173. else:
  174. raise NotImplementedError
  175. # 后处理每个图像的结果
  176. batch_results = []
  177. total_elapse = time.time() - starttime
  178. for i in range(len(img_list)):
  179. # 提取单个图像的预测结果
  180. single_preds = {}
  181. for key, value in preds.items():
  182. if isinstance(value, np.ndarray):
  183. single_preds[key] = value[i:i + 1] # 保持批次维度
  184. else:
  185. single_preds[key] = value
  186. # 后处理
  187. post_result = self.postprocess_op(single_preds, batch_shapes[i:i + 1])
  188. dt_boxes = post_result[0]['points']
  189. # 过滤和裁剪检测框
  190. if (self.det_algorithm == "SAST" and
  191. self.det_sast_polygon) or (self.det_algorithm in ["PSE", "FCE"] and
  192. self.postprocess_op.box_type == 'poly'):
  193. dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_imgs[i].shape)
  194. else:
  195. dt_boxes = self.filter_tag_det_res(dt_boxes, ori_imgs[i].shape)
  196. batch_results.append((dt_boxes, total_elapse / len(img_list)))
  197. return batch_results, total_elapse
  198. def batch_predict(self, img_list, max_batch_size=8):
  199. """
  200. 批处理预测方法,支持多张图像同时检测
  201. Args:
  202. img_list: 图像列表
  203. max_batch_size: 最大批处理大小
  204. Returns:
  205. batch_results: 批处理结果列表,每个元素为(dt_boxes, elapse)
  206. """
  207. if not img_list:
  208. return []
  209. batch_results = []
  210. # 分批处理
  211. for i in range(0, len(img_list), max_batch_size):
  212. batch_imgs = img_list[i:i + max_batch_size]
  213. # assert尺寸一致
  214. batch_dt_boxes, batch_elapse = self._batch_process_same_size(batch_imgs)
  215. batch_results.extend(batch_dt_boxes)
  216. return batch_results
  217. def order_points_clockwise(self, pts):
  218. """
  219. reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
  220. # sort the points based on their x-coordinates
  221. """
  222. xSorted = pts[np.argsort(pts[:, 0]), :]
  223. # grab the left-most and right-most points from the sorted
  224. # x-roodinate points
  225. leftMost = xSorted[:2, :]
  226. rightMost = xSorted[2:, :]
  227. # now, sort the left-most coordinates according to their
  228. # y-coordinates so we can grab the top-left and bottom-left
  229. # points, respectively
  230. leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
  231. (tl, bl) = leftMost
  232. rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
  233. (tr, br) = rightMost
  234. rect = np.array([tl, tr, br, bl], dtype="float32")
  235. return rect
  236. def clip_det_res(self, points, img_height, img_width):
  237. for pno in range(points.shape[0]):
  238. points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
  239. points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
  240. return points
  241. def filter_tag_det_res(self, dt_boxes, image_shape):
  242. img_height, img_width = image_shape[0:2]
  243. dt_boxes_new = []
  244. for box in dt_boxes:
  245. box = self.order_points_clockwise(box)
  246. box = self.clip_det_res(box, img_height, img_width)
  247. rect_width = int(np.linalg.norm(box[0] - box[1]))
  248. rect_height = int(np.linalg.norm(box[0] - box[3]))
  249. if rect_width <= 3 or rect_height <= 3:
  250. continue
  251. dt_boxes_new.append(box)
  252. dt_boxes = np.array(dt_boxes_new)
  253. return dt_boxes
  254. def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
  255. img_height, img_width = image_shape[0:2]
  256. dt_boxes_new = []
  257. for box in dt_boxes:
  258. box = self.clip_det_res(box, img_height, img_width)
  259. dt_boxes_new.append(box)
  260. dt_boxes = np.array(dt_boxes_new)
  261. return dt_boxes
  262. def __call__(self, img):
  263. ori_shape = img.shape
  264. data = {'image': img}
  265. data = transform(data, self.preprocess_op)
  266. img, shape_list = data
  267. if img is None:
  268. return None, 0
  269. img = np.expand_dims(img, axis=0)
  270. shape_list = np.expand_dims(shape_list, axis=0)
  271. img = img.copy()
  272. starttime = time.time()
  273. with torch.no_grad():
  274. inp = torch.from_numpy(img)
  275. inp = inp.to(self.device)
  276. outputs = self.net(inp)
  277. preds = {}
  278. if self.det_algorithm == "EAST":
  279. preds['f_geo'] = outputs['f_geo'].cpu().numpy()
  280. preds['f_score'] = outputs['f_score'].cpu().numpy()
  281. elif self.det_algorithm == 'SAST':
  282. preds['f_border'] = outputs['f_border'].cpu().numpy()
  283. preds['f_score'] = outputs['f_score'].cpu().numpy()
  284. preds['f_tco'] = outputs['f_tco'].cpu().numpy()
  285. preds['f_tvo'] = outputs['f_tvo'].cpu().numpy()
  286. elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
  287. preds['maps'] = outputs['maps'].cpu().numpy()
  288. elif self.det_algorithm == 'FCE':
  289. for i, (k, output) in enumerate(outputs.items()):
  290. preds['level_{}'.format(i)] = output
  291. else:
  292. raise NotImplementedError
  293. post_result = self.postprocess_op(preds, shape_list)
  294. dt_boxes = post_result[0]['points']
  295. if (self.det_algorithm == "SAST" and
  296. self.det_sast_polygon) or (self.det_algorithm in ["PSE", "FCE"] and
  297. self.postprocess_op.box_type == 'poly'):
  298. dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_shape)
  299. else:
  300. dt_boxes = self.filter_tag_det_res(dt_boxes, ori_shape)
  301. elapse = time.time() - starttime
  302. return dt_boxes, elapse