""" OCR 工具函数 - 从 MinerU 完整迁移 提供文本检测框处理、图像预处理等功能 """ import copy import cv2 import numpy as np from typing import List, Tuple, Union # ==================== 配置常量 ==================== class OcrConfidence: """OCR 置信度配置""" min_confidence = 0.5 min_width = 3 # 一般情况下,行宽度超过高度4倍时才是一个正常的横向文本块 LINE_WIDTH_TO_HEIGHT_RATIO_THRESHOLD = 4 # ==================== 图像基础处理 ==================== def img_decode(content: bytes): """ 解码字节流为图像 Args: content: 图像字节流 Returns: np.ndarray: 解码后的图像 """ np_arr = np.frombuffer(content, dtype=np.uint8) return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED) def check_img(img): """ 检查并转换图像格式 Args: img: 图像(可以是 bytes 或 np.ndarray) Returns: np.ndarray: BGR 格式图像 """ if isinstance(img, bytes): img = img_decode(img) if isinstance(img, np.ndarray) and len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) return img def alpha_to_color(img, alpha_color=(255, 255, 255)): """ 将带 alpha 通道的图像转换为 RGB Args: img: 输入图像 alpha_color: 背景颜色 (B, G, R) Returns: np.ndarray: RGB 图像 """ if len(img.shape) == 3 and img.shape[2] == 4: B, G, R, A = cv2.split(img) alpha = A / 255 R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8) G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8) B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8) img = cv2.merge((B, G, R)) return img def preprocess_image(_image): """ 预处理图像(去除 alpha 通道) Args: _image: 输入图像 Returns: np.ndarray: 预处理后的图像 """ alpha_color = (255, 255, 255) _image = alpha_to_color(_image, alpha_color) return _image # ==================== 坐标转换工具 ==================== def bbox_to_points(bbox): """ 将 bbox 格式转换为四个顶点的数组 Args: bbox: [x0, y0, x1, y1] Returns: np.ndarray: [[x0, y0], [x1, y0], [x1, y1], [x0, y1]] """ x0, y0, x1, y1 = bbox return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32') def points_to_bbox(points): """ 将四个顶点的数组转换为 bbox 格式 Args: points: [[x0, y0], [x1, y1], [x2, y2], [x3, y3]] Returns: list: [x0, y0, x1, y1] """ x0, y0 = points[0] x1, _ = points[1] _, y1 = points[2] return [x0, y0, x1, y1] # ==================== 检测框排序和合并 ==================== def sorted_boxes(dt_boxes): """ 按从上到下、从左到右的顺序排序文本框 Args: dt_boxes (array): 检测到的文本框,形状为 [num, 4, 2] Returns: list: 排序后的文本框列表 """ num_boxes = dt_boxes.shape[0] sorted_boxes_list = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) _boxes = list(sorted_boxes_list) for i in range(num_boxes - 1): for j in range(i, -1, -1): if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \ (_boxes[j + 1][0][0] < _boxes[j][0][0]): tmp = _boxes[j] _boxes[j] = _boxes[j + 1] _boxes[j + 1] = tmp else: break return _boxes def _is_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.8): """ 检查两个 bbox 在 y 轴上是否有重叠,并且该重叠区域的高度占两个 bbox 高度更低的那个超过阈值 Args: bbox1: [x0, y0, x1, y1] bbox2: [x0, y0, x1, y1] overlap_ratio_threshold: 重叠比例阈值 Returns: bool: 是否满足重叠条件 """ _, y0_1, _, y1_1 = bbox1 _, y0_2, _, y1_2 = bbox2 overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2)) height1, height2 = y1_1 - y0_1, y1_2 - y0_2 min_height = min(height1, height2) return (overlap / min_height) > overlap_ratio_threshold if min_height > 0 else False def _is_overlaps_x_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.8): """ 检查两个 bbox 在 x 轴上是否有重叠 Args: bbox1: [x0, y0, x1, y1] bbox2: [x0, y0, x1, y1] overlap_ratio_threshold: 重叠比例阈值 Returns: bool: 是否满足重叠条件 """ x0_1, _, x1_1, _ = bbox1 x0_2, _, x1_2, _ = bbox2 overlap = max(0, min(x1_1, x1_2) - max(x0_1, x0_2)) width1, width2 = x1_1 - x0_1, x1_2 - x0_2 min_width = min(width1, width2) return (overlap / min_width) > overlap_ratio_threshold if min_width > 0 else False def merge_spans_to_line(spans, threshold=0.6): """ 将 spans 合并为行 Args: spans: span 列表,每个 span 包含 'bbox' 字段 threshold: y 轴重叠阈值 Returns: list: 行列表,每行包含多个 span """ if len(spans) == 0: return [] # 按照 y0 坐标排序 spans.sort(key=lambda span: span['bbox'][1]) lines = [] current_line = [spans[0]] for span in spans[1:]: # 如果当前的 span 与当前行的最后一个 span 在 y 轴上重叠,则添加到当前行 if _is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold): current_line.append(span) else: # 否则,开始新行 lines.append(current_line) current_line = [span] # 添加最后一行 if current_line: lines.append(current_line) return lines def merge_overlapping_spans(spans): """ 合并同一行上重叠的 spans Args: spans: span 坐标列表 [(x1, y1, x2, y2), ...] Returns: list: 合并后的 spans """ if not spans: return [] # 按起始 x 坐标排序 spans.sort(key=lambda x: x[0]) merged = [] for span in spans: x1, y1, x2, y2 = span # 如果合并列表为空或没有水平重叠,直接添加 if not merged or merged[-1][2] < x1: merged.append(span) else: # 如果有水平重叠,合并当前 span 和前一个 last_span = merged.pop() x1 = min(last_span[0], x1) y1 = min(last_span[1], y1) x2 = max(last_span[2], x2) y2 = max(last_span[3], y2) merged.append((x1, y1, x2, y2)) return merged def merge_det_boxes(dt_boxes): """ 合并检测框为更大的文本区域 Args: dt_boxes (list): 检测框列表,每个框由四个角点定义 Returns: list: 合并后的文本区域列表 """ # 转换检测框为字典格式 dt_boxes_dict_list = [] angle_boxes_list = [] for text_box in dt_boxes: text_bbox = points_to_bbox(text_box) # 跳过倾斜的框 if calculate_is_angle(text_box): angle_boxes_list.append(text_box) continue text_box_dict = {'bbox': text_bbox} dt_boxes_dict_list.append(text_box_dict) # 合并相邻文本区域为行 lines = merge_spans_to_line(dt_boxes_dict_list) # 初始化合并后的文本区域列表 new_dt_boxes = [] for line in lines: line_bbox_list = [] for span in line: line_bbox_list.append(span['bbox']) # 计算整行的宽度和高度 min_x = min(bbox[0] for bbox in line_bbox_list) max_x = max(bbox[2] for bbox in line_bbox_list) min_y = min(bbox[1] for bbox in line_bbox_list) max_y = max(bbox[3] for bbox in line_bbox_list) line_width = max_x - min_x line_height = max_y - min_y # 只有当行宽度超过高度4倍时才进行合并 if line_width > line_height * LINE_WIDTH_TO_HEIGHT_RATIO_THRESHOLD: # 合并同一行内重叠的文本区域 merged_spans = merge_overlapping_spans(line_bbox_list) # 转换回点格式并添加到新检测框列表 for span in merged_spans: new_dt_boxes.append(bbox_to_points(span)) else: # 不进行合并,直接添加原始区域 for bbox in line_bbox_list: new_dt_boxes.append(bbox_to_points(bbox)) # 添加倾斜的框 new_dt_boxes.extend(angle_boxes_list) return new_dt_boxes # ==================== 区间处理工具 ==================== def merge_intervals(intervals): """ 合并重叠的区间 Args: intervals: 区间列表 [[start, end], ...] Returns: list: 合并后的区间列表 """ # 按起始值排序 intervals.sort(key=lambda x: x[0]) merged = [] for interval in intervals: # 如果合并列表为空或当前区间不重叠,直接添加 if not merged or merged[-1][1] < interval[0]: merged.append(interval) else: # 否则合并当前和前一个区间 merged[-1][1] = max(merged[-1][1], interval[1]) return merged def remove_intervals(original, masks): """ 从原始区间中移除掩码区间 Args: original: 原始区间 [start, end] masks: 掩码区间列表 [[start, end], ...] Returns: list: 移除掩码后的区间列表 """ # 合并所有掩码区间 merged_masks = merge_intervals(masks) result = [] original_start, original_end = original for mask in merged_masks: mask_start, mask_end = mask # 如果掩码在原始区间之外,忽略 if mask_start > original_end: continue if mask_end < original_start: continue # 移除掩码部分 if original_start < mask_start: result.append([original_start, mask_start - 1]) original_start = max(mask_end + 1, original_start) # 添加剩余部分 if original_start <= original_end: result.append([original_start, original_end]) return result # ==================== 公式检测结果处理 ==================== def update_det_boxes(dt_boxes, mfd_res): """ 更新检测框(移除与公式区域重叠的框) Args: dt_boxes: 文本检测框列表 mfd_res: 公式检测结果列表 Returns: list: 更新后的检测框列表 """ if mfd_res is None or len(mfd_res) == 0: return dt_boxes new_dt_boxes = [] angle_boxes_list = [] for text_box in dt_boxes: # 跳过倾斜的框 if calculate_is_angle(text_box): angle_boxes_list.append(text_box) continue text_bbox = points_to_bbox(text_box) masks_list = [] # 找出所有与文本框在 y 轴上重叠的公式框 for mf_box in mfd_res: mf_bbox = mf_box['bbox'] if _is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox): masks_list.append([mf_bbox[0], mf_bbox[2]]) # 从文本框的 x 范围中移除公式框 text_x_range = [text_bbox[0], text_bbox[2]] text_remove_mask_range = remove_intervals(text_x_range, masks_list) # 为每个剩余的 x 范围创建新的文本框 temp_dt_box = [] for text_remove_mask in text_remove_mask_range: temp_dt_box.append(bbox_to_points([ text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3] ])) if len(temp_dt_box) > 0: new_dt_boxes.extend(temp_dt_box) # 添加倾斜的框 new_dt_boxes.extend(angle_boxes_list) return new_dt_boxes def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list): """ 调整公式检测结果的坐标 Args: single_page_mfdetrec_res: 公式检测结果 useful_list: 坐标调整参数 [paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height] Returns: list: 调整后的公式检测结果 """ paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list adjusted_mfdetrec_res = [] for mf_res in single_page_mfdetrec_res: mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"] # 调整坐标 x0 = mf_xmin - xmin + paste_x y0 = mf_ymin - ymin + paste_y x1 = mf_xmax - xmin + paste_x y1 = mf_ymax - ymin + paste_y # 过滤图外的公式块 if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]): continue else: adjusted_mfdetrec_res.append({ "bbox": [x0, y0, x1, y1], }) return adjusted_mfdetrec_res # ==================== 角度检测 ==================== def calculate_is_angle(poly): """ 判断多边形是否倾斜 Args: poly: 四个顶点 [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] Returns: bool: 是否倾斜 """ p1, p2, p3, p4 = poly height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2 if 0.8 * height <= (p3[1] - p1[1]) <= 1.2 * height: return False else: return True def is_bbox_aligned_rect(points): """ 判断边界框是否是轴对齐矩形 Args: points: 四个顶点坐标 Returns: bool: 是否是轴对齐矩形 """ x_coords = points[:, 0] y_coords = points[:, 1] unique_x = np.unique(x_coords) unique_y = np.unique(y_coords) return len(unique_x) == 2 and len(unique_y) == 2 # ==================== 图像裁剪和旋转 ==================== def get_rotate_crop_image(img, points): """ 根据四个点裁剪并矫正文本区域 Args: img: 输入图像 points: 四个角点坐标 [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] Returns: np.ndarray: 裁剪并矫正后的图像 """ assert len(points) == 4, "shape of points must be 4*2" # 如果是轴对齐矩形,直接裁剪 if is_bbox_aligned_rect(points): xmin = int(np.min(points[:, 0])) xmax = int(np.max(points[:, 0])) ymin = int(np.min(points[:, 1])) ymax = int(np.max(points[:, 1])) new_img = img[ymin:ymax, xmin:xmax].copy() if new_img.shape[0] > 0 and new_img.shape[1] > 0: return new_img # 计算裁剪区域的宽高 img_crop_width = int( max( np.linalg.norm(points[0] - points[1]), np.linalg.norm(points[2] - points[3]) ) ) img_crop_height = int( max( np.linalg.norm(points[0] - points[3]), np.linalg.norm(points[1] - points[2]) ) ) # 定义标准矩形的四个角点 pts_std = np.float32([ [0, 0], [img_crop_width, 0], [img_crop_width, img_crop_height], [0, img_crop_height] ]) # 透视变换 M = cv2.getPerspectiveTransform(points, pts_std) dst_img = cv2.warpPerspective( img, M, (img_crop_width, img_crop_height), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC ) # 如果高度远大于宽度,旋转90度 dst_img_height, dst_img_width = dst_img.shape[0:2] rotate_radio = 2 if dst_img_height * 1.0 / dst_img_width >= rotate_radio: dst_img = np.rot90(dst_img) return dst_img # ==================== OCR 结果处理 ==================== def get_ocr_result_list(ocr_res, useful_list, ocr_enable, bgr_image, lang): """ 处理 OCR 结果列表 Args: ocr_res: OCR 原始结果 useful_list: 坐标调整参数 ocr_enable: 是否启用 OCR bgr_image: 原始图像 lang: 语言 Returns: list: 处理后的 OCR 结果列表 """ paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list ocr_result_list = [] ori_im = bgr_image.copy() for box_ocr_res in ocr_res: if len(box_ocr_res) == 2: p1, p2, p3, p4 = box_ocr_res[0] text, score = box_ocr_res[1] # 过滤低置信度的结果 if score < OcrConfidence.min_confidence: continue else: p1, p2, p3, p4 = box_ocr_res text, score = "", 1 if ocr_enable: tmp_box = copy.deepcopy(np.array([p1, p2, p3, p4]).astype('float32')) img_crop = get_rotate_crop_image(ori_im, tmp_box) poly = [p1, p2, p3, p4] # 过滤宽度太小的框 if (p3[0] - p1[0]) < OcrConfidence.min_width: continue # 矫正倾斜的框 if calculate_is_angle(poly): # 计算几何中心 x_center = sum(point[0] for point in poly) / 4 y_center = sum(point[1] for point in poly) / 4 new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2 new_width = p3[0] - p1[0] p1 = [x_center - new_width / 2, y_center - new_height / 2] p2 = [x_center + new_width / 2, y_center - new_height / 2] p3 = [x_center + new_width / 2, y_center + new_height / 2] p4 = [x_center - new_width / 2, y_center + new_height / 2] # 转换回原始坐标系 p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin] p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin] p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin] p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin] if ocr_enable: ocr_result_list.append({ 'category_id': 15, 'poly': p1 + p2 + p3 + p4, 'score': 1, 'text': text, 'np_img': img_crop, 'lang': lang, }) else: ocr_result_list.append({ 'category_id': 15, 'poly': p1 + p2 + p3 + p4, 'score': float(round(score, 2)), 'text': text, }) return ocr_result_list # ==================== 测试代码 ==================== if __name__ == "__main__": """测试 OCR 工具函数""" print("🧪 Testing OCR Utils...") # 测试 check_img print("\n1. Testing check_img") img = np.ones((100, 100, 3), dtype=np.uint8) assert check_img(img) is not None print(" ✅ check_img: PASS") # 测试 sorted_boxes print("\n2. Testing sorted_boxes") boxes = np.array([ [[50, 50], [100, 50], [100, 100], [50, 100]], [[10, 10], [60, 10], [60, 60], [10, 60]], ]) sorted_result = sorted_boxes(boxes) print(f" ✅ sorted_boxes: PASS (got {len(sorted_result)} boxes)") # 测试 bbox_to_points / points_to_bbox print("\n3. Testing bbox_to_points / points_to_bbox") bbox = [10, 20, 100, 200] points = bbox_to_points(bbox) bbox_back = points_to_bbox(points) assert bbox == bbox_back print(" ✅ bbox conversion: PASS") # 测试 merge_intervals print("\n4. Testing merge_intervals") intervals = [[1, 3], [2, 6], [8, 10], [15, 18]] merged = merge_intervals(intervals) print(f" ✅ merge_intervals: {merged}") # 测试 calculate_is_angle print("\n5. Testing calculate_is_angle") poly_straight = [[0, 0], [100, 0], [100, 50], [0, 50]] poly_angle = [[0, 0], [100, 0], [100, 80], [0, 20]] print(f" Straight poly is_angle: {calculate_is_angle(poly_straight)}") print(f" Angled poly is_angle: {calculate_is_angle(poly_angle)}") print("\n✅ All tests passed!")