소스 검색

Merge pull request #2878 from myhloli/dev

refactor: adjust minimum confidence threshold and enhance merging logic in ocr_utils.py
Xiaomeng Zhao 4 달 전
부모
커밋
275e662ea9
2개의 변경된 파일127개의 추가작업 그리고 16개의 파일을 삭제
  1. 43 12
      mineru/utils/ocr_utils.py
  2. 84 4
      mineru/utils/span_block_fix.py

+ 43 - 12
mineru/utils/ocr_utils.py

@@ -5,9 +5,11 @@ import numpy as np
 
 
 class OcrConfidence:
-    min_confidence = 0.6
+    min_confidence = 0.5
     min_width = 3
 
+LINE_WIDTH_TO_HEIGHT_RATIO_THRESHOLD = 4  # 一般情况下,行宽度超过高度4倍时才是一个正常的横向文本块
+
 
 def merge_spans_to_line(spans, threshold=0.6):
     if len(spans) == 0:
@@ -20,7 +22,7 @@ def merge_spans_to_line(spans, threshold=0.6):
         current_line = [spans[0]]
         for span in spans[1:]:
             # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
-            if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
+            if _is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
                 current_line.append(span)
             else:
                 # 否则,开始新行
@@ -33,9 +35,9 @@ def merge_spans_to_line(spans, threshold=0.6):
 
         return lines
 
-def __is_overlaps_y_exceeds_threshold(bbox1,
-                                      bbox2,
-                                      overlap_ratio_threshold=0.8):
+def _is_overlaps_y_exceeds_threshold(bbox1,
+                                     bbox2,
+                                     overlap_ratio_threshold=0.8):
     """检查两个bbox在y轴上是否有重叠,并且该重叠区域的高度占两个bbox高度更低的那个超过80%"""
     _, y0_1, _, y1_1 = bbox1
     _, y0_2, _, y1_2 = bbox2
