| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710 |
- """
- OCR 工具函数 - 从 MinerU 完整迁移
- 提供文本检测框处理、图像预处理等功能
- """
- import copy
- import cv2
- import numpy as np
- from typing import List, Tuple, Union
- # ==================== 配置常量 ====================
- class OcrConfidence:
- """OCR 置信度配置"""
- min_confidence = 0.5
- min_width = 3
- # 一般情况下,行宽度超过高度4倍时才是一个正常的横向文本块
- LINE_WIDTH_TO_HEIGHT_RATIO_THRESHOLD = 4
- # ==================== 图像基础处理 ====================
- def img_decode(content: bytes):
- """
- 解码字节流为图像
-
- Args:
- content: 图像字节流
-
- Returns:
- np.ndarray: 解码后的图像
- """
- np_arr = np.frombuffer(content, dtype=np.uint8)
- return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
- def check_img(img):
- """
- 检查并转换图像格式
-
- Args:
- img: 图像(可以是 bytes 或 np.ndarray)
-
- Returns:
- np.ndarray: BGR 格式图像
- """
- if isinstance(img, bytes):
- img = img_decode(img)
- if isinstance(img, np.ndarray) and len(img.shape) == 2:
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
- return img
- def alpha_to_color(img, alpha_color=(255, 255, 255)):
- """
- 将带 alpha 通道的图像转换为 RGB
-
- Args:
- img: 输入图像
- alpha_color: 背景颜色 (B, G, R)
-
- Returns:
- np.ndarray: RGB 图像
- """
- if len(img.shape) == 3 and img.shape[2] == 4:
- B, G, R, A = cv2.split(img)
- alpha = A / 255
- R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8)
- G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8)
- B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8)
- img = cv2.merge((B, G, R))
- return img
- def preprocess_image(_image):
- """
- 预处理图像(去除 alpha 通道)
-
- Args:
- _image: 输入图像
-
- Returns:
- np.ndarray: 预处理后的图像
- """
- alpha_color = (255, 255, 255)
- _image = alpha_to_color(_image, alpha_color)
- return _image
- # ==================== 坐标转换工具 ====================
- def bbox_to_points(bbox):
- """
- 将 bbox 格式转换为四个顶点的数组
-
- Args:
- bbox: [x0, y0, x1, y1]
-
- Returns:
- np.ndarray: [[x0, y0], [x1, y0], [x1, y1], [x0, y1]]
- """
- x0, y0, x1, y1 = bbox
- return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
- def points_to_bbox(points):
- """
- 将四个顶点的数组转换为 bbox 格式
-
- Args:
- points: [[x0, y0], [x1, y1], [x2, y2], [x3, y3]]
-
- Returns:
- list: [x0, y0, x1, y1]
- """
- x0, y0 = points[0]
- x1, _ = points[1]
- _, y1 = points[2]
- return [x0, y0, x1, y1]
- # ==================== 检测框排序和合并 ====================
- def sorted_boxes(dt_boxes):
- """
- 按从上到下、从左到右的顺序排序文本框
-
- Args:
- dt_boxes (array): 检测到的文本框,形状为 [num, 4, 2]
-
- Returns:
- list: 排序后的文本框列表
- """
- num_boxes = dt_boxes.shape[0]
- sorted_boxes_list = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
- _boxes = list(sorted_boxes_list)
- for i in range(num_boxes - 1):
- for j in range(i, -1, -1):
- if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
- (_boxes[j + 1][0][0] < _boxes[j][0][0]):
- tmp = _boxes[j]
- _boxes[j] = _boxes[j + 1]
- _boxes[j + 1] = tmp
- else:
- break
- return _boxes
- def _is_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.8):
- """
- 检查两个 bbox 在 y 轴上是否有重叠,并且该重叠区域的高度占两个 bbox 高度更低的那个超过阈值
-
- Args:
- bbox1: [x0, y0, x1, y1]
- bbox2: [x0, y0, x1, y1]
- overlap_ratio_threshold: 重叠比例阈值
-
- Returns:
- bool: 是否满足重叠条件
- """
- _, y0_1, _, y1_1 = bbox1
- _, y0_2, _, y1_2 = bbox2
- overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2))
- height1, height2 = y1_1 - y0_1, y1_2 - y0_2
- min_height = min(height1, height2)
- return (overlap / min_height) > overlap_ratio_threshold if min_height > 0 else False
- def _is_overlaps_x_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.8):
- """
- 检查两个 bbox 在 x 轴上是否有重叠
-
- Args:
- bbox1: [x0, y0, x1, y1]
- bbox2: [x0, y0, x1, y1]
- overlap_ratio_threshold: 重叠比例阈值
-
- Returns:
- bool: 是否满足重叠条件
- """
- x0_1, _, x1_1, _ = bbox1
- x0_2, _, x1_2, _ = bbox2
- overlap = max(0, min(x1_1, x1_2) - max(x0_1, x0_2))
- width1, width2 = x1_1 - x0_1, x1_2 - x0_2
- min_width = min(width1, width2)
- return (overlap / min_width) > overlap_ratio_threshold if min_width > 0 else False
- def merge_spans_to_line(spans, threshold=0.6):
- """
- 将 spans 合并为行
-
- Args:
- spans: span 列表,每个 span 包含 'bbox' 字段
- threshold: y 轴重叠阈值
-
- Returns:
- list: 行列表,每行包含多个 span
- """
- if len(spans) == 0:
- return []
-
- # 按照 y0 坐标排序
- spans.sort(key=lambda span: span['bbox'][1])
- lines = []
- current_line = [spans[0]]
- for span in spans[1:]:
- # 如果当前的 span 与当前行的最后一个 span 在 y 轴上重叠,则添加到当前行
- if _is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
- current_line.append(span)
- else:
- # 否则,开始新行
- lines.append(current_line)
- current_line = [span]
- # 添加最后一行
- if current_line:
- lines.append(current_line)
- return lines
- def merge_overlapping_spans(spans):
- """
- 合并同一行上重叠的 spans
-
- Args:
- spans: span 坐标列表 [(x1, y1, x2, y2), ...]
-
- Returns:
- list: 合并后的 spans
- """
- if not spans:
- return []
- # 按起始 x 坐标排序
- spans.sort(key=lambda x: x[0])
- merged = []
- for span in spans:
- x1, y1, x2, y2 = span
- # 如果合并列表为空或没有水平重叠,直接添加
- if not merged or merged[-1][2] < x1:
- merged.append(span)
- else:
- # 如果有水平重叠,合并当前 span 和前一个
- last_span = merged.pop()
- x1 = min(last_span[0], x1)
- y1 = min(last_span[1], y1)
- x2 = max(last_span[2], x2)
- y2 = max(last_span[3], y2)
- merged.append((x1, y1, x2, y2))
- return merged
- def merge_det_boxes(dt_boxes):
- """
- 合并检测框为更大的文本区域
-
- Args:
- dt_boxes (list): 检测框列表,每个框由四个角点定义
-
- Returns:
- list: 合并后的文本区域列表
- """
- # 转换检测框为字典格式
- dt_boxes_dict_list = []
- angle_boxes_list = []
-
- for text_box in dt_boxes:
- text_bbox = points_to_bbox(text_box)
- # 跳过倾斜的框
- if calculate_is_angle(text_box):
- angle_boxes_list.append(text_box)
- continue
- text_box_dict = {'bbox': text_bbox}
- dt_boxes_dict_list.append(text_box_dict)
- # 合并相邻文本区域为行
- lines = merge_spans_to_line(dt_boxes_dict_list)
- # 初始化合并后的文本区域列表
- new_dt_boxes = []
- for line in lines:
- line_bbox_list = []
- for span in line:
- line_bbox_list.append(span['bbox'])
- # 计算整行的宽度和高度
- min_x = min(bbox[0] for bbox in line_bbox_list)
- max_x = max(bbox[2] for bbox in line_bbox_list)
- min_y = min(bbox[1] for bbox in line_bbox_list)
- max_y = max(bbox[3] for bbox in line_bbox_list)
- line_width = max_x - min_x
- line_height = max_y - min_y
- # 只有当行宽度超过高度4倍时才进行合并
- if line_width > line_height * LINE_WIDTH_TO_HEIGHT_RATIO_THRESHOLD:
- # 合并同一行内重叠的文本区域
- merged_spans = merge_overlapping_spans(line_bbox_list)
- # 转换回点格式并添加到新检测框列表
- for span in merged_spans:
- new_dt_boxes.append(bbox_to_points(span))
- else:
- # 不进行合并,直接添加原始区域
- for bbox in line_bbox_list:
- new_dt_boxes.append(bbox_to_points(bbox))
- # 添加倾斜的框
- new_dt_boxes.extend(angle_boxes_list)
- return new_dt_boxes
- # ==================== 区间处理工具 ====================
- def merge_intervals(intervals):
- """
- 合并重叠的区间
-
- Args:
- intervals: 区间列表 [[start, end], ...]
-
- Returns:
- list: 合并后的区间列表
- """
- # 按起始值排序
- intervals.sort(key=lambda x: x[0])
- merged = []
- for interval in intervals:
- # 如果合并列表为空或当前区间不重叠,直接添加
- if not merged or merged[-1][1] < interval[0]:
- merged.append(interval)
- else:
- # 否则合并当前和前一个区间
- merged[-1][1] = max(merged[-1][1], interval[1])
- return merged
- def remove_intervals(original, masks):
- """
- 从原始区间中移除掩码区间
-
- Args:
- original: 原始区间 [start, end]
- masks: 掩码区间列表 [[start, end], ...]
-
- Returns:
- list: 移除掩码后的区间列表
- """
- # 合并所有掩码区间
- merged_masks = merge_intervals(masks)
- result = []
- original_start, original_end = original
- for mask in merged_masks:
- mask_start, mask_end = mask
- # 如果掩码在原始区间之外,忽略
- if mask_start > original_end:
- continue
- if mask_end < original_start:
- continue
- # 移除掩码部分
- if original_start < mask_start:
- result.append([original_start, mask_start - 1])
- original_start = max(mask_end + 1, original_start)
- # 添加剩余部分
- if original_start <= original_end:
- result.append([original_start, original_end])
- return result
- # ==================== 公式检测结果处理 ====================
- def update_det_boxes(dt_boxes, mfd_res):
- """
- 更新检测框(移除与公式区域重叠的框)
-
- Args:
- dt_boxes: 文本检测框列表
- mfd_res: 公式检测结果列表
-
- Returns:
- list: 更新后的检测框列表
- """
- if mfd_res is None or len(mfd_res) == 0:
- return dt_boxes
-
- new_dt_boxes = []
- angle_boxes_list = []
-
- for text_box in dt_boxes:
- # 跳过倾斜的框
- if calculate_is_angle(text_box):
- angle_boxes_list.append(text_box)
- continue
- text_bbox = points_to_bbox(text_box)
- masks_list = []
-
- # 找出所有与文本框在 y 轴上重叠的公式框
- for mf_box in mfd_res:
- mf_bbox = mf_box['bbox']
- if _is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
- masks_list.append([mf_bbox[0], mf_bbox[2]])
-
- # 从文本框的 x 范围中移除公式框
- text_x_range = [text_bbox[0], text_bbox[2]]
- text_remove_mask_range = remove_intervals(text_x_range, masks_list)
-
- # 为每个剩余的 x 范围创建新的文本框
- temp_dt_box = []
- for text_remove_mask in text_remove_mask_range:
- temp_dt_box.append(bbox_to_points([
- text_remove_mask[0], text_bbox[1],
- text_remove_mask[1], text_bbox[3]
- ]))
-
- if len(temp_dt_box) > 0:
- new_dt_boxes.extend(temp_dt_box)
- # 添加倾斜的框
- new_dt_boxes.extend(angle_boxes_list)
- return new_dt_boxes
- def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
- """
- 调整公式检测结果的坐标
-
- Args:
- single_page_mfdetrec_res: 公式检测结果
- useful_list: 坐标调整参数 [paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height]
-
- Returns:
- list: 调整后的公式检测结果
- """
- paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
-
- adjusted_mfdetrec_res = []
- for mf_res in single_page_mfdetrec_res:
- mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
-
- # 调整坐标
- x0 = mf_xmin - xmin + paste_x
- y0 = mf_ymin - ymin + paste_y
- x1 = mf_xmax - xmin + paste_x
- y1 = mf_ymax - ymin + paste_y
-
- # 过滤图外的公式块
- if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
- continue
- else:
- adjusted_mfdetrec_res.append({
- "bbox": [x0, y0, x1, y1],
- })
-
- return adjusted_mfdetrec_res
- # ==================== 角度检测 ====================
- def calculate_is_angle(poly):
- """
- 判断多边形是否倾斜
-
- Args:
- poly: 四个顶点 [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
-
- Returns:
- bool: 是否倾斜
- """
- p1, p2, p3, p4 = poly
- height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
-
- if 0.8 * height <= (p3[1] - p1[1]) <= 1.2 * height:
- return False
- else:
- return True
- def is_bbox_aligned_rect(points):
- """
- 判断边界框是否是轴对齐矩形
-
- Args:
- points: 四个顶点坐标
-
- Returns:
- bool: 是否是轴对齐矩形
- """
- x_coords = points[:, 0]
- y_coords = points[:, 1]
- unique_x = np.unique(x_coords)
- unique_y = np.unique(y_coords)
- return len(unique_x) == 2 and len(unique_y) == 2
- # ==================== 图像裁剪和旋转 ====================
- def get_rotate_crop_image(img, points):
- """
- 根据四个点裁剪并矫正文本区域
-
- Args:
- img: 输入图像
- points: 四个角点坐标 [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
-
- Returns:
- np.ndarray: 裁剪并矫正后的图像
- """
- assert len(points) == 4, "shape of points must be 4*2"
- # 如果是轴对齐矩形,直接裁剪
- if is_bbox_aligned_rect(points):
- xmin = int(np.min(points[:, 0]))
- xmax = int(np.max(points[:, 0]))
- ymin = int(np.min(points[:, 1]))
- ymax = int(np.max(points[:, 1]))
- new_img = img[ymin:ymax, xmin:xmax].copy()
- if new_img.shape[0] > 0 and new_img.shape[1] > 0:
- return new_img
- # 计算裁剪区域的宽高
- img_crop_width = int(
- max(
- np.linalg.norm(points[0] - points[1]),
- np.linalg.norm(points[2] - points[3])
- )
- )
- img_crop_height = int(
- max(
- np.linalg.norm(points[0] - points[3]),
- np.linalg.norm(points[1] - points[2])
- )
- )
-
- # 定义标准矩形的四个角点
- pts_std = np.float32([
- [0, 0],
- [img_crop_width, 0],
- [img_crop_width, img_crop_height],
- [0, img_crop_height]
- ])
-
- # 透视变换
- M = cv2.getPerspectiveTransform(points, pts_std)
- dst_img = cv2.warpPerspective(
- img,
- M,
- (img_crop_width, img_crop_height),
- borderMode=cv2.BORDER_REPLICATE,
- flags=cv2.INTER_CUBIC
- )
-
- # 如果高度远大于宽度,旋转90度
- dst_img_height, dst_img_width = dst_img.shape[0:2]
- rotate_radio = 2
- if dst_img_height * 1.0 / dst_img_width >= rotate_radio:
- dst_img = np.rot90(dst_img)
-
- return dst_img
- # ==================== OCR 结果处理 ====================
- def get_ocr_result_list(ocr_res, useful_list, ocr_enable, bgr_image, lang):
- """
- 处理 OCR 结果列表
-
- Args:
- ocr_res: OCR 原始结果
- useful_list: 坐标调整参数
- ocr_enable: 是否启用 OCR
- bgr_image: 原始图像
- lang: 语言
-
- Returns:
- list: 处理后的 OCR 结果列表
- """
- paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
- ocr_result_list = []
- ori_im = bgr_image.copy()
-
- for box_ocr_res in ocr_res:
- if len(box_ocr_res) == 2:
- p1, p2, p3, p4 = box_ocr_res[0]
- text, score = box_ocr_res[1]
-
- # 过滤低置信度的结果
- if score < OcrConfidence.min_confidence:
- continue
- else:
- p1, p2, p3, p4 = box_ocr_res
- text, score = "", 1
- if ocr_enable:
- tmp_box = copy.deepcopy(np.array([p1, p2, p3, p4]).astype('float32'))
- img_crop = get_rotate_crop_image(ori_im, tmp_box)
- poly = [p1, p2, p3, p4]
- # 过滤宽度太小的框
- if (p3[0] - p1[0]) < OcrConfidence.min_width:
- continue
- # 矫正倾斜的框
- if calculate_is_angle(poly):
- # 计算几何中心
- x_center = sum(point[0] for point in poly) / 4
- y_center = sum(point[1] for point in poly) / 4
- new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
- new_width = p3[0] - p1[0]
- p1 = [x_center - new_width / 2, y_center - new_height / 2]
- p2 = [x_center + new_width / 2, y_center - new_height / 2]
- p3 = [x_center + new_width / 2, y_center + new_height / 2]
- p4 = [x_center - new_width / 2, y_center + new_height / 2]
- # 转换回原始坐标系
- p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
- p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
- p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
- p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
- if ocr_enable:
- ocr_result_list.append({
- 'category_id': 15,
- 'poly': p1 + p2 + p3 + p4,
- 'score': 1,
- 'text': text,
- 'np_img': img_crop,
- 'lang': lang,
- })
- else:
- ocr_result_list.append({
- 'category_id': 15,
- 'poly': p1 + p2 + p3 + p4,
- 'score': float(round(score, 2)),
- 'text': text,
- })
- return ocr_result_list
- # ==================== 测试代码 ====================
- if __name__ == "__main__":
- """测试 OCR 工具函数"""
- print("🧪 Testing OCR Utils...")
-
- # 测试 check_img
- print("\n1. Testing check_img")
- img = np.ones((100, 100, 3), dtype=np.uint8)
- assert check_img(img) is not None
- print(" ✅ check_img: PASS")
-
- # 测试 sorted_boxes
- print("\n2. Testing sorted_boxes")
- boxes = np.array([
- [[50, 50], [100, 50], [100, 100], [50, 100]],
- [[10, 10], [60, 10], [60, 60], [10, 60]],
- ])
- sorted_result = sorted_boxes(boxes)
- print(f" ✅ sorted_boxes: PASS (got {len(sorted_result)} boxes)")
-
- # 测试 bbox_to_points / points_to_bbox
- print("\n3. Testing bbox_to_points / points_to_bbox")
- bbox = [10, 20, 100, 200]
- points = bbox_to_points(bbox)
- bbox_back = points_to_bbox(points)
- assert bbox == bbox_back
- print(" ✅ bbox conversion: PASS")
-
- # 测试 merge_intervals
- print("\n4. Testing merge_intervals")
- intervals = [[1, 3], [2, 6], [8, 10], [15, 18]]
- merged = merge_intervals(intervals)
- print(f" ✅ merge_intervals: {merged}")
-
- # 测试 calculate_is_angle
- print("\n5. Testing calculate_is_angle")
- poly_straight = [[0, 0], [100, 0], [100, 50], [0, 50]]
- poly_angle = [[0, 0], [100, 0], [100, 80], [0, 20]]
- print(f" Straight poly is_angle: {calculate_is_angle(poly_straight)}")
- print(f" Angled poly is_angle: {calculate_is_angle(poly_angle)}")
-
- print("\n✅ All tests passed!")
|