فهرست منبع

Merge pull request #1050 from myhloli/dev

refactor(txt_parse): improve text extraction accuracy with new algorithm
Xiaomeng Zhao 1 سال پیش
والد
کامیت
ead2e67028

+ 27 - 1
magic_pdf/libs/pdf_image_tools.py

@@ -1,4 +1,7 @@
-
+from io import BytesIO
+import cv2
+import numpy as np
+from PIL import Image
 from magic_pdf.data.data_reader_writer import DataWriter
 from magic_pdf.libs.commons import fitz, join_path
 from magic_pdf.libs.hash_utils import compute_sha256
@@ -29,3 +32,26 @@ def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWri
     imageWriter.write(img_hash256_path, byte_data)
 
     return img_hash256_path
+
+
+def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"):
+
+    # 将坐标转换为fitz.Rect对象
+    rect = fitz.Rect(*bbox)
+    # 配置缩放倍数为3倍
+    zoom = fitz.Matrix(3, 3)
+    # 截取图片
+    pix = page.get_pixmap(clip=rect, matrix=zoom)
+
+    # 将字节数据转换为文件对象
+    image_file = BytesIO(pix.tobytes(output='png'))
+    # 使用 Pillow 打开图像
+    pil_image = Image.open(image_file)
+    if mode == "cv2":
+        image_result = cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_RGB2BGR)
+    elif mode == "pillow":
+        image_result = pil_image
+    else:
+        raise ValueError(f"mode: {mode} is not supported.")
+
+    return image_result

+ 1 - 1
magic_pdf/model/sub_modules/model_init.py

@@ -63,7 +63,7 @@ def ocr_model_init(show_log: bool = False,
                    use_dilation=True,
                    det_db_unclip_ratio=1.8,
                    ):
-    if lang is not None:
+    if lang is not None and lang != '':
         model = ModifiedPaddleOCR(
             show_log=show_log,
             det_db_box_thresh=det_db_box_thresh,

+ 2 - 0
magic_pdf/pdf_parse_by_ocr.py

@@ -9,6 +9,7 @@ def parse_pdf_by_ocr(pdf_bytes,
                      start_page_id=0,
                      end_page_id=None,
                      debug_mode=False,
+                     lang=None,
                      ):
     dataset = PymuDocDataset(pdf_bytes)
     return pdf_parse_union(dataset,
@@ -18,4 +19,5 @@ def parse_pdf_by_ocr(pdf_bytes,
                            start_page_id=start_page_id,
                            end_page_id=end_page_id,
                            debug_mode=debug_mode,
+                           lang=lang,
                            )

+ 2 - 0
magic_pdf/pdf_parse_by_txt.py

@@ -10,6 +10,7 @@ def parse_pdf_by_txt(
     start_page_id=0,
     end_page_id=None,
     debug_mode=False,
+    lang=None,
 ):
     dataset = PymuDocDataset(pdf_bytes)
     return pdf_parse_union(dataset,
@@ -19,4 +20,5 @@ def parse_pdf_by_txt(
                            start_page_id=start_page_id,
                            end_page_id=end_page_id,
                            debug_mode=debug_mode,
+                           lang=lang,
                            )

+ 190 - 27
magic_pdf/pdf_parse_union_core_v2.py

@@ -18,7 +18,21 @@ from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
 from magic_pdf.libs.convert_utils import dict_to_list
 from magic_pdf.libs.hash_utils import compute_md5
 from magic_pdf.libs.local_math import float_equal
+from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
 from magic_pdf.model.magic_model import MagicModel
+
+os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
+os.environ['YOLO_VERBOSE'] = 'False'  # disable yolo logger
+
+try:
+    import torchtext
+
+    if torchtext.__version__ >= "0.18.0":
+        torchtext.disable_torchtext_deprecation_warning()
+except ImportError:
+    pass
+from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
+
 from magic_pdf.para.para_split_v3 import para_split
 from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
 from magic_pdf.pre_proc.construct_page_dict import \
@@ -74,7 +88,150 @@ def __replace_STX_ETX(text_str: str):
     return text_str
 
 
-def txt_spans_extract(pdf_page, inline_equations, interline_equations):
+def chars_to_content(span):
+        # # 先给chars按char['bbox']的x坐标排序
+        # span['chars'] = sorted(span['chars'], key=lambda x: x['bbox'][0])
+
+        # 先给chars按char['bbox']的中心点的x坐标排序
+        span['chars'] = sorted(span['chars'], key=lambda x: (x['bbox'][0] + x['bbox'][2]) / 2)
+        content = ''
+
+        # 求char的平均宽度
+        if len(span['chars']) == 0:
+            span['content'] = content
+            del span['chars']
+            return
+        else:
+            char_width_sum = sum([char['bbox'][2] - char['bbox'][0] for char in span['chars']])
+            char_avg_width = char_width_sum / len(span['chars'])
+
+        for char in span['chars']:
+            # 如果下一个char的x0和上一个char的x1距离超过一个字符宽度,则需要在中间插入一个空格
+            if char['bbox'][0] - span['chars'][span['chars'].index(char) - 1]['bbox'][2] > char_avg_width:
+                content += ' '
+            content += char['c']
+        span['content'] = __replace_STX_ETX(content)
+        del span['chars']
+
+
+LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';', ']', '】', '}', '}', '>', '》', '、', ',', ',')
+def fill_char_in_spans(spans, all_chars):
+
+    for char in all_chars:
+        for span in spans:
+            # 判断char是否属于LINE_STOP_FLAG
+            if char['c'] in LINE_STOP_FLAG:
+                char_is_line_stop_flag = True
+            else:
+                char_is_line_stop_flag = False
+            if calculate_char_in_span(char['bbox'], span['bbox'], char_is_line_stop_flag):
+                span['chars'].append(char)
+                break
+
+    for span in spans:
+        chars_to_content(span)
+
+
+# 使用鲁棒性更强的中心点坐标判断
+def calculate_char_in_span(char_bbox, span_bbox, char_is_line_stop_flag):
+    char_center_x = (char_bbox[0] + char_bbox[2]) / 2
+    char_center_y = (char_bbox[1] + char_bbox[3]) / 2
+    span_center_y = (span_bbox[1] + span_bbox[3]) / 2
+    span_height = span_bbox[3] - span_bbox[1]
+
+    if (
+        span_bbox[0] < char_center_x < span_bbox[2]
+        and span_bbox[1] < char_center_y < span_bbox[3]
+        and abs(char_center_y - span_center_y) < span_height / 4  # 字符的中轴和span的中轴高度差不能超过1/4span高度
+    ):
+        return True
+    else:
+        # 如果char是LINE_STOP_FLAG,就不用中心点判定,换一种方案(左边界在span区域内,高度判定和之前逻辑一致)
+        # 主要是给结尾符号一个进入span的机会,这个char还应该离span右边界较近
+        if char_is_line_stop_flag:
+            if (
+                (span_bbox[2] - span_height) < char_bbox[0] < span_bbox[2]
+                and span_bbox[1] < char_center_y < span_bbox[3]
+                and abs(char_center_y - span_center_y) < span_height / 4
+            ):
+                return True
+        else:
+            return False
+
+
+def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
+
+    useful_spans = []
+    unuseful_spans = []
+    for span in spans:
+        for block in all_bboxes:
+            if block[7] in [BlockType.ImageBody, BlockType.TableBody, BlockType.InterlineEquation]:
+                continue
+            else:
+                if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block[0:4]) > 0.5:
+                    useful_spans.append(span)
+                    break
+        for block in all_discarded_blocks:
+            if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block[0:4]) > 0.5:
+                unuseful_spans.append(span)
+                break
+
+    text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
+
+    # @todo: 拿到char之后把倾斜角度较大的先删一遍
+    all_pymu_chars = []
+    for block in text_blocks:
+        for line in block['lines']:
+            for span in line['spans']:
+                all_pymu_chars.extend(span['chars'])
+
+    new_spans = []
+
+    for span in useful_spans:
+        if span['type'] in [ContentType.Text]:
+            span['chars'] = []
+            new_spans.append(span)
+
+    for span in unuseful_spans:
+        if span['type'] in [ContentType.Text]:
+            span['chars'] = []
+            new_spans.append(span)
+
+    fill_char_in_spans(new_spans, all_pymu_chars)
+
+    empty_spans = []
+    for span in new_spans:
+        if len(span['content']) == 0:
+            empty_spans.append(span)
+    if len(empty_spans) > 0:
+
+        # 初始化ocr模型
+        atom_model_manager = AtomModelSingleton()
+        ocr_model = atom_model_manager.get_atom_model(
+            atom_model_name="ocr",
+            ocr_show_log=False,
+            det_db_box_thresh=0.3,
+            lang=lang
+        )
+
+        for span in empty_spans:
+            spans.remove(span)
+            # 对span的bbox截图
+            span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode="cv2")
+            ocr_res = ocr_model.ocr(span_img, det=False)
+            # logger.info(f"ocr_res: {ocr_res}")
+            # logger.info(f"empty_span: {span}")
+            if len(ocr_res) > 0:
+                if len(ocr_res[0]) > 0:
+                    ocr_text, ocr_score = ocr_res[0][0]
+                    if ocr_score > 0.5 and len(ocr_text) > 0:
+                            span['content'] = ocr_text
+                            spans.append(span)
+
+    return spans
+
+
+def txt_spans_extract_v1(pdf_page, inline_equations, interline_equations):
     text_raw_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
     char_level_text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)[
         'blocks'
@@ -464,18 +621,16 @@ def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
 
 
 def parse_page_core(
-    page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
+    page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
 ):
     need_drop = False
     drop_reason = []
 
     """从magic_model对象中获取后面会用到的区块信息"""
-    # img_blocks = magic_model.get_imgs(page_id)
-    # table_blocks = magic_model.get_tables(page_id)
-
     img_groups = magic_model.get_imgs_v2(page_id)
     table_groups = magic_model.get_tables_v2(page_id)
 
+    """对image和table的区块分组"""
     img_body_blocks, img_caption_blocks, img_footnote_blocks = process_groups(
         img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
     )
@@ -519,38 +674,20 @@ def parse_page_core(
             page_h,
         )
 
+    """获取所有的spans信息"""
     spans = magic_model.get_all_spans(page_id)
 
-    """根据parse_mode,构造spans"""
-    if parse_mode == SupportedPdfParseMethod.TXT:
-        """ocr 中文本类的 span 用 pymu spans 替换!"""
-        pymu_spans = txt_spans_extract(page_doc, inline_equations, interline_equations)
-        spans = replace_text_span(pymu_spans, spans)
-    elif parse_mode == SupportedPdfParseMethod.OCR:
-        pass
-    else:
-        raise Exception('parse_mode must be txt or ocr')
-
     """在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
     """顺便删除大水印并保留abandon的span"""
     spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks)
 
-    """删除重叠spans中置信度较低的那些"""
-    spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
-    """删除重叠spans中较小的那些"""
-    spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
-    """对image和table截图"""
-    spans = ocr_cut_image_and_table(
-        spans, page_doc, page_id, pdf_bytes_md5, imageWriter
-    )
-
     """先处理不需要排版的discarded_blocks"""
     discarded_block_with_spans, spans = fill_spans_in_blocks(
         all_discarded_blocks, spans, 0.4
     )
     fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
 
-    """如果当前页面没有bbox则跳过"""
+    """如果当前页面没有有效的bbox则跳过"""
     if len(all_bboxes) == 0:
         logger.warning(f'skip this page, not found useful bbox, page_id: {page_id}')
         return ocr_construct_page_component_v2(
@@ -568,7 +705,32 @@ def parse_page_core(
             drop_reason,
         )
 
-    """将span填入blocks中"""
+    """删除重叠spans中置信度较低的那些"""
+    spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
+    """删除重叠spans中较小的那些"""
+    spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
+
+    """根据parse_mode,构造spans,主要是文本类的字符填充"""
+    if parse_mode == SupportedPdfParseMethod.TXT:
+
+        """之前的公式替换方案"""
+        # pymu_spans = txt_spans_extract_v1(page_doc, inline_equations, interline_equations)
+        # spans = replace_text_span(pymu_spans, spans)
+
+        """ocr 中文本类的 span 用 pymu spans 替换!"""
+        spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, lang)
+
+    elif parse_mode == SupportedPdfParseMethod.OCR:
+        pass
+    else:
+        raise Exception('parse_mode must be txt or ocr')
+
+    """对image和table截图"""
+    spans = ocr_cut_image_and_table(
+        spans, page_doc, page_id, pdf_bytes_md5, imageWriter
+    )
+
+    """span填充进block"""
     block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
 
     """对block进行fix操作"""
@@ -618,6 +780,7 @@ def pdf_parse_union(
     start_page_id=0,
     end_page_id=None,
     debug_mode=False,
+    lang=None,
 ):
     pdf_bytes_md5 = compute_md5(dataset.data_bits())
 
@@ -654,7 +817,7 @@ def pdf_parse_union(
         """解析pdf中的每一页"""
         if start_page_id <= page_id <= end_page_id:
             page_info = parse_page_core(
-                page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
+                page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
             )
         else:
             page_info = page.get_page_info()

+ 3 - 0
magic_pdf/user_api.py

@@ -30,6 +30,7 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, i
         start_page_id=start_page_id,
         end_page_id=end_page_id,
         debug_mode=is_debug,
+        lang=lang,
     )
 
     pdf_info_dict['_parse_type'] = PARSE_TYPE_TXT
@@ -53,6 +54,7 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, i
         start_page_id=start_page_id,
         end_page_id=end_page_id,
         debug_mode=is_debug,
+        lang=lang,
     )
 
     pdf_info_dict['_parse_type'] = PARSE_TYPE_OCR
@@ -80,6 +82,7 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter,
                 start_page_id=start_page_id,
                 end_page_id=end_page_id,
                 debug_mode=is_debug,
+                lang=lang,
             )
         except Exception as e:
             logger.exception(e)