@@ -45,7 +47,21 @@ def __is_overlaps_y_exceeds_threshold(bbox1,
     # max_height = max(height1, height2)
     min_height = min(height1, height2)
 
-    return (overlap / min_height) > overlap_ratio_threshold
+    return (overlap / min_height) > overlap_ratio_threshold if min_height > 0 else False
+
+
+def _is_overlaps_x_exceeds_threshold(bbox1,
+                                     bbox2,
+                                     overlap_ratio_threshold=0.8):
+    """检查两个bbox在x轴上是否有重叠,并且该重叠区域的宽度占两个bbox宽度更低的那个超过指定阈值"""
+    x0_1, _, x1_1, _ = bbox1
+    x0_2, _, x1_2, _ = bbox2
+
+    overlap = max(0, min(x1_1, x1_2) - max(x0_1, x0_2))
+    width1, width2 = x1_1 - x0_1, x1_2 - x0_2
+    min_width = min(width1, width2)
+
+    return (overlap / min_width) > overlap_ratio_threshold if min_width > 0 else False
 
 
 def img_decode(content: bytes):
@@ -178,7 +194,7 @@ def update_det_boxes(dt_boxes, mfd_res):
         masks_list = []
         for mf_box in mfd_res:
             mf_bbox = mf_box['bbox']
-            if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
+            if _is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
                 masks_list.append([mf_bbox[0], mf_bbox[2]])
         text_x_range = [text_bbox[0], text_bbox[2]]
         text_remove_mask_range = remove_intervals(text_x_range, masks_list)
@@ -266,12 +282,27 @@ def merge_det_boxes(dt_boxes):
         for span in line:
             line_bbox_list.append(span['bbox'])
 
-        # Merge overlapping text regions within the same line
-        merged_spans = merge_overlapping_spans(line_bbox_list)
+        # 计算整行的宽度和高度
+        min_x = min(bbox[0] for bbox in line_bbox_list)
+        max_x = max(bbox[2] for bbox in line_bbox_list)
+        min_y = min(bbox[1] for bbox in line_bbox_list)
+        max_y = max(bbox[3] for bbox in line_bbox_list)
+        line_width = max_x - min_x
+        line_height = max_y - min_y
+
+        # 只有当行宽度超过高度4倍时才进行合并
+        if line_width > line_height * LINE_WIDTH_TO_HEIGHT_RATIO_THRESHOLD:
 
-        # Convert the merged text regions back to point format and add them to the new detection box list
-        for span in merged_spans:
-            new_dt_boxes.append(bbox_to_points(span))
+            # Merge overlapping text regions within the same line
+            merged_spans = merge_overlapping_spans(line_bbox_list)
+
+            # Convert the merged text regions back to point format and add them to the new detection box list
+            for span in merged_spans:
+                new_dt_boxes.append(bbox_to_points(span))
+        else:
+            # 不进行合并,直接添加原始区域
+            for bbox in line_bbox_list:
+                new_dt_boxes.append(bbox_to_points(bbox))
 
     new_dt_boxes.extend(angle_boxes_list)
 

+ 84 - 4
mineru/utils/span_block_fix.py

@@ -1,8 +1,10 @@
 # Copyright (c) Opendatalab. All rights reserved.
 from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
 from mineru.utils.enum_class import BlockType, ContentType
-from mineru.utils.ocr_utils import __is_overlaps_y_exceeds_threshold
+from mineru.utils.ocr_utils import _is_overlaps_y_exceeds_threshold, _is_overlaps_x_exceeds_threshold
 
+VERTICAL_SPAN_HEIGHT_TO_WIDTH_RATIO_THRESHOLD = 2
+VERTICAL_SPAN_IN_BLOCK_THRESHOLD = 0.8
 
 def fill_spans_in_blocks(blocks, spans, radio):
     """将allspans中的span按位置关系,放入blocks中."""
@@ -71,8 +73,26 @@ def fix_text_block(block):
     for span in block['spans']:
         if span['type'] == ContentType.INTERLINE_EQUATION:
             span['type'] = ContentType.INLINE_EQUATION
-    block_lines = merge_spans_to_line(block['spans'])
-    sort_block_lines = line_sort_spans_by_left_to_right(block_lines)
+
+    # 假设block中的span超过80%的数量高度是宽度的两倍以上,则认为是纵向文本块
+    vertical_span_count = sum(
+        1 for span in block['spans']
+        if (span['bbox'][3] - span['bbox'][1]) / (span['bbox'][2] - span['bbox'][0]) > VERTICAL_SPAN_HEIGHT_TO_WIDTH_RATIO_THRESHOLD
+    )
+    total_span_count = len(block['spans'])
+    if total_span_count == 0:
+        vertical_ratio = 0
+    else:
+        vertical_ratio = vertical_span_count / total_span_count
+
+    if vertical_ratio > VERTICAL_SPAN_IN_BLOCK_THRESHOLD:
+        # 如果是纵向文本块,则按纵向lines处理
+        block_lines = merge_spans_to_vertical_line(block['spans'])
+        sort_block_lines = vertical_line_sort_spans_from_top_to_bottom(block_lines)
+    else:
+        block_lines = merge_spans_to_line(block['spans'])
+        sort_block_lines = line_sort_spans_by_left_to_right(block_lines)
+
     block['lines'] = sort_block_lines
     del block['spans']
     return block
@@ -103,7 +123,7 @@ def merge_spans_to_line(spans, threshold=0.6):
                 continue
 
             # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
-            if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
+            if _is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
                 current_line.append(span)
             else:
                 # 否则,开始新行
@@ -117,6 +137,44 @@ def merge_spans_to_line(spans, threshold=0.6):
         return lines
 
 
+def merge_spans_to_vertical_line(spans, threshold=0.6):
+    """将纵向文本的spans合并成纵向lines(从右向左阅读)"""
+    if len(spans) == 0:
+        return []
+    else:
+        # 按照x2坐标从大到小排序(从右向左)
+        spans.sort(key=lambda span: span['bbox'][2], reverse=True)
+
+        vertical_lines = []
+        current_line = [spans[0]]
+
+        for span in spans[1:]:
+            # 特殊类型元素单独成列
+            if span['type'] in [
+                ContentType.INTERLINE_EQUATION, ContentType.IMAGE,
+                ContentType.TABLE
+            ] or any(s['type'] in [
+                ContentType.INTERLINE_EQUATION, ContentType.IMAGE,
+                ContentType.TABLE
+            ] for s in current_line):
+                vertical_lines.append(current_line)
+                current_line = [span]
+                continue
+
+            # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
+            if _is_overlaps_x_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
+                current_line.append(span)
+            else:
+                vertical_lines.append(current_line)
+                current_line = [span]
+
+        # 添加最后一列
+        if current_line:
+            vertical_lines.append(current_line)
+
+        return vertical_lines
+
+
 # 将每一个line中的span从左到右排序
 def line_sort_spans_by_left_to_right(lines):
     line_objects = []
@@ -136,6 +194,28 @@ def line_sort_spans_by_left_to_right(lines):
     return line_objects
 
 
+def vertical_line_sort_spans_from_top_to_bottom(vertical_lines):
+    line_objects = []
+    for line in vertical_lines:
+        # 按照y0坐标排序(从上到下)
+        line.sort(key=lambda span: span['bbox'][1])
+
+        # 计算整个列的边界框
+        line_bbox = [
+            min(span['bbox'][0] for span in line),  # x0
+            min(span['bbox'][1] for span in line),  # y0
+            max(span['bbox'][2] for span in line),  # x1
+            max(span['bbox'][3] for span in line),  # y1
+        ]
+
+        # 组装结果
+        line_objects.append({
+            'bbox': line_bbox,
+            'spans': line,
+        })
+    return line_objects
+
+
 def fix_block_spans(block_with_spans):
     fix_blocks = []
     for block in block_with_spans: