import sys import numpy as np import time import torch from ..pytorchocr.base_ocr_v20 import BaseOCRV20 from . import pytorchocr_utility as utility from ..pytorchocr.data import create_operators, transform from ..pytorchocr.postprocess import build_post_process class TextDetector(BaseOCRV20): def __init__(self, args, **kwargs): self.args = args self.det_algorithm = args.det_algorithm self.device = args.device pre_process_list = [{ 'DetResizeForTest': { 'limit_side_len': args.det_limit_side_len, 'limit_type': args.det_limit_type, } }, { 'NormalizeImage': { 'std': [0.229, 0.224, 0.225], 'mean': [0.485, 0.456, 0.406], 'scale': '1./255.', 'order': 'hwc' } }, { 'ToCHWImage': None }, { 'KeepKeys': { 'keep_keys': ['image', 'shape'] } }] postprocess_params = {} if self.det_algorithm == "DB": postprocess_params['name'] = 'DBPostProcess' postprocess_params["thresh"] = args.det_db_thresh postprocess_params["box_thresh"] = args.det_db_box_thresh postprocess_params["max_candidates"] = 1000 postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["use_dilation"] = args.use_dilation postprocess_params["score_mode"] = args.det_db_score_mode elif self.det_algorithm == "DB++": postprocess_params['name'] = 'DBPostProcess' postprocess_params["thresh"] = args.det_db_thresh postprocess_params["box_thresh"] = args.det_db_box_thresh postprocess_params["max_candidates"] = 1000 postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["use_dilation"] = args.use_dilation postprocess_params["score_mode"] = args.det_db_score_mode pre_process_list[1] = { 'NormalizeImage': { 'std': [1.0, 1.0, 1.0], 'mean': [0.48109378172549, 0.45752457890196, 0.40787054090196], 'scale': '1./255.', 'order': 'hwc' } } elif self.det_algorithm == "EAST": postprocess_params['name'] = 'EASTPostProcess' postprocess_params["score_thresh"] = args.det_east_score_thresh postprocess_params["cover_thresh"] = args.det_east_cover_thresh postprocess_params["nms_thresh"] = args.det_east_nms_thresh elif self.det_algorithm == "SAST": pre_process_list[0] = { 'DetResizeForTest': { 'resize_long': args.det_limit_side_len } } postprocess_params['name'] = 'SASTPostProcess' postprocess_params["score_thresh"] = args.det_sast_score_thresh postprocess_params["nms_thresh"] = args.det_sast_nms_thresh self.det_sast_polygon = args.det_sast_polygon if self.det_sast_polygon: postprocess_params["sample_pts_num"] = 6 postprocess_params["expand_scale"] = 1.2 postprocess_params["shrink_ratio_of_width"] = 0.2 else: postprocess_params["sample_pts_num"] = 2 postprocess_params["expand_scale"] = 1.0 postprocess_params["shrink_ratio_of_width"] = 0.3 elif self.det_algorithm == "PSE": postprocess_params['name'] = 'PSEPostProcess' postprocess_params["thresh"] = args.det_pse_thresh postprocess_params["box_thresh"] = args.det_pse_box_thresh postprocess_params["min_area"] = args.det_pse_min_area postprocess_params["box_type"] = args.det_pse_box_type postprocess_params["scale"] = args.det_pse_scale self.det_pse_box_type = args.det_pse_box_type elif self.det_algorithm == "FCE": pre_process_list[0] = { 'DetResizeForTest': { 'rescale_img': [1080, 736] } } postprocess_params['name'] = 'FCEPostProcess' postprocess_params["scales"] = args.scales postprocess_params["alpha"] = args.alpha postprocess_params["beta"] = args.beta postprocess_params["fourier_degree"] = args.fourier_degree postprocess_params["box_type"] = args.det_fce_box_type else: print("unknown det_algorithm:{}".format(self.det_algorithm)) sys.exit(0) self.preprocess_op = create_operators(pre_process_list) self.postprocess_op = build_post_process(postprocess_params) self.weights_path = args.det_model_path self.yaml_path = args.det_yaml_path network_config = utility.get_arch_config(self.weights_path) super(TextDetector, self).__init__(network_config, **kwargs) self.load_pytorch_weights(self.weights_path) self.net.eval() self.net.to(self.device) for module in self.net.modules(): if hasattr(module, 'rep'): module.rep() def _batch_process_same_size(self, img_list): """ 对相同尺寸的图像进行批处理 Args: img_list: 相同尺寸的图像列表 Returns: batch_results: 批处理结果列表 total_elapse: 总耗时 """ starttime = time.time() # 预处理所有图像 batch_data = [] batch_shapes = [] ori_imgs = [] for img in img_list: ori_im = img.copy() ori_imgs.append(ori_im) data = {'image': img} data = transform(data, self.preprocess_op) if data is None: # 如果预处理失败,返回空结果 return [(None, 0) for _ in img_list], 0 img_processed, shape_list = data batch_data.append(img_processed) batch_shapes.append(shape_list) # 堆叠成批处理张量 try: batch_tensor = np.stack(batch_data, axis=0) batch_shapes = np.stack(batch_shapes, axis=0) except Exception as e: # 如果堆叠失败,回退到逐个处理 batch_results = [] for img in img_list: dt_boxes, elapse = self.__call__(img) batch_results.append((dt_boxes, elapse)) return batch_results, time.time() - starttime # 批处理推理 with torch.no_grad(): inp = torch.from_numpy(batch_tensor) inp = inp.to(self.device) outputs = self.net(inp) # 处理输出 preds = {} if self.det_algorithm == "EAST": preds['f_geo'] = outputs['f_geo'].cpu().numpy() preds['f_score'] = outputs['f_score'].cpu().numpy() elif self.det_algorithm == 'SAST': preds['f_border'] = outputs['f_border'].cpu().numpy() preds['f_score'] = outputs['f_score'].cpu().numpy() preds['f_tco'] = outputs['f_tco'].cpu().numpy() preds['f_tvo'] = outputs['f_tvo'].cpu().numpy() elif self.det_algorithm in ['DB', 'PSE', 'DB++']: preds['maps'] = outputs['maps'].cpu().numpy() elif self.det_algorithm == 'FCE': for i, (k, output) in enumerate(outputs.items()): preds['level_{}'.format(i)] = output.cpu().numpy() else: raise NotImplementedError # 后处理每个图像的结果 batch_results = [] total_elapse = time.time() - starttime for i in range(len(img_list)): # 提取单个图像的预测结果 single_preds = {} for key, value in preds.items(): if isinstance(value, np.ndarray): single_preds[key] = value[i:i + 1] # 保持批次维度 else: single_preds[key] = value # 后处理 post_result = self.postprocess_op(single_preds, batch_shapes[i:i + 1]) dt_boxes = post_result[0]['points'] # 过滤和裁剪检测框 if (self.det_algorithm == "SAST" and self.det_sast_polygon) or (self.det_algorithm in ["PSE", "FCE"] and self.postprocess_op.box_type == 'poly'): dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_imgs[i].shape) else: dt_boxes = self.filter_tag_det_res(dt_boxes, ori_imgs[i].shape) batch_results.append((dt_boxes, total_elapse / len(img_list))) return batch_results, total_elapse def batch_predict(self, img_list, max_batch_size=8): """ 批处理预测方法,支持多张图像同时检测 Args: img_list: 图像列表 max_batch_size: 最大批处理大小 Returns: batch_results: 批处理结果列表,每个元素为(dt_boxes, elapse) """ if not img_list: return [] batch_results = [] # 分批处理 for i in range(0, len(img_list), max_batch_size): batch_imgs = img_list[i:i + max_batch_size] # assert尺寸一致 batch_dt_boxes, batch_elapse = self._batch_process_same_size(batch_imgs) batch_results.extend(batch_dt_boxes) return batch_results def order_points_clockwise(self, pts): """ reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py # sort the points based on their x-coordinates """ xSorted = pts[np.argsort(pts[:, 0]), :] # grab the left-most and right-most points from the sorted # x-roodinate points leftMost = xSorted[:2, :] rightMost = xSorted[2:, :] # now, sort the left-most coordinates according to their # y-coordinates so we can grab the top-left and bottom-left # points, respectively leftMost = leftMost[np.argsort(leftMost[:, 1]), :] (tl, bl) = leftMost rightMost = rightMost[np.argsort(rightMost[:, 1]), :] (tr, br) = rightMost rect = np.array([tl, tr, br, bl], dtype="float32") return rect def clip_det_res(self, points, img_height, img_width): for pno in range(points.shape[0]): points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) return points def filter_tag_det_res(self, dt_boxes, image_shape): img_height, img_width = image_shape[0:2] dt_boxes_new = [] for box in dt_boxes: box = self.order_points_clockwise(box) box = self.clip_det_res(box, img_height, img_width) rect_width = int(np.linalg.norm(box[0] - box[1])) rect_height = int(np.linalg.norm(box[0] - box[3])) if rect_width <= 3 or rect_height <= 3: continue dt_boxes_new.append(box) dt_boxes = np.array(dt_boxes_new) return dt_boxes def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): img_height, img_width = image_shape[0:2] dt_boxes_new = [] for box in dt_boxes: box = self.clip_det_res(box, img_height, img_width) dt_boxes_new.append(box) dt_boxes = np.array(dt_boxes_new) return dt_boxes def __call__(self, img): ori_shape = img.shape data = {'image': img} data = transform(data, self.preprocess_op) img, shape_list = data if img is None: return None, 0 img = np.expand_dims(img, axis=0) shape_list = np.expand_dims(shape_list, axis=0) img = img.copy() starttime = time.time() with torch.no_grad(): inp = torch.from_numpy(img) inp = inp.to(self.device) outputs = self.net(inp) preds = {} if self.det_algorithm == "EAST": preds['f_geo'] = outputs['f_geo'].cpu().numpy() preds['f_score'] = outputs['f_score'].cpu().numpy() elif self.det_algorithm == 'SAST': preds['f_border'] = outputs['f_border'].cpu().numpy() preds['f_score'] = outputs['f_score'].cpu().numpy() preds['f_tco'] = outputs['f_tco'].cpu().numpy() preds['f_tvo'] = outputs['f_tvo'].cpu().numpy() elif self.det_algorithm in ['DB', 'PSE', 'DB++']: preds['maps'] = outputs['maps'].cpu().numpy() elif self.det_algorithm == 'FCE': for i, (k, output) in enumerate(outputs.items()): preds['level_{}'.format(i)] = output else: raise NotImplementedError post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] if (self.det_algorithm == "SAST" and self.det_sast_polygon) or (self.det_algorithm in ["PSE", "FCE"] and self.postprocess_op.box_type == 'poly'): dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_shape) else: dt_boxes = self.filter_tag_det_res(dt_boxes, ori_shape) elapse = time.time() - starttime return dt_boxes, elapse