Forráskód Böngészése

refactor: update OCR span extraction logic and improve PDF page processing

myhloli 5 hónapja
szülő
commit
1ed61cb5d6

+ 1 - 1
mineru/backend/pipeline/batch_analyze.py

@@ -238,7 +238,7 @@ class BatchAnalyze:
                             res_area = get_coords_and_area(res)[4]
                             if res_area > 0:
                                 ratio = ocr_res_area / res_area
-                                if ratio > 0.25:
+                                if ratio > 0.3:
                                     res["category_id"] = 1
                                 else:
                                     continue

+ 2 - 2
mineru/backend/pipeline/model_json_to_middle_json.py

@@ -9,7 +9,7 @@ from mineru.utils.model_utils import clean_memory
 from mineru.utils.pipeline_magic_model import MagicModel
 from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans
 from mineru.utils.span_pre_proc import remove_outside_spans, remove_overlaps_low_confidence_spans, \
-    remove_overlaps_min_spans, txt_spans_extract
+    remove_overlaps_min_spans, txt_spans_extract_v2
 from mineru.version import __version__
 from mineru.utils.hash_utils import str_md5
 
@@ -79,7 +79,7 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer
         pass
     else:
         """使用新版本的混合ocr方案."""
-        spans = txt_spans_extract(page, spans, page_pil_img, scale)
+        spans = txt_spans_extract_v2(page, spans, page_pil_img, scale)
 
     """先处理不需要排版的discarded_blocks"""
     discarded_block_with_spans, spans = fill_spans_in_blocks(

+ 7 - 4
mineru/cli/common.py

@@ -95,8 +95,10 @@ def do_parse(
 ):
 
     if backend == "pipeline":
-        for pdf_bytes in pdf_bytes_list:
-            pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
+        for idx, pdf_bytes in enumerate(pdf_bytes_list):
+            new_pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
+            pdf_bytes_list[idx] = new_pdf_bytes
+
         infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list = pipeline_doc_analyze(pdf_bytes_list, p_lang_list, parse_method=parse_method, formula_enable=p_formula_enable,table_enable=p_table_enable)
 
         for idx, model_list in enumerate(infer_results):
@@ -213,9 +215,10 @@ def do_parse(
 
 
 if __name__ == "__main__":
-    pdf_path = "../../demo/pdfs/demo2.pdf"
+    pdf_path = "../../demo/pdfs/计算机学报-单词中间有换行符-span不准确.pdf"
+    # pdf_path = "../../demo/pdfs/demo1.pdf"
     with open(pdf_path, "rb") as f:
         try:
-           do_parse("./output", [Path(pdf_path).stem], [f.read()],["ch"],)
+           do_parse("./output", [Path(pdf_path).stem], [f.read()],["ch"], end_page_id=20,)
         except Exception as e:
             logger.exception(e)

+ 40 - 0
mineru/utils/pdf_text_tool.py

@@ -0,0 +1,40 @@
+from typing import List
+import math
+
+import pypdfium2 as pdfium
+from pdftext.pdf.chars import get_chars, deduplicate_chars
+from pdftext.pdf.pages import get_spans, get_lines, assign_scripts, get_blocks
+
+
+def get_page(
+    page: pdfium.PdfPage,
+    quote_loosebox: bool =True,
+    superscript_height_threshold: float = 0.7,
+    line_distance_threshold: float = 0.1,
+) -> dict:
+
+        textpage = page.get_textpage()
+        page_bbox: List[float] = page.get_bbox()
+        page_width = math.ceil(abs(page_bbox[2] - page_bbox[0]))
+        page_height = math.ceil(abs(page_bbox[1] - page_bbox[3]))
+
+        page_rotation = 0
+        try:
+            page_rotation = page.get_rotation()
+        except:
+            pass
+
+        chars = deduplicate_chars(get_chars(textpage, page_bbox, page_rotation, quote_loosebox))
+        spans = get_spans(chars, superscript_height_threshold=superscript_height_threshold, line_distance_threshold=line_distance_threshold)
+        lines = get_lines(spans)
+        assign_scripts(lines, height_threshold=superscript_height_threshold, line_distance_threshold=line_distance_threshold)
+        blocks = get_blocks(lines)
+
+        page = {
+            "bbox": page_bbox,
+            "width": page_width,
+            "height": page_height,
+            "rotation": page_rotation,
+            "blocks": blocks
+        }
+        return page

+ 69 - 1
mineru/utils/span_pre_proc.py

@@ -2,11 +2,13 @@
 import re
 import cv2
 import numpy as np
+from loguru import logger
 
 from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio, calculate_iou, \
     get_minbox_if_overlap_by_ratio
 from mineru.utils.enum_class import BlockType, ContentType
 from mineru.utils.pdf_image_tools import get_crop_img
+from mineru.utils.pdf_text_tool import get_page
 
 
 def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
@@ -114,7 +116,7 @@ def __replace_unicode(text: str):
     return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
 
 
-def txt_spans_extract(pdf_page, spans, pil_img, scale):
+def txt_spans_extract_v1(pdf_page, spans, pil_img, scale):
 
     textpage = pdf_page.get_textpage()
     width, height = pdf_page.get_size()
@@ -128,6 +130,10 @@ def txt_spans_extract(pdf_page, spans, pil_img, scale):
                     height - span_bbox[3] + cropbox[1],
                     span_bbox[2] + cropbox[0],
                     height - span_bbox[1] + cropbox[1]]
+        # logger.info(f"span bbox: {span_bbox}, rect_box: {rect_box}")
+        middle_height = (rect_box[1] + rect_box[3]) / 2
+        rect_box[1] = middle_height - 1
+        rect_box[3] = middle_height + 1
         text = textpage.get_text_bounded(left=rect_box[0], top=rect_box[1],
                                          right=rect_box[2], bottom=rect_box[3])
         if text and len(text) > 0:
@@ -156,6 +162,68 @@ def txt_spans_extract(pdf_page, spans, pil_img, scale):
     return spans
 
 
+def txt_spans_extract_v2(pdf_page, spans, pil_img, scale):
+
+    page_dict = get_page(pdf_page)
+
+    page_all_spans = []
+    for block in page_dict['blocks']:
+        for line in block['lines']:
+            if 0 < abs(line['rotation']) < 90:
+                # 旋转角度在0-90度之间的行,直接跳过
+                continue
+            for span in line['spans']:
+                page_all_spans.append(span)
+
+    need_ocr_spans = []
+    for span in spans:
+        if span['type'] in [ContentType.TEXT]:
+            span['sub_spans'] = []
+            matched_spans = []
+            for page_span in page_all_spans:
+                if calculate_overlap_area_in_bbox1_area_ratio(page_span['bbox'].bbox, span['bbox']) > 0.5:
+                    span['sub_spans'].append(page_span)
+                    matched_spans.append(page_span)
+
+            # 从page_all_spans中移除已匹配的元素
+            page_all_spans = [span for span in page_all_spans if span not in matched_spans]
+
+            # 对sub_spans按照bbox的x坐标进行排序
+            span['sub_spans'].sort(key=lambda x: x['bbox'].x_start)
+            # 对sub_spans的content进行拼接
+            span_content = ''.join([sub_span['text'] for sub_span in span['sub_spans']])
+
+            if span_content and len(span_content) > 0:
+                span_content = __replace_unicode(span_content)
+                span_content = __replace_ligatures(span_content)
+                span['content'] = span_content.strip()
+                span['score'] = 1.0
+            else:
+                need_ocr_spans.append(span)
+
+            # 移除span的sub_spans
+            span.pop('sub_spans', None)
+        else:
+            pass
+
+    if len(need_ocr_spans) > 0:
+
+        for span in need_ocr_spans:
+            # 对span的bbox截图再ocr
+            span_pil_img = get_crop_img(span['bbox'], pil_img, scale)
+            span_img = cv2.cvtColor(np.array(span_pil_img), cv2.COLOR_RGB2BGR)
+            # 计算span的对比度,低于0.20的span不进行ocr
+            if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
+                spans.remove(span)
+                continue
+
+            span['content'] = ''
+            span['score'] = 1.0
+            span['np_img'] = span_img
+
+    return spans
+
+
 def calculate_contrast(img, img_mode) -> float:
     """
     计算给定图像的对比度。