瀏覽代碼

Merge pull request #1664 from myhloli/dev

feat(pdf_parse): improve OCR processing and contrast filtering
Xiaomeng Zhao 9 月之前
父節點
當前提交
6e1fba9345

+ 1 - 1
magic_pdf/filter/__init__.py

@@ -23,7 +23,7 @@ def classify(pdf_bytes: bytes) -> SupportedPdfParseMethod:
                 pdf_meta['image_info_per_page'],
                 pdf_meta['text_len_per_page'],
                 pdf_meta['imgs_per_page'],
-                pdf_meta['text_layout_per_page'],
+                # pdf_meta['text_layout_per_page'],
                 pdf_meta['invalid_chars'],
             )
             if is_text_pdf:

+ 6 - 4
magic_pdf/filter/pdf_classify_by_type.py

@@ -305,7 +305,8 @@ def classify_by_img_narrow_strips(page_width, page_height, img_sz_list):
 
 
 def classify(total_page: int, page_width, page_height, img_sz_list: list, text_len_list: list, img_num_list: list,
-             text_layout_list: list, invalid_chars: bool):
+             # text_layout_list: list,
+             invalid_chars: bool):
     """
     这里的图片和页面长度单位是pts
     :param total_page:
@@ -321,7 +322,7 @@ def classify(total_page: int, page_width, page_height, img_sz_list: list, text_l
         'by_text_len': classify_by_text_len(text_len_list, total_page),
         'by_avg_words': classify_by_avg_words(text_len_list),
         'by_img_num': classify_by_img_num(img_sz_list, img_num_list),
-        'by_text_layout': classify_by_text_layout(text_layout_list),
+        # 'by_text_layout': classify_by_text_layout(text_layout_list),
         'by_img_narrow_strips': classify_by_img_narrow_strips(page_width, page_height, img_sz_list),
         'by_invalid_chars': invalid_chars,
     }
@@ -332,9 +333,10 @@ def classify(total_page: int, page_width, page_height, img_sz_list: list, text_l
         return False, results
     else:
         logger.warning(
-            f"pdf is not classified by area and text_len, by_image_area: {results['by_image_area']},"
+            f"OCR needed based on classification result, by_image_area: {results['by_image_area']},"
             f" by_text: {results['by_text_len']}, by_avg_words: {results['by_avg_words']}, by_img_num: {results['by_img_num']},"
-            f" by_text_layout: {results['by_text_layout']}, by_img_narrow_strips: {results['by_img_narrow_strips']},"
+            # f" by_text_layout: {results['by_text_layout']},"
+            f" by_img_narrow_strips: {results['by_img_narrow_strips']},"
             f" by_invalid_chars: {results['by_invalid_chars']}",
             file=sys.stderr)  # 利用这种情况可以快速找出来哪些pdf比较特殊,针对性修正分类算法
         return False, results

+ 4 - 4
magic_pdf/filter/pdf_meta_scan.py

@@ -356,9 +356,9 @@ def pdf_meta_scan(pdf_bytes: bytes):
         # logger.info(f"image_info_per_page: {image_info_per_page}, junk_img_bojids: {junk_img_bojids}")
         text_len_per_page = get_pdf_textlen_per_page(doc)
         # logger.info(f"text_len_per_page: {text_len_per_page}")
-        text_layout_per_page = get_pdf_text_layout_per_page(doc)
+        # text_layout_per_page = get_pdf_text_layout_per_page(doc)
         # logger.info(f"text_layout_per_page: {text_layout_per_page}")
-        text_language = get_language(doc)
+        # text_language = get_language(doc)
         # logger.info(f"text_language: {text_language}")
         invalid_chars = check_invalid_chars(pdf_bytes)
         # logger.info(f"invalid_chars: {invalid_chars}")
@@ -372,8 +372,8 @@ def pdf_meta_scan(pdf_bytes: bytes):
             'page_height_pts': int(page_height_pts),
             'image_info_per_page': image_info_per_page,
             'text_len_per_page': text_len_per_page,
-            'text_layout_per_page': text_layout_per_page,
-            'text_language': text_language,
+            # 'text_layout_per_page': text_layout_per_page,
+            # 'text_language': text_language,
             # "svgs_per_page": svgs_per_page,
             'imgs_per_page': imgs_per_page,  # 增加每页img数量list
             'junk_img_bojids': junk_img_bojids,  # 增加垃圾图片的bojid list

+ 11 - 1
magic_pdf/libs/pdf_check.py

@@ -4,6 +4,7 @@ from loguru import logger
 import re
 from io import BytesIO
 from pdfminer.high_level import extract_text
+from pdfminer.layout import LAParams
 
 
 def calculate_sample_count(total_page: int):
@@ -41,7 +42,16 @@ def detect_invalid_chars(src_pdf_bytes: bytes) -> bool:
     sample_docs = extract_pages(src_pdf_bytes)
     sample_pdf_bytes = sample_docs.tobytes()
     sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
-    text = extract_text(sample_pdf_file_like_object)
+    laparams = LAParams(
+        line_overlap=0.5,
+        char_margin=2.0,
+        line_margin=0.5,
+        word_margin=0.1,
+        boxes_flow=None,
+        detect_vertical=False,
+        all_texts=False,
+    )
+    text = extract_text(pdf_file=sample_pdf_file_like_object, laparams=laparams)
     text = text.replace("\n", "")
     # logger.info(text)
     '''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''

+ 4 - 3
magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py

@@ -1,4 +1,5 @@
 # Copyright (c) Opendatalab. All rights reserved.
+import time
 from collections import Counter
 from uuid import uuid4
 
@@ -102,9 +103,9 @@ class YOLOv11LangDetModel(object):
             temp_images = split_images(image)
             for temp_image in temp_images:
                 all_images.append(resize_images_to_224(temp_image))
-
-        images_lang_res = self.batch_predict(all_images, batch_size=8)
-        # logger.info(f"images_lang_res: {images_lang_res}")
+        # langdetect_start = time.time()
+        images_lang_res = self.batch_predict(all_images, batch_size=256)
+        # logger.info(f"image number of langdetect: {len(images_lang_res)}, langdetect time: {round(time.time() - langdetect_start, 2)}")
         if len(images_lang_res) > 0:
             count_dict = Counter(images_lang_res)
             language = max(count_dict, key=count_dict.get)

+ 39 - 7
magic_pdf/pdf_parse_union_core_v2.py

@@ -6,8 +6,10 @@ import statistics
 import time
 from typing import List
 
+import cv2
 import fitz
 import torch
+import numpy as np
 from loguru import logger
 
 from magic_pdf.config.enums import SupportedPdfParseMethod
@@ -127,16 +129,15 @@ def fill_char_in_spans(spans, all_chars):
                 span['chars'].append(char)
                 break
 
-    empty_spans = []
-
+    need_ocr_spans = []
     for span in spans:
         chars_to_content(span)
         # 有的span中虽然没有字但有一两个空的占位符,用宽高和content长度过滤
         if len(span['content']) * span['height'] < span['width'] * 0.5:
             # logger.info(f"maybe empty span: {len(span['content'])}, {span['height']}, {span['width']}")
-            empty_spans.append(span)
+            need_ocr_spans.append(span)
         del span['height'], span['width']
-    return empty_spans
+    return need_ocr_spans
 
 
 # 使用鲁棒性更强的中心点坐标判断
@@ -190,6 +191,31 @@ def remove_tilted_line(text_blocks):
             block['lines'].remove(line)
 
 
+def calculate_contrast(img, img_mode) -> float:
+    """
+    计算给定图像的对比度。
+    :param img: 图像,类型为numpy.ndarray
+    :Param img_mode = 图像的色彩通道,'rgb' 或 'bgr'
+    :return: 图像的对比度值
+    """
+    if img_mode == 'rgb':
+        # 将RGB图像转换为灰度图
+        gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+    elif img_mode == 'bgr':
+        # 将BGR图像转换为灰度图
+        gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+    else:
+        raise ValueError("Invalid image mode. Please provide 'rgb' or 'bgr'.")
+
+    # 计算均值和标准差
+    mean_value = np.mean(gray_img)
+    std_dev = np.std(gray_img)
+    # 对比度定义为标准差除以平均值(加上小常数避免除零错误)
+    contrast = std_dev / (mean_value + 1e-6)
+    # logger.info(f"contrast: {contrast}")
+    return round(contrast, 2)
+
+
 def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
     # cid用0xfffd表示,连字符拆开
     # text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
@@ -274,9 +300,9 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
             span['chars'] = []
             new_spans.append(span)
 
-    empty_spans = fill_char_in_spans(new_spans, all_pymu_chars)
+    need_ocr_spans = fill_char_in_spans(new_spans, all_pymu_chars)
 
-    if len(empty_spans) > 0:
+    if len(need_ocr_spans) > 0:
 
         # 初始化ocr模型
         atom_model_manager = AtomModelSingleton()
@@ -287,9 +313,15 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
             lang=lang
         )
 
-        for span in empty_spans:
+        for span in need_ocr_spans:
             # 对span的bbox截图再ocr
             span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode='cv2')
+
+            # 计算span的对比度,低于0.20的span不进行ocr
+            if calculate_contrast(span_img, img_mode='bgr') <= 0.20:
+                spans.remove(span)
+                continue
+
             ocr_res = ocr_model.ocr(span_img, det=False)
             if ocr_res and len(ocr_res) > 0:
                 if len(ocr_res[0]) > 0: