Explorar o código

feat: 新增MinerU vLLM批量处理图片/PDF文件的功能,支持多种输入格式及输出结果生成

zhch158_admin hai 1 mes
pai
achega
2d8f911dd1
Modificáronse 1 ficheiros con 644 adicións e 0 borrados
  1. 644 0
      zhch/mineru2_vllm_multthreads.py

+ 644 - 0
zhch/mineru2_vllm_multthreads.py

@@ -0,0 +1,644 @@
+"""
+批量处理图片/PDF文件并生成符合评测要求的预测结果(MinerU版本)
+
+根据 MinerU demo.py 框架调用方式:
+- 输入:支持 PDF 和各种图片格式
+- 输出:每个文件对应的 .md、.json 文件,所有图片保存为单独的图片文件
+- 调用方式:通过 vlm-http-client 连接到 MinerU vLLM 服务器
+"""
+
+import os
+import sys
+import json
+import copy
+import shutil
+import time
+import traceback
+from pathlib import Path
+from typing import List, Dict, Any
+from PIL import Image
+from tqdm import tqdm
+import argparse
+
+from loguru import logger
+
+# 导入 MinerU 核心组件 (参考 demo.py)
+from mineru.cli.common import read_fn, convert_pdf_bytes_to_bytes_by_pypdfium2, prepare_env
+from mineru.data.data_reader_writer import FileBasedDataWriter
+from mineru.utils.draw_bbox import draw_layout_bbox
+from mineru.utils.enum_class import MakeMode
+from mineru.backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
+from mineru.backend.vlm.vlm_middle_json_mkcontent import union_make as vlm_union_make
+
+# 导入工具函数
+from utils import (
+    get_input_files,
+    collect_pid_files,
+    normalize_markdown_table,
+    normalize_json_table,
+)
+
+class MinerUVLLMProcessor:
+    """MinerU vLLM 处理器 (基于 demo.py 框架)"""
+    
+    def __init__(self, 
+                 server_url: str = "http://127.0.0.1:8121",
+                 timeout: int = 300,
+                 normalize_numbers: bool = False,
+                 debug: bool = False):
+        """
+        初始化处理器
+        
+        Args:
+            server_url: vLLM 服务器地址
+            timeout: 请求超时时间
+            normalize_numbers: 是否标准化数字
+            debug: 是否启用调试模式
+        """
+        self.server_url = server_url.rstrip('/')
+        self.timeout = timeout
+        self.normalize_numbers = normalize_numbers
+        self.debug = debug
+        self.backend = "http-client"  # 固定使用 http-client 后端
+        
+        print(f"MinerU vLLM Processor 初始化完成:")
+        print(f"  - 服务器: {server_url}")
+        print(f"  - 后端: vlm-{self.backend}")
+        print(f"  - 超时: {timeout}s")
+        print(f"  - 数字标准化: {normalize_numbers}")
+        print(f"  - 调试模式: {debug}")
+    
+    def do_parse_single_file(self, 
+                           input_file: str, 
+                           output_dir: str,
+                           start_page_id: int = 0,
+                           end_page_id: int = None) -> Dict[str, Any]:
+        """
+        解析单个文件 (参考 demo.py 的 do_parse 函数)
+        
+        Args:
+            file_path: 文件路径
+            output_dir: 输出目录
+            start_page_id: 起始页ID
+            end_page_id: 结束页ID
+            
+        Returns:
+            dict: 处理结果
+        """
+        try:
+            # 准备文件名和字节数据
+            file_path = Path(input_file)
+            pdf_file_name = file_path.stem
+            pdf_bytes = read_fn(str(file_path))
+            
+            # 转换PDF字节流 (如果需要)
+            if file_path.suffix.lower() == '.pdf':
+                pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(
+                    pdf_bytes, start_page_id, end_page_id
+                )
+            
+            # 准备环境 (创建输出目录)
+            # parse_method = "vlm"
+            # local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
+            local_md_dir = Path(output_dir).resolve()
+            local_image_dir = local_md_dir / "images"
+            image_writer = FileBasedDataWriter(local_image_dir.as_posix())
+            md_writer = FileBasedDataWriter(local_md_dir.as_posix())
+            
+            # 使用 VLM 分析文档 (核心调用)
+            middle_json, model_output = vlm_doc_analyze(
+                pdf_bytes, 
+                image_writer=image_writer, 
+                backend=self.backend,
+                server_url=self.server_url
+            )
+            
+            pdf_info = middle_json["pdf_info"]
+            
+            # 处理输出 (参考 demo.py 的 _process_output)
+            output_files = self._process_output(
+                pdf_info=pdf_info,
+                pdf_bytes=pdf_bytes,
+                pdf_file_name=pdf_file_name,
+                local_md_dir=local_md_dir,
+                local_image_dir=local_image_dir,
+                md_writer=md_writer,
+                middle_json=middle_json,
+                model_output=model_output,
+                original_file_path=str(file_path)
+            )
+            
+            # 统计提取信息
+            extraction_stats = self._get_extraction_stats(middle_json)
+            
+            return {
+                "success": True,
+                "pdf_info": pdf_info,
+                "middle_json": middle_json,
+                "model_output": model_output,
+                "output_files": output_files,
+                "extraction_stats": extraction_stats
+            }
+            
+        except Exception as e:
+            logger.error(f"Failed to process {file_path}: {e}")
+            return {
+                "success": False,
+                "error": str(e)
+            }
+    
+    def _process_output(self,
+                       pdf_info,
+                       pdf_bytes,
+                       pdf_file_name,
+                       local_md_dir,
+                       local_image_dir,
+                       md_writer,
+                       middle_json,
+                       model_output,
+                       original_file_path: str) -> Dict[str, str]:
+        """
+        处理输出文件 (改进版的 demo.py _process_output)
+        
+        Args:
+            pdf_info: PDF信息
+            pdf_bytes: PDF字节数据
+            pdf_file_name: PDF文件名
+            local_md_dir: Markdown目录
+            local_image_dir: 图片目录
+            md_writer: Markdown写入器
+            middle_json: 中间JSON数据
+            model_output: 模型输出
+            original_file_path: 原始文件路径
+            
+        Returns:
+            dict: 保存的文件路径信息
+        """
+        saved_files = {}
+        
+        try:
+            # 设置相对图片目录名
+            image_dir = str(os.path.basename(local_image_dir))
+            
+            # 1. 生成并保存 Markdown 文件
+            md_content_str = vlm_union_make(pdf_info, MakeMode.MM_MD, image_dir)
+            
+            # 数字标准化处理
+            if self.normalize_numbers:
+                original_md = md_content_str
+                md_content_str = normalize_markdown_table(md_content_str)
+                
+                changes_count = len([1 for o, n in zip(original_md, md_content_str) if o != n])
+                if changes_count > 0:
+                    saved_files['md_normalized'] = f"✅ 已标准化 {changes_count} 个字符(全角→半角)"
+                else:
+                    saved_files['md_normalized'] = "ℹ️ 无需标准化(已是标准格式)"
+            
+            md_writer.write_string(f"{pdf_file_name}.md", md_content_str)
+            saved_files['md'] = os.path.join(local_md_dir, f"{pdf_file_name}.md")
+            
+            # 2. 生成并保存 content_list JSON 文件
+            content_list = vlm_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir)
+            content_list_str = json.dumps(content_list, ensure_ascii=False, indent=2)
+            
+            # 数字标准化处理
+            if self.normalize_numbers:
+                original_json = content_list_str
+                content_list_str = normalize_json_table(content_list_str)
+                
+                changes_count = len([1 for o, n in zip(original_json, content_list_str) if o != n])
+                if changes_count > 0:
+                    saved_files['json_normalized'] = f"✅ 已标准化 {changes_count} 个字符(全角→半角)"
+                else:
+                    saved_files['json_normalized'] = "ℹ️ 无需标准化(已是标准格式)"
+            
+            md_writer.write_string(f"{pdf_file_name}.json", content_list_str)
+            saved_files['json'] = os.path.join(local_md_dir, f"{pdf_file_name}.json")
+            
+            # 绘制布局边界框
+            try:
+                draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
+                saved_files['layout_pdf'] = os.path.join(local_md_dir, f"{pdf_file_name}_layout.pdf")
+            except Exception as e:
+                logger.warning(f"Failed to draw layout bbox: {e}")
+
+            # 3. 保存原始文件到 images 目录
+            # original_file_path = Path(original_file_path)
+            # output_image_path = os.path.join(local_image_dir, f"{pdf_file_name}{original_file_path.suffix}")
+            # if original_file_path.exists():
+            #     shutil.copy2(str(original_file_path), output_image_path)
+            #     saved_files['image'] = output_image_path
+            
+            # 4. 调试模式下保存额外信息
+            if self.debug:
+                # 保存 middle.json
+                middle_json_str = json.dumps(middle_json, ensure_ascii=False, indent=2)
+                if self.normalize_numbers:
+                    middle_json_str = normalize_json_table(middle_json_str)
+                
+                md_writer.write_string(f"{pdf_file_name}_middle.json", middle_json_str)
+                saved_files['middle_json'] = os.path.join(local_md_dir, f"{pdf_file_name}_middle.json")
+                
+                # 保存 model output
+                if model_output:
+                    model_output_str = json.dumps(model_output, ensure_ascii=False, indent=2)
+                    md_writer.write_string(f"{pdf_file_name}_model.json", model_output_str)
+                    saved_files['model_output'] = os.path.join(local_md_dir, f"{pdf_file_name}_model.json")
+                
+                # # 保存原始PDF
+                # md_writer.write(f"{pdf_file_name}_origin.pdf", pdf_bytes)
+                # saved_files['origin_pdf'] = os.path.join(local_md_dir, f"{pdf_file_name}_origin.pdf")
+                
+            
+            logger.info(f"Output saved to: {local_md_dir}")
+            
+        except Exception as e:
+            logger.error(f"Error in _process_output: {e}")
+            if self.debug:
+                traceback.print_exc()
+        
+        return saved_files
+    
+    def _get_extraction_stats(self, middle_json: Dict) -> Dict[str, Any]:
+        """
+        获取提取统计信息
+        
+        Args:
+            middle_json: 中间JSON数据
+            
+        Returns:
+            dict: 统计信息
+        """
+        stats = {
+            "total_blocks": 0,
+            "block_types": {},
+            "total_pages": 0
+        }
+        
+        try:
+            pdf_info = middle_json.get("pdf_info", [])
+            if isinstance(pdf_info, list):
+                stats["total_pages"] = len(pdf_info)
+                
+                for page_info in pdf_info:
+                    para_blocks = page_info.get("para_blocks", [])
+                    stats["total_blocks"] += len(para_blocks)
+                    
+                    for block in para_blocks:
+                        block_type = block.get("type", "unknown")
+                        stats["block_types"][block_type] = stats["block_types"].get(block_type, 0) + 1
+                        
+        except Exception as e:
+            logger.warning(f"Failed to get extraction stats: {e}")
+        
+        return stats
+    
+    def process_single_image(self, image_path: str, output_dir: str) -> Dict[str, Any]:
+        """
+        处理单张图片 (保持与原接口兼容)
+        
+        Args:
+            image_path: 图片路径
+            output_dir: 输出目录
+            
+        Returns:
+            dict: 处理结果
+        """
+        start_time = time.time()
+        image_name = Path(image_path).stem
+        
+        result_info = {
+            "image_path": image_path,
+            "processing_time": 0,
+            "success": False,
+            "server": self.server_url,
+            "error": None,
+            "output_files": {},
+            "is_pdf_page": "_page_" in Path(image_path).name,
+            "extraction_stats": {}
+        }
+        
+        try:
+            # 检查输出文件是否已存在
+            expected_md_path = Path(output_dir) / f"{image_name}.md"
+            expected_json_path = Path(output_dir) / f"{image_name}.json"
+            
+            if expected_md_path.exists() and expected_json_path.exists():
+                result_info.update({
+                    "success": True,
+                    "processing_time": 0,
+                    "output_files": {
+                        "md": str(expected_md_path),
+                        "json": str(expected_json_path)
+                    },
+                    "skipped": True
+                })
+                return result_info
+            
+            # 使用 do_parse_single_file 处理
+            parse_result = self.do_parse_single_file(image_path, output_dir)
+            
+            if parse_result["success"]:
+                result_info.update({
+                    "success": True,
+                    "output_files": parse_result["output_files"],
+                    "extraction_stats": parse_result["extraction_stats"]
+                })
+            else:
+                result_info["error"] = parse_result.get("error", "Unknown error")
+            
+        except Exception as e:
+            result_info["error"] = str(e)
+            logger.error(f"Error processing {image_name}: {e}")
+            if self.debug:
+                traceback.print_exc()
+        
+        finally:
+            result_info["processing_time"] = time.time() - start_time
+        
+        return result_info
+
+
+def process_images_single_process(image_paths: List[str],
+                                processor: MinerUVLLMProcessor,
+                                batch_size: int = 1,
+                                output_dir: str = "./output") -> List[Dict[str, Any]]:
+    """
+    单进程版本的图像处理函数
+    """
+    # 创建输出目录
+    output_path = Path(output_dir)
+    output_path.mkdir(parents=True, exist_ok=True)
+    
+    all_results = []
+    total_images = len(image_paths)
+    
+    print(f"Processing {total_images} images with batch size {batch_size}")
+    
+    with tqdm(total=total_images, desc="Processing images", unit="img", 
+              bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
+        
+        for i in range(0, total_images, batch_size):
+            batch = image_paths[i:i + batch_size]
+            batch_start_time = time.time()
+            batch_results = []
+            
+            try:
+                for image_path in batch:
+                    try:
+                        result = processor.process_single_image(image_path, output_dir)
+                        batch_results.append(result)
+                    except Exception as e:
+                        logger.error(f"Error processing {image_path}: {e}")
+                        batch_results.append({
+                            "image_path": image_path,
+                            "processing_time": 0,
+                            "success": False,
+                            "server": processor.server_url,
+                            "error": str(e)
+                        })
+                
+                batch_processing_time = time.time() - batch_start_time
+                all_results.extend(batch_results)
+                
+                # 更新进度条
+                success_count = sum(1 for r in batch_results if r.get('success', False))
+                skipped_count = sum(1 for r in batch_results if r.get('skipped', False))
+                total_success = sum(1 for r in all_results if r.get('success', False))
+                total_skipped = sum(1 for r in all_results if r.get('skipped', False))
+                avg_time = batch_processing_time / len(batch) if len(batch) > 0 else 0
+                
+                total_blocks = sum(r.get('extraction_stats', {}).get('total_blocks', 0) for r in batch_results)
+                
+                pbar.update(len(batch))
+                pbar.set_postfix({
+                    'batch_time': f"{batch_processing_time:.2f}s",
+                    'avg_time': f"{avg_time:.2f}s/img",
+                    'success': f"{total_success}/{len(all_results)}",
+                    'skipped': f"{total_skipped}",
+                    'blocks': f"{total_blocks}",
+                    'rate': f"{total_success/len(all_results)*100:.1f}%" if len(all_results) > 0 else "0%"
+                })
+                
+            except Exception as e:
+                logger.error(f"Error processing batch {[Path(p).name for p in batch]}: {e}")
+                error_results = []
+                for img_path in batch:
+                    error_results.append({
+                        "image_path": str(img_path),
+                        "processing_time": 0,
+                        "success": False,
+                        "server": processor.server_url,
+                        "error": str(e)
+                    })
+                all_results.extend(error_results)
+                pbar.update(len(batch))
+    
+    return all_results
+
+
+def main():
+    """主函数"""
+    parser = argparse.ArgumentParser(description="MinerU vLLM Batch Processing (demo.py framework)")
+    
+    # 输入参数组
+    input_group = parser.add_mutually_exclusive_group(required=True)
+    input_group.add_argument("--input_file", type=str, help="Input file (supports both PDF and image file)")
+    input_group.add_argument("--input_dir", type=str, help="Input directory (supports both PDF and image files)")
+    input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")
+    input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path and status columns")
+
+    # 输出参数
+    parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
+    
+    # MinerU vLLM 参数
+    parser.add_argument("--server_url", type=str, default="http://127.0.0.1:8121", 
+                       help="MinerU vLLM server URL")
+    parser.add_argument("--timeout", type=int, default=300, help="Request timeout in seconds")
+    parser.add_argument("--dpi", type=int, default=200, help="PDF processing DPI")
+    parser.add_argument('--no-normalize', action='store_true', help='禁用数字标准化')
+    parser.add_argument('--debug', action='store_true', help='启用调试模式')
+    
+    # 处理参数
+    parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
+    parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 10 images)")
+    parser.add_argument("--collect_results", type=str, help="收集处理结果到指定CSV文件")
+    
+    args = parser.parse_args()
+    
+    try:
+        # 获取并预处理输入文件
+        print("🔄 Preprocessing input files...")
+        image_files = get_input_files(args)
+        
+        if not image_files:
+            print("❌ No input files found or processed")
+            return 1
+
+        output_dir = Path(args.output_dir).resolve()
+        print(f"📁 Output dir: {output_dir}")
+        print(f"📊 Found {len(image_files)} image files to process")
+        
+        if args.test_mode:
+            image_files = image_files[:10]
+            print(f"🧪 Test mode: processing only {len(image_files)} images")
+        
+        print(f"🌐 Using server: {args.server_url}")
+        print(f"📦 Batch size: {args.batch_size}")
+        print(f"⏱️ Timeout: {args.timeout}s")
+        
+        # 创建处理器
+        processor = MinerUVLLMProcessor(
+            server_url=args.server_url,
+            timeout=args.timeout,
+            normalize_numbers=not args.no_normalize,
+            debug=args.debug
+        )
+        
+        # 开始处理
+        start_time = time.time()
+        results = process_images_single_process(
+            image_files,
+            processor,
+            args.batch_size,
+            str(output_dir)
+        )
+        
+        total_time = time.time() - start_time
+        
+        # 统计结果
+        success_count = sum(1 for r in results if r.get('success', False))
+        skipped_count = sum(1 for r in results if r.get('skipped', False))
+        error_count = len(results) - success_count
+        pdf_page_count = sum(1 for r in results if r.get('is_pdf_page', False))
+        
+        # 统计提取的块信息
+        total_blocks = sum(r.get('extraction_stats', {}).get('total_blocks', 0) for r in results)
+        block_type_stats = {}
+        for result in results:
+            if 'extraction_stats' in result and 'block_types' in result['extraction_stats']:
+                for block_type, count in result['extraction_stats']['block_types'].items():
+                    block_type_stats[block_type] = block_type_stats.get(block_type, 0) + count
+        
+        print(f"\n" + "="*60)
+        print(f"✅ Processing completed!")
+        print(f"📊 Statistics:")
+        print(f"  Total files processed: {len(image_files)}")
+        print(f"  PDF pages processed: {pdf_page_count}")
+        print(f"  Regular images processed: {len(image_files) - pdf_page_count}")
+        print(f"  Successful: {success_count}")
+        print(f"  Skipped: {skipped_count}")
+        print(f"  Failed: {error_count}")
+        if len(image_files) > 0:
+            print(f"  Success rate: {success_count / len(image_files) * 100:.2f}%")
+        
+        print(f"📋 Content Extraction:")
+        print(f"  Total blocks extracted: {total_blocks}")
+        if block_type_stats:
+            print(f"  Block types:")
+            for block_type, count in sorted(block_type_stats.items()):
+                print(f"    {block_type}: {count}")
+        
+        print(f"⏱️ Performance:")
+        print(f"  Total time: {total_time:.2f} seconds")
+        if total_time > 0:
+            print(f"  Throughput: {len(image_files) / total_time:.2f} images/second")
+            print(f"  Avg time per image: {total_time / len(image_files):.2f} seconds")
+        
+        print(f"\n📁 Output Structure (demo.py compatible):")
+        print(f"  output_dir/")
+        print(f"  ├── filename.md              # Markdown content")
+        print(f"  ├── filename.json            # Content list")
+        print(f"  ├── filename_layout.json     # Debug: layout bbox")
+        print(f"  └── images/                  # Extracted images")
+        print(f"      └── filename.png")
+        if args.debug:
+            print(f"  ├── filename_middle.json    # Debug: middle JSON")
+            print(f"  └── filename_model.json     # Debug: model output")
+
+        # 保存结果统计
+        stats = {
+            "total_files": len(image_files),
+            "pdf_pages": pdf_page_count,
+            "regular_images": len(image_files) - pdf_page_count,
+            "success_count": success_count,
+            "skipped_count": skipped_count,
+            "error_count": error_count,
+            "success_rate": success_count / len(image_files) if len(image_files) > 0 else 0,
+            "total_time": total_time,
+            "throughput": len(image_files) / total_time if total_time > 0 else 0,
+            "avg_time_per_image": total_time / len(image_files) if len(image_files) > 0 else 0,
+            "batch_size": args.batch_size,
+            "server": args.server_url,
+            "backend": "vlm-http-client",
+            "timeout": args.timeout,
+            "pdf_dpi": args.dpi,
+            "total_blocks": total_blocks,
+            "block_type_stats": block_type_stats,
+            "normalization_enabled": not args.no_normalize,
+            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
+        }
+        
+        # 保存最终结果
+        output_file_name = Path(output_dir).name
+        output_file = os.path.join(output_dir, f"{output_file_name}_results.json")
+        final_results = {
+            "stats": stats,
+            "results": results
+        }
+        
+        with open(output_file, 'w', encoding='utf-8') as f:
+            json.dump(final_results, f, ensure_ascii=False, indent=2)
+        
+        print(f"💾 Results saved to: {output_file}")
+
+        # 收集处理结果
+        if not args.collect_results:
+            output_file_processed = Path(args.output_dir) / f"processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv"
+        else:
+            output_file_processed = Path(args.collect_results).resolve()
+            
+        processed_files = collect_pid_files(output_file)
+        with open(output_file_processed, 'w', encoding='utf-8') as f:
+            f.write("image_path,status\n")
+            for file_path, status in processed_files:
+                f.write(f"{file_path},{status}\n")
+        print(f"💾 Processed files saved to: {output_file_processed}")
+
+        return 0
+        
+    except Exception as e:
+        logger.error(f"Processing failed: {e}")
+        traceback.print_exc()
+        return 1
+
+
+if __name__ == "__main__":
+    print(f"🚀 启动MinerU vLLM统一PDF/图像处理程序...")
+    print(f"🔧 CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
+    
+    if len(sys.argv) == 1:
+        # 如果没有命令行参数,使用默认配置运行
+        print("ℹ️  No command line arguments provided. Running with default configuration...")
+        
+        # 默认配置
+        default_config = {
+            "input_file": "/home/ubuntu/zhch/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_003.png",
+            "output_dir": "./output",
+            "collect_results": f"./output/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
+            "server_url": "http://127.0.0.1:8121",
+            "timeout": "300",
+            "batch_size": "1",
+            "dpi": "200",
+        }
+        
+        # 构造参数
+        sys.argv = [sys.argv[0]]
+        for key, value in default_config.items():
+            sys.argv.extend([f"--{key}", str(value)])
+        
+        # 测试模式和调试模式
+        sys.argv.append("--test_mode")
+        sys.argv.append("--debug")
+    
+    sys.exit(main())