Forráskód Böngészése

refactor: update text span extraction to use new version and improve character handling

myhloli 5 hónapja
szülő
commit
d2de6d801a

+ 2 - 2
mineru/backend/pipeline/model_json_to_middle_json.py

@@ -9,7 +9,7 @@ from mineru.utils.model_utils import clean_memory
 from mineru.utils.pipeline_magic_model import MagicModel
 from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans
 from mineru.utils.span_pre_proc import remove_outside_spans, remove_overlaps_low_confidence_spans, \
-    remove_overlaps_min_spans, txt_spans_extract_v2
+    remove_overlaps_min_spans, txt_spans_extract_v3
 from mineru.version import __version__
 from mineru.utils.hash_utils import str_md5
 
@@ -79,7 +79,7 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer
         pass
     else:
         """使用新版本的混合ocr方案."""
-        spans = txt_spans_extract_v2(page, spans, page_pil_img, scale)
+        spans = txt_spans_extract_v3(page, spans, page_pil_img, scale, all_bboxes, all_discarded_blocks)
 
     """先处理不需要排版的discarded_blocks"""
     discarded_block_with_spans, spans = fill_spans_in_blocks(

+ 2 - 2
mineru/cli/common.py

@@ -215,8 +215,8 @@ def do_parse(
 
 
 if __name__ == "__main__":
-    pdf_path = "../../demo/pdfs/计算机学报-单词中间有换行符-span不准确.pdf"
-    # pdf_path = "../../demo/pdfs/demo1.pdf"
+    # pdf_path = "../../demo/pdfs/计算机学报-单词中间有换行符-span不准确.pdf"
+    pdf_path = "../../demo/pdfs/demo1.pdf"
     with open(pdf_path, "rb") as f:
         try:
            do_parse("./output", [Path(pdf_path).stem], [f.read()],["ch"], end_page_id=20,)

+ 241 - 0
mineru/utils/span_pre_proc.py

@@ -1,5 +1,7 @@
 # Copyright (c) Opendatalab. All rights reserved.
 import re
+import statistics
+
 import cv2
 import numpy as np
 from loguru import logger
@@ -116,6 +118,7 @@ def __replace_unicode(text: str):
     return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
 
 
+"""textpage.get_text_bounded方案"""
 def txt_spans_extract_v1(pdf_page, spans, pil_img, scale):
 
     textpage = pdf_page.get_textpage()
@@ -162,6 +165,7 @@ def txt_spans_extract_v1(pdf_page, spans, pil_img, scale):
     return spans
 
 
+"""pdf_text dict方案 span级别"""
 def txt_spans_extract_v2(pdf_page, spans, pil_img, scale):
 
     page_dict = get_page(pdf_page)
@@ -224,6 +228,243 @@ def txt_spans_extract_v2(pdf_page, spans, pil_img, scale):
     return spans
 
 
+"""pdf_text dict方案 char级别"""
+def txt_spans_extract_v3(pdf_page, spans, pil_img, scale, all_bboxes, all_discarded_blocks):
+
+    page_dict = get_page(pdf_page)
+
+    page_all_chars = []
+    page_all_lines = []
+    for block in page_dict['blocks']:
+        for line in block['lines']:
+            if 0 < abs(line['rotation']) < 90:
+                # 旋转角度在0-90度之间的行,直接跳过
+                continue
+            page_all_lines.append(line)
+            for span in line['spans']:
+                for char in span['chars']:
+                    page_all_chars.append(char)
+
+    # 计算所有sapn的高度的中位数
+    span_height_list = []
+    for span in spans:
+        if span['type'] in [ContentType.TEXT]:
+            span_height = span['bbox'][3] - span['bbox'][1]
+            span['height'] = span_height
+            span['width'] = span['bbox'][2] - span['bbox'][0]
+            span_height_list.append(span_height)
+    if len(span_height_list) == 0:
+        return spans
+    else:
+        median_span_height = statistics.median(span_height_list)
+
+    useful_spans = []
+    unuseful_spans = []
+    # 纵向span的两个特征:1. 高度超过多个line 2. 高宽比超过某个值
+    vertical_spans = []
+    for span in spans:
+        if span['type'] in [ContentType.TEXT]:
+            for block in all_bboxes + all_discarded_blocks:
+                if block[7] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.INTERLINE_EQUATION]:
+                    continue
+                if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block[0:4]) > 0.5:
+                    if span['height'] > median_span_height * 3 and span['height'] > span['width'] * 3:
+                        vertical_spans.append(span)
+                    elif block in all_bboxes:
+                        useful_spans.append(span)
+                    else:
+                        unuseful_spans.append(span)
+                    break
+
+    """垂直的span框直接用line进行填充"""
+    if len(vertical_spans) > 0:
+        for pdfium_line in page_all_lines:
+            for span in vertical_spans:
+                if calculate_overlap_area_in_bbox1_area_ratio(pdfium_line['bbox'].bbox, span['bbox']) > 0.5:
+                    for pdfium_span in pdfium_line['spans']:
+                        span['content'] += pdfium_span['text']
+                    break
+
+        for span in vertical_spans:
+            if len(span['content']) == 0:
+                spans.remove(span)
+
+    """水平的span框先用char填充,再用ocr填充空的span框"""
+    new_spans = []
+
+    for span in useful_spans + unuseful_spans:
+        if span['type'] in [ContentType.TEXT]:
+            span['chars'] = []
+            new_spans.append(span)
+
+    need_ocr_spans = fill_char_in_spans(new_spans, page_all_chars)
+
+    """对未填充的span进行ocr"""
+    if len(need_ocr_spans) > 0:
+
+        for span in need_ocr_spans:
+            # 对span的bbox截图再ocr
+            span_pil_img = get_crop_img(span['bbox'], pil_img, scale)
+            span_img = cv2.cvtColor(np.array(span_pil_img), cv2.COLOR_RGB2BGR)
+            # 计算span的对比度,低于0.20的span不进行ocr
+            if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
+                spans.remove(span)
+                continue
+
+            span['content'] = ''
+            span['score'] = 1.0
+            span['np_img'] = span_img
+
+    return spans
+
+
+def fill_char_in_spans(spans, all_chars):
+
+    # 简单从上到下排一下序
+    spans = sorted(spans, key=lambda x: x['bbox'][1])
+
+    for char in all_chars:
+
+        for span in spans:
+            if calculate_char_in_span(char['bbox'], span['bbox'], char['char']):
+                span['chars'].append(char)
+                break
+
+    need_ocr_spans = []
+    for span in spans:
+        chars_to_content(span)
+        # 有的span中虽然没有字但有一两个空的占位符,用宽高和content长度过滤
+        if len(span['content']) * span['height'] < span['width'] * 0.5:
+            # logger.info(f"maybe empty span: {len(span['content'])}, {span['height']}, {span['width']}")
+            need_ocr_spans.append(span)
+        del span['height'], span['width']
+    return need_ocr_spans
+
+
+LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';', ']', '】', '}', '}', '>', '》', '、', ',', ',', '-', '—', '–',)
+LINE_START_FLAG = ('(', '(', '"', '“', '【', '{', '《', '<', '「', '『', '【', '[',)
+
+def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
+    char_center_x = (char_bbox[0] + char_bbox[2]) / 2
+    char_center_y = (char_bbox[1] + char_bbox[3]) / 2
+    span_center_y = (span_bbox[1] + span_bbox[3]) / 2
+    span_height = span_bbox[3] - span_bbox[1]
+
+    if (
+        span_bbox[0] < char_center_x < span_bbox[2]
+        and span_bbox[1] < char_center_y < span_bbox[3]
+        and abs(char_center_y - span_center_y) < span_height * span_height_radio  # 字符的中轴和span的中轴高度差不能超过1/4span高度
+    ):
+        return True
+    else:
+        # 如果char是LINE_STOP_FLAG,就不用中心点判定,换一种方案(左边界在span区域内,高度判定和之前逻辑一致)
+        # 主要是给结尾符号一个进入span的机会,这个char还应该离span右边界较近
+        if char in LINE_STOP_FLAG:
+            if (
+                (span_bbox[2] - span_height) < char_bbox[0] < span_bbox[2]
+                and char_center_x > span_bbox[0]
+                and span_bbox[1] < char_center_y < span_bbox[3]
+                and abs(char_center_y - span_center_y) < span_height * span_height_radio
+            ):
+                return True
+        elif char in LINE_START_FLAG:
+            if (
+                span_bbox[0] < char_bbox[2] < (span_bbox[0] + span_height)
+                and char_center_x < span_bbox[2]
+                and span_bbox[1] < char_center_y < span_bbox[3]
+                and abs(char_center_y - span_center_y) < span_height * span_height_radio
+            ):
+                return True
+        else:
+            return False
+
+
+def chars_to_content(span):
+    # 检查span中的char是否为空
+    if len(span['chars']) == 0:
+        pass
+    else:
+        # 先给chars按char['bbox']的中心点的x坐标排序
+        span['chars'] = sorted(span['chars'], key=lambda x: (x['bbox'][0] + x['bbox'][2]) / 2)
+
+        # Calculate the width of each character
+        char_widths = [char['bbox'][2] - char['bbox'][0] for char in span['chars']]
+        # Calculate the median width
+        median_width = statistics.median(char_widths)
+
+        # 通过x轴重叠比率移除一部分char
+        span = remove_x_overlapping_chars(span, median_width)
+
+        content = ''
+        for char in span['chars']:
+
+            # 如果下一个char的x0和上一个char的x1距离超过0.25个字符宽度,则需要在中间插入一个空格
+            char1 = char
+            char2 = span['chars'][span['chars'].index(char) + 1] if span['chars'].index(char) + 1 < len(span['chars']) else None
+            if char2 and char2['bbox'][0] - char1['bbox'][2] > median_width * 0.25 and char['char'] != ' ' and char2['char'] != ' ':
+                content += f"{char['char']} "
+            else:
+                content += char['char']
+
+        content = __replace_unicode(content)
+        content = __replace_ligatures(content)
+        content = __replace_ligatures(content)
+        span['content'] = content.strip()
+
+    del span['chars']
+
+
+def remove_x_overlapping_chars(span, median_width):
+    """
+    Remove characters from a span that overlap significantly on the x-axis.
+
+    Args:
+        median_width:
+        span (dict): A span containing a list of chars, each with bbox coordinates
+                    in the format [x0, y0, x1, y1]
+
+    Returns:
+        dict: The span with overlapping characters removed
+    """
+    if 'chars' not in span or len(span['chars']) < 2:
+        return span
+
+    overlap_threshold = median_width * 0.3
+
+    i = 0
+    while i < len(span['chars']) - 1:
+        char1 = span['chars'][i]
+        char2 = span['chars'][i + 1]
+
+        # Calculate overlap width
+        x_left = max(char1['bbox'][0], char2['bbox'][0])
+        x_right = min(char1['bbox'][2], char2['bbox'][2])
+
+        if x_right > x_left:  # There is overlap
+            overlap_width = x_right - x_left
+
+            if overlap_width > overlap_threshold:
+                if char1['char'] == char2['char'] or char1['char'] == ' ' or char2['char'] == ' ':
+                    # Determine which character to remove
+                    width1 = char1['bbox'][2] - char1['bbox'][0]
+                    width2 = char2['bbox'][2] - char2['bbox'][0]
+                    if width1 < width2:
+                        # Remove the narrower character
+                        span['chars'].pop(i)
+                    else:
+                        span['chars'].pop(i + 1)
+                else:
+                    i += 1
+
+                # Don't increment i since we need to check the new pair
+            else:
+                i += 1
+        else:
+            i += 1
+
+    return span
+
+
 def calculate_contrast(img, img_mode) -> float:
     """
     计算给定图像的对比度。