Răsfoiți Sursa

Merge pull request #2616 from myhloli/dev

Dev
Xiaomeng Zhao 5 luni în urmă
părinte
comite
3cc3f75411

+ 12 - 5
mineru/backend/pipeline/pipeline_analyze.py

@@ -1,6 +1,7 @@
 import os
 import time
-import numpy as np
+from typing import List, Tuple
+import PIL.Image
 import torch
 
 from .model_init import MineruPipelineModel
@@ -150,7 +151,7 @@ def doc_analyze(
 
 
 def batch_image_analyze(
-        images_with_extra_info: list[(np.ndarray, bool, str)],
+        images_with_extra_info: List[Tuple[PIL.Image.Image, bool, str]],
         formula_enable=None,
         table_enable=None):
     # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
@@ -163,9 +164,15 @@ def batch_image_analyze(
     device = get_device()
 
     if str(device).startswith('npu'):
-        import torch_npu
-        if torch_npu.npu.is_available():
-            torch.npu.set_compile_mode(jit_compile=False)
+        try:
+            import torch_npu
+            if torch_npu.npu.is_available():
+                torch.npu.set_compile_mode(jit_compile=False)
+        except Exception as e:
+            raise RuntimeError(
+                "NPU is selected as device, but torch_npu is not available. "
+                "Please ensure that the torch_npu package is installed correctly."
+            ) from e
 
     if str(device).startswith('npu') or str(device).startswith('cuda'):
         vram = get_vram(device)

+ 6 - 1
mineru/backend/pipeline/pipeline_middle_json_mkcontent.py

@@ -34,6 +34,8 @@ def make_blocks_to_markdown(paras_of_layout,
             title_level = get_title_level(para_block)
             para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
         elif para_type == BlockType.INTERLINE_EQUATION:
+            if len(para_block['lines']) == 0 or len(para_block['lines'][0]['spans']) == 0:
+                continue
             if para_block['lines'][0]['spans'][0].get('content', ''):
                 para_text = merge_para_with_text(para_block)
             else:
@@ -201,6 +203,8 @@ def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
         if title_level != 0:
             para_content['text_level'] = title_level
     elif para_type == BlockType.INTERLINE_EQUATION:
+        if len(para_block['lines']) == 0 or len(para_block['lines'][0]['spans']) == 0:
+            return None
         para_content = {
             'type': 'equation',
             'img_path': f"{img_buket_path}/{para_block['lines'][0]['spans'][0].get('image_path', '')}",
@@ -263,7 +267,8 @@ def union_make(pdf_info_dict: list,
         elif make_mode == MakeMode.CONTENT_LIST:
             for para_block in paras_of_layout:
                 para_content = make_blocks_to_content_list(para_block, img_buket_path, page_idx)
-                output_content.append(para_content)
+                if para_content:
+                    output_content.append(para_content)
 
     if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
         return '\n\n'.join(output_content)

+ 30 - 23
mineru/backend/vlm/vlm_magic_model.py

@@ -205,35 +205,42 @@ def isolated_formula_clean(txt):
 
 
 def latex_fix(latex):
-    # 白名单分隔符
-    valid_delims_list = [r'(', r')', r'[', r']', r'{', r'}', r'/', r'|',
-                         r'\{', r'\}', r'\lceil', r'\rceil', r'\lfloor',
-                         r'\rfloor', r'\backslash', r'\uparrow', r'\downarrow',
-                         r'\Uparrow', r'\Downarrow', r'\|', r'\.']
-
-    # 为\left后缺失有效分隔符的情况添加点
-    def fix_delim(match):
-        cmd = match.group(1)  # \left 或 \right
-        rest = match.group(2) if len(match.groups()) > 1 else ""
-        if not rest or rest not in valid_delims_list:
-            return cmd + "."
-        return match.group(0)
-
-    LEFT_PATTERN = re.compile(r'(\\left)(\S*)')
-    RIGHT_PATTERN = re.compile(r'(\\right)(\S*)')
+    # valid pairs:
+    # \left\{ ... \right\}
+    # \left( ... \right)
+    # \left| ... \right|
+    # \left\| ... \right\|
+    # \left[ ... \right]
+
     LEFT_COUNT_PATTERN = re.compile(r'\\left(?![a-zA-Z])')
     RIGHT_COUNT_PATTERN = re.compile(r'\\right(?![a-zA-Z])')
-    LEFT_RIGHT_REMOVE_PATTERN = re.compile(r'\\left\.?|\\right\.?')
-
-    latex = LEFT_PATTERN.sub(lambda m: fix_delim(m), latex)
-    latex = RIGHT_PATTERN.sub(lambda m: fix_delim(m), latex)
-
-
     left_count = len(LEFT_COUNT_PATTERN.findall(latex))  # 不匹配\lefteqn等
     right_count = len(RIGHT_COUNT_PATTERN.findall(latex))  # 不匹配\rightarrow
 
     if left_count != right_count:
-        return LEFT_RIGHT_REMOVE_PATTERN.sub('', latex)
+        for _ in range(2):
+            # replace valid pairs
+            latex = re.sub(r'\\left\\\{', "{", latex) # \left\{
+            latex = re.sub(r"\\left\|", "|", latex) # \left|
+            latex = re.sub(r"\\left\\\|", "|", latex) # \left\|
+            latex = re.sub(r"\\left\(", "(", latex) # \left(
+            latex = re.sub(r"\\left\[", "[", latex) # \left[
+
+            latex = re.sub(r"\\right\\\}", "}", latex) # \right\}
+            latex = re.sub(r"\\right\|", "|", latex) # \right|
+            latex = re.sub(r"\\right\\\|", "|", latex) # \right\|
+            latex = re.sub(r"\\right\)", ")", latex) # \right)
+            latex = re.sub(r"\\right\]", "]", latex) # \right]
+            latex = re.sub(r"\\right\.", "", latex) # \right.
+
+            # replace invalid pairs first
+            latex = re.sub(r'\\left\{', "{", latex)
+            latex = re.sub(r'\\right\}', "}", latex) # \left{ ... \right}
+            latex = re.sub(r'\\left\\\(', "(", latex)
+            latex = re.sub(r'\\right\\\)', ")", latex) # \left\( ... \right\)
+            latex = re.sub(r'\\left\\\[', "[", latex)
+            latex = re.sub(r'\\right\\\]', "]", latex) # \left\[ ... \right\]
+
     return latex
 
 

+ 4 - 5
mineru/cli/client.py

@@ -4,6 +4,8 @@ import click
 from pathlib import Path
 import torch
 from loguru import logger
+
+from mineru.utils.config_reader import get_device
 from mineru.utils.model_utils import get_vram
 from ..version import __version__
 from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
@@ -144,11 +146,8 @@ def main(input_path, output_dir, method, backend, lang, server_url, start_page_i
     def get_device_mode() -> str:
         if device_mode is not None:
             return device_mode
-        if torch.cuda.is_available():
-            return "cuda"
-        if torch.backends.mps.is_available():
-            return "mps"
-        return "cpu"
+        else:
+            return get_device()
     if os.getenv('MINERU_DEVICE_MODE', None) is None:
         os.environ['MINERU_DEVICE_MODE'] = get_device_mode()
 

+ 8 - 1
mineru/utils/config_reader.py

@@ -74,8 +74,15 @@ def get_device():
     else:
         if torch.cuda.is_available():
             return "cuda"
-        if torch.backends.mps.is_available():
+        elif torch.backends.mps.is_available():
             return "mps"
+        else:
+            try:
+                import torch_npu
+                if torch_npu.npu.is_available():
+                    return "npu"
+            except Exception as e:
+                pass
         return "cpu"
 
 

+ 15 - 2
mineru/utils/draw_bbox.py

@@ -1,6 +1,7 @@
 import json
 from io import BytesIO
 
+from loguru import logger
 from pypdf import PdfReader, PdfWriter
 from reportlab.pdfgen import canvas
 
@@ -182,7 +183,13 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
         packet.seek(0)
         overlay_pdf = PdfReader(packet)
 
-        page.merge_page(overlay_pdf.pages[0])
+        # 添加检查确保overlay_pdf.pages不为空
+        if len(overlay_pdf.pages) > 0:
+            page.merge_page(overlay_pdf.pages[0])
+        else:
+            # 记录日志并继续处理下一个页面
+            logger.warning(f"layout.pdf: 第{i + 1}页未能生成有效的overlay PDF")
+
         output_pdf.add_page(page)
 
     # 保存结果
@@ -290,7 +297,13 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
         packet.seek(0)
         overlay_pdf = PdfReader(packet)
 
-        page.merge_page(overlay_pdf.pages[0])
+        # 添加检查确保overlay_pdf.pages不为空
+        if len(overlay_pdf.pages) > 0:
+            page.merge_page(overlay_pdf.pages[0])
+        else:
+            # 记录日志并继续处理下一个页面
+            logger.warning(f"span.pdf: 第{i + 1}页未能生成有效的overlay PDF")
+
         output_pdf.add_page(page)
 
     # Save the PDF

+ 84 - 45
mineru/utils/pdf_classify.py

@@ -5,8 +5,13 @@ import numpy as np
 import pypdfium2 as pdfium
 from loguru import logger
 from pdfminer.high_level import extract_text
-from pdfminer.layout import LAParams
-from pypdf import PdfReader
+from pdfminer.pdfparser import PDFParser
+from pdfminer.pdfdocument import PDFDocument
+from pdfminer.pdfpage import PDFPage
+from pdfminer.pdfinterp import PDFResourceManager
+from pdfminer.pdfinterp import PDFPageInterpreter
+from pdfminer.layout import LAParams, LTImage, LTFigure
+from pdfminer.converter import PDFPageAggregator
 
 
 def classify(pdf_bytes):
@@ -41,7 +46,7 @@ def classify(pdf_bytes):
             return 'ocr'
         else:
 
-            if get_high_image_coverage_ratio(sample_pdf_bytes, pages_to_check) >= 0.9:
+            if get_high_image_coverage_ratio(sample_pdf_bytes, pages_to_check) >= 0.8:
                 return 'ocr'
 
             return 'txt'
@@ -77,60 +82,94 @@ def get_avg_cleaned_chars_per_page(pdf_doc, pages_to_check):
 
     return avg_cleaned_chars_per_page
 
+
 def get_high_image_coverage_ratio(sample_pdf_bytes, pages_to_check):
+    # 创建内存文件对象
     pdf_stream = BytesIO(sample_pdf_bytes)
-    pdf_reader = PdfReader(pdf_stream)
+
+    # 创建PDF解析器
+    parser = PDFParser(pdf_stream)
+
+    # 创建PDF文档对象
+    document = PDFDocument(parser)
+
+    # 检查文档是否允许文本提取
+    if not document.is_extractable:
+        # logger.warning("PDF不允许内容提取")
+        return 1.0  # 默认为高覆盖率,因为无法提取内容
+
+    # 创建资源管理器和参数对象
+    rsrcmgr = PDFResourceManager()
+    laparams = LAParams(
+        line_overlap=0.5,
+        char_margin=2.0,
+        line_margin=0.5,
+        word_margin=0.1,
+        boxes_flow=None,
+        detect_vertical=False,
+        all_texts=False,
+    )
+
+    # 创建聚合器
+    device = PDFPageAggregator(rsrcmgr, laparams=laparams)
+
+    # 创建解释器
+    interpreter = PDFPageInterpreter(rsrcmgr, device)
 
     # 记录高图像覆盖率的页面数量
     high_image_coverage_pages = 0
+    page_count = 0
 
-    # 检查前几页的图像
-    for i in range(pages_to_check):
-        page = pdf_reader.pages[i]
+    # 遍历页面
+    for page in PDFPage.create_pages(document):
+        # 控制检查的页数
+        if page_count >= pages_to_check:
+            break
+
+        # 处理页面
+        interpreter.process_page(page)
+        layout = device.get_result()
 
-        # 获取页面尺寸
-        page_width = float(page.mediabox.width)
-        page_height = float(page.mediabox.height)
+        # 页面尺寸
+        page_width = layout.width
+        page_height = layout.height
         page_area = page_width * page_height
 
-        # 估算图像覆盖率
+        # 计算图像覆盖的总面积
         image_area = 0
-        if '/Resources' in page:
-            resources = page['/Resources']
-            if '/XObject' in resources:
-                x_objects = resources['/XObject']
-                # 计算所有图像对象占据的面积
-                for obj_name in x_objects:
-                    try:
-                        obj = x_objects[obj_name]
-                        if obj['/Subtype'] == '/Image':
-                            # 获取图像宽高
-                            width = obj.get('/Width', 0)
-                            height = obj.get('/Height', 0)
-
-                            # 计算图像在页面上的估计面积
-                            # 注意:这是估计值,因为没有考虑图像变换矩阵
-                            scale_factor = 1.0  # 估计缩放因子
-                            img_area = width * height * scale_factor
-                            image_area += img_area
-                    except Exception as e:
-                        # logger.debug(f"处理图像对象时出错: {e}")
-                        continue
-
-        # 估算图像覆盖率
-        estimated_coverage = min(image_area / page_area, 1.0) if page_area > 0 else 0
-        # logger.debug(f"PDF分析: 页面 {i + 1} 图像覆盖率: {estimated_coverage:.2f}")
-        # 基于估计的图像覆盖率
-        if estimated_coverage >= 1:
-            # 如果图像覆盖率超过80%,认为是高图像覆盖率页面
+
+        # 遍历页面元素
+        for element in layout:
+            # 检查是否为图像或图形元素
+            if isinstance(element, (LTImage, LTFigure)):
+                # 计算图像边界框面积
+                img_width = element.width
+                img_height = element.height
+                img_area = img_width * img_height
+                image_area += img_area
+
+        # 计算覆盖率
+        coverage_ratio = min(image_area / page_area, 1.0) if page_area > 0 else 0
+        # logger.debug(f"PDF分析: 页面 {page_count + 1} 图像覆盖率: {coverage_ratio:.2f}")
+
+        # 判断是否为高覆盖率
+        if coverage_ratio >= 0.8:  # 使用80%作为高覆盖率的阈值
             high_image_coverage_pages += 1
-    # 计算高图像覆盖页面比例
-    high_image_coverage_ratio = high_image_coverage_pages / pages_to_check
-    # logger.debug(f"PDF分析: 高图像覆盖页面比例: {high_image_coverage_ratio:.2f}")
 
-    pdf_stream.close()  # 关闭字节流
-    pdf_reader.close()
-    return high_image_coverage_ratio
+        page_count += 1
+
+    # 如果没有处理任何页面,返回0
+    if page_count == 0:
+        return 0.0
+
+    # 计算高图像覆盖率的页面比例
+    high_coverage_ratio = high_image_coverage_pages / page_count
+    # logger.debug(f"PDF分析: 高图像覆盖页面比例: {high_coverage_ratio:.2f}")
+
+    # 关闭资源
+    pdf_stream.close()
+
+    return high_coverage_ratio
 
 
 def extract_pages(src_pdf_bytes: bytes) -> bytes: