Explorar el Código

feat: Add single and multi-threaded processing for OmniDocBench images

- Implemented DotsOCRProcessor class for handling image processing.
- Created `process_images_single_process` for single-threaded image processing.
- Added `process_images_concurrent` for multi-threaded image processing using ThreadPoolExecutor.
- Enhanced result saving functionality to output .md, .json, and annotated layout images.
- Introduced command-line arguments for input/output directories, server settings, and processing options.
- Added support for batch processing and error handling.
- Included progress tracking with tqdm for better user feedback during processing.
zhch158_admin hace 3 meses
padre
commit
3d6e265827
Se han modificado 2 ficheros con 635 adiciones y 0 borrados
  1. 0 0
      zhch/OmniDocBench_DotsOCR single.py
  2. 635 0
      zhch/OmniDocBench_DotsOCR_multthreads.py

+ 0 - 0
zhch/OmniDocBench_DotsOCR.py → zhch/OmniDocBench_DotsOCR single.py


+ 635 - 0
zhch/OmniDocBench_DotsOCR_multthreads.py

@@ -0,0 +1,635 @@
+"""
+批量处理 OmniDocBench 图片并生成符合评测要求的预测结果
+
+根据 OmniDocBench 评测要求:
+- 输入:OpenDataLab___OmniDocBench/images 下的所有 .jpg 图片
+- 输出:每个图片对应的 .md、.json 和带标注的 layout 图片文件
+- 输出目录:用于后续的 end2end 评测
+"""
+
+import os
+import sys
+import json
+import tempfile
+import uuid
+import shutil
+import time
+import traceback
+import warnings
+from pathlib import Path
+from typing import List, Dict, Any
+from PIL import Image
+from tqdm import tqdm
+import argparse
+
+# 导入 dots.ocr 相关模块
+from dots_ocr.parser import DotsOCRParser
+from dots_ocr.utils import dict_promptmode_to_prompt
+from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
+
+# 导入工具函数
+from utils import (
+    get_image_files_from_dir,
+    get_image_files_from_list,
+    get_image_files_from_csv,
+    collect_pid_files
+)
+
+class DotsOCRProcessor:
+    """DotsOCR 处理器"""
+    
+    def __init__(self, 
+                 ip: str = "127.0.0.1", 
+                 port: int = 8101, 
+                 model_name: str = "DotsOCR",
+                 prompt_mode: str = "prompt_layout_all_en",
+                 dpi: int = 200,
+                 min_pixels: int = MIN_PIXELS,
+                 max_pixels: int = MAX_PIXELS):
+        """
+        初始化处理器
+        
+        Args:
+            ip: vLLM 服务器 IP
+            port: vLLM 服务器端口
+            model_name: 模型名称
+            prompt_mode: 提示模式
+            dpi: PDF 处理 DPI
+            min_pixels: 最小像素数
+            max_pixels: 最大像素数
+        """
+        self.ip = ip
+        self.port = port
+        self.model_name = model_name
+        self.prompt_mode = prompt_mode
+        self.dpi = dpi
+        self.min_pixels = min_pixels
+        self.max_pixels = max_pixels
+        
+        # 初始化解析器
+        self.parser = DotsOCRParser(
+            ip=ip,
+            port=port,
+            dpi=dpi,
+            min_pixels=min_pixels,
+            max_pixels=max_pixels,
+            model_name=model_name
+        )
+        
+        print(f"DotsOCR Parser 初始化完成:")
+        print(f"  - 服务器: {ip}:{port}")
+        print(f"  - 模型: {model_name}")
+        print(f"  - 提示模式: {prompt_mode}")
+        print(f"  - 像素范围: {min_pixels} - {max_pixels}")
+    
+    def create_temp_session_dir(self) -> tuple:
+        """创建临时会话目录"""
+        session_id = uuid.uuid4().hex[:8]
+        temp_dir = os.path.join(tempfile.gettempdir(), f"omnidocbench_batch_{session_id}")
+        os.makedirs(temp_dir, exist_ok=True)
+        return temp_dir, session_id
+    
+    def save_results_to_output_dir(self, result: Dict, image_name: str, output_dir: str) -> Dict[str, str]:
+        """
+        将处理结果保存到输出目录
+        
+        Args:
+            result: 解析结果
+            image_name: 图片文件名(不含扩展名)
+            output_dir: 输出目录
+            
+        Returns:
+            dict: 保存的文件路径
+        """
+        saved_files = {}
+        
+        try:
+            # 1. 保存 Markdown 文件(OmniDocBench 评测必需)
+            output_md_path = os.path.join(output_dir, f"{image_name}.md")
+            md_content = ""
+            
+            # 优先使用无页眉页脚的版本(符合 OmniDocBench 评测要求)
+            if 'md_content_nohf_path' in result and os.path.exists(result['md_content_nohf_path']):
+                with open(result['md_content_nohf_path'], 'r', encoding='utf-8') as f:
+                    md_content = f.read()
+            elif 'md_content_path' in result and os.path.exists(result['md_content_path']):
+                with open(result['md_content_path'], 'r', encoding='utf-8') as f:
+                    md_content = f.read()
+            else:
+                md_content = "# 解析失败\n\n未能提取到有效的文档内容。"
+            
+            with open(output_md_path, 'w', encoding='utf-8') as f:
+                f.write(md_content)
+            saved_files['md'] = output_md_path
+            
+            # 2. 保存 JSON 文件
+            output_json_path = os.path.join(output_dir, f"{image_name}.json")
+            json_data = {}
+            
+            if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
+                with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
+                    json_data = json.load(f)
+            else:
+                json_data = {
+                    "error": "未能提取到有效的布局信息",
+                    "cells": []
+                }
+            
+            with open(output_json_path, 'w', encoding='utf-8') as f:
+                json.dump(json_data, f, ensure_ascii=False, indent=2)
+            saved_files['json'] = output_json_path
+            
+            # 3. 保存带标注的布局图片
+            output_layout_image_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
+            
+            if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
+                # 直接复制布局图片
+                shutil.copy2(result['layout_image_path'], output_layout_image_path)
+                saved_files['layout_image'] = output_layout_image_path
+            else:
+                # 如果没有布局图片,使用原始图片作为占位符
+                try:
+                    original_image = Image.open(result.get('original_image_path', ''))
+                    original_image.save(output_layout_image_path, 'JPEG', quality=95)
+                    saved_files['layout_image'] = output_layout_image_path
+                except Exception as e:
+                    saved_files['layout_image'] = None
+            
+            # # 4. 可选:保存原始图片副本
+            # output_original_image_path = os.path.join(output_dir, f"{image_name}_original.jpg")
+            # if 'original_image_path' in result and os.path.exists(result['original_image_path']):
+            #     shutil.copy2(result['original_image_path'], output_original_image_path)
+            #     saved_files['original_image'] = output_original_image_path
+            
+        except Exception as e:
+            print(f"Error saving results for {image_name}: {e}")
+            
+        return saved_files
+    
+    def process_single_image(self, image_path: str, output_dir: str) -> Dict[str, Any]:
+        """
+        处理单张图片
+        
+        Args:
+            image_path: 图片路径
+            output_dir: 输出目录
+            
+        Returns:
+            dict: 处理结果
+        """
+        start_time = time.time()
+        image_name = Path(image_path).stem
+        
+        result_info = {
+            "image_path": image_path,
+            "processing_time": 0,
+            "success": False,
+            "device": f"{self.ip}:{self.port}",
+            "error": None,
+            "output_files": {}
+        }
+        
+        try:
+            # 检查输出文件是否已存在
+            output_md_path = os.path.join(output_dir, f"{image_name}.md")
+            output_json_path = os.path.join(output_dir, f"{image_name}.json")
+            output_layout_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
+            
+            if all(os.path.exists(p) for p in [output_md_path, output_json_path, output_layout_path]):
+                result_info.update({
+                    "success": True,
+                    "processing_time": 0,
+                    "output_files": {
+                        "md": output_md_path,
+                        "json": output_json_path,
+                        "layout_image": output_layout_path
+                    },
+                    "skipped": True
+                })
+                return result_info
+            
+            # 创建临时会话目录
+            temp_dir, session_id = self.create_temp_session_dir()
+            
+            try:
+                # 读取图片
+                image = Image.open(image_path)
+                
+                # 使用 DotsOCRParser 处理图片
+                filename = f"omnidocbench_{session_id}"
+                results = self.parser.parse_image(
+                    input_path=image,
+                    filename=filename,
+                    prompt_mode=self.prompt_mode,
+                    save_dir=temp_dir,
+                    fitz_preprocess=True  # 对图片使用 fitz 预处理
+                )
+                
+                # 解析结果
+                if not results:
+                    raise Exception("未返回解析结果")
+                
+                result = results[0]  # parse_image 返回单个结果的列表
+                
+                # 添加原始图片路径到结果中
+                # result['original_image_path'] = image_path
+                
+                # 保存所有结果文件到输出目录
+                saved_files = self.save_results_to_output_dir(result, image_name, output_dir)
+                
+                # 验证保存结果
+                success_count = sum(1 for path in saved_files.values() if path and os.path.exists(path))
+                
+                if success_count >= 2:  # 至少保存了 md 和 json
+                    result_info.update({
+                        "success": True,
+                        "output_files": saved_files
+                    })
+                else:
+                    raise Exception(f"保存文件不完整 ({success_count}/3)")
+                
+            finally:
+                # 清理临时目录
+                if os.path.exists(temp_dir):
+                    shutil.rmtree(temp_dir, ignore_errors=True)
+                
+        except Exception as e:
+            result_info["error"] = str(e)
+            
+        finally:
+            result_info["processing_time"] = time.time() - start_time
+            
+        return result_info
+
+
+def process_images_single_process(image_paths: List[str],
+                                processor: DotsOCRProcessor,
+                                batch_size: int = 1,
+                                output_dir: str = "./output") -> List[Dict[str, Any]]:
+    """
+    单进程版本的图像处理函数
+    
+    Args:
+        image_paths: 图像路径列表
+        processor: DotsOCR处理器实例
+        batch_size: 批处理大小
+        output_dir: 输出目录
+        
+    Returns:
+        处理结果列表
+    """
+    # 创建输出目录
+    output_path = Path(output_dir)
+    output_path.mkdir(parents=True, exist_ok=True)
+    
+    all_results = []
+    total_images = len(image_paths)
+    
+    print(f"Processing {total_images} images with batch size {batch_size}")
+    
+    # 使用tqdm显示进度,添加更多统计信息
+    with tqdm(total=total_images, desc="Processing images", unit="img", 
+              bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
+        
+        # 按批次处理图像(DotsOCR通常单张处理)
+        for i in range(0, total_images, batch_size):
+            batch = image_paths[i:i + batch_size]
+            batch_start_time = time.time()
+            batch_results = []
+            
+            try:
+                # 处理批次中的每张图片
+                for image_path in batch:
+                    try:
+                        result = processor.process_single_image(image_path, output_dir)
+                        batch_results.append(result)
+                        
+                    except Exception as e:
+                        print(f"Error processing {image_path}: {e}", file=sys.stderr)
+                        traceback.print_exc()
+                        
+                        batch_results.append({
+                            "image_path": image_path,
+                            "processing_time": 0,
+                            "success": False,
+                            "device": f"{processor.ip}:{processor.port}",
+                            "error": str(e)
+                        })
+                
+                batch_processing_time = time.time() - batch_start_time
+                all_results.extend(batch_results)
+                
+                # 更新进度条
+                success_count = sum(1 for r in batch_results if r.get('success', False))
+                skipped_count = sum(1 for r in batch_results if r.get('skipped', False))
+                total_success = sum(1 for r in all_results if r.get('success', False))
+                total_skipped = sum(1 for r in all_results if r.get('skipped', False))
+                avg_time = batch_processing_time / len(batch)
+                
+                pbar.update(len(batch))
+                pbar.set_postfix({
+                    'batch_time': f"{batch_processing_time:.2f}s",
+                    'avg_time': f"{avg_time:.2f}s/img",
+                    'success': f"{total_success}/{len(all_results)}",
+                    'skipped': f"{total_skipped}",
+                    'rate': f"{total_success/len(all_results)*100:.1f}%"
+                })
+                
+            except Exception as e:
+                print(f"Error processing batch {[Path(p).name for p in batch]}: {e}", file=sys.stderr)
+                traceback.print_exc()
+                
+                # 为批次中的所有图像添加错误结果
+                error_results = []
+                for img_path in batch:
+                    error_results.append({
+                        "image_path": str(img_path),
+                        "processing_time": 0,
+                        "success": False,
+                        "device": f"{processor.ip}:{processor.port}",
+                        "error": str(e)
+                    })
+                all_results.extend(error_results)
+                pbar.update(len(batch))
+    
+    return all_results
+
+
+def process_images_concurrent(image_paths: List[str],
+                            processor: DotsOCRProcessor,
+                            batch_size: int = 1,
+                            output_dir: str = "./output",
+                            max_workers: int = 3) -> List[Dict[str, Any]]:
+    """并发版本的图像处理函数"""
+    
+    from concurrent.futures import ThreadPoolExecutor, as_completed
+    
+    Path(output_dir).mkdir(parents=True, exist_ok=True)
+    
+    def process_batch(batch_images):
+        """处理一批图像"""
+        batch_results = []
+        for image_path in batch_images:
+            try:
+                result = processor.process_single_image(image_path, output_dir)
+                batch_results.append(result)
+            except Exception as e:
+                batch_results.append({
+                    "image_path": image_path,
+                    "processing_time": 0,
+                    "success": False,
+                    "device": f"{processor.ip}:{processor.port}",
+                    "error": str(e)
+                })
+        return batch_results
+    
+    # 将图像分批
+    batches = [image_paths[i:i + batch_size] for i in range(0, len(image_paths), batch_size)]
+    
+    all_results = []
+    
+    with ThreadPoolExecutor(max_workers=max_workers) as executor:
+        # 提交所有批次
+        future_to_batch = {executor.submit(process_batch, batch): batch for batch in batches}
+        
+        # 使用 tqdm 显示进度
+        with tqdm(total=len(image_paths), desc="Processing images") as pbar:
+            for future in as_completed(future_to_batch):
+                try:
+                    batch_results = future.result()
+                    all_results.extend(batch_results)
+                    
+                    # 更新进度
+                    success_count = sum(1 for r in batch_results if r.get('success', False))
+                    pbar.update(len(batch_results))
+                    pbar.set_postfix({'batch_success': f"{success_count}/{len(batch_results)}"})
+                    
+                except Exception as e:
+                    batch = future_to_batch[future]
+                    # 为批次中的所有图像添加错误结果
+                    error_results = [
+                        {
+                            "image_path": img_path,
+                            "processing_time": 0,
+                            "success": False,
+                            "device": f"{processor.ip}:{processor.port}",
+                            "error": str(e)
+                        }
+                        for img_path in batch
+                    ]
+                    all_results.extend(error_results)
+                    pbar.update(len(batch))
+    
+    return all_results
+
+
+def main():
+    """主函数"""
+    parser = argparse.ArgumentParser(description="DotsOCR OmniDocBench Processing")
+    
+    # 输入参数组
+    input_group = parser.add_mutually_exclusive_group(required=True)
+    input_group.add_argument("--input_dir", type=str, help="Input directory")
+    input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")
+    input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path and status columns")
+
+    # 输出参数
+    parser.add_argument("--output_dir", type=str, help="Output directory")
+    
+    # DotsOCR 参数
+    parser.add_argument("--ip", type=str, default="127.0.0.1", help="vLLM server IP")
+    parser.add_argument("--port", type=int, default=8101, help="vLLM server port")
+    parser.add_argument("--model_name", type=str, default="DotsOCR", help="Model name")
+    parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en", 
+                       choices=list(dict_promptmode_to_prompt.keys()), help="Prompt mode")
+    parser.add_argument("--min_pixels", type=int, default=MIN_PIXELS, help="Minimum pixels")
+    parser.add_argument("--max_pixels", type=int, default=MAX_PIXELS, help="Maximum pixels")
+    parser.add_argument("--dpi", type=int, default=200, help="PDF processing DPI")
+    
+    # 处理参数
+    parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
+    parser.add_argument("--input_pattern", type=str, default="*", help="Input file pattern")
+    parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 10 images)")
+    parser.add_argument("--collect_results", type=str, help="收集处理结果到指定CSV文件")
+
+    # 并发参数
+    parser.add_argument("--max_workers", type=int, default=3, 
+                       help="Maximum number of concurrent workers (should match vLLM data-parallel-size)")
+    parser.add_argument("--use_threading", action="store_true", 
+                       help="Use multi-threading")
+    
+    args = parser.parse_args()
+    
+    try:
+        # 获取图像文件列表
+        if args.input_csv:
+            # 从CSV文件读取
+            image_files = get_image_files_from_csv(args.input_csv, "fail")
+            print(f"📊 Loaded {len(image_files)} files from CSV with status filter: fail")
+        elif args.input_file_list:
+            # 从文件列表读取
+            image_files = get_image_files_from_list(args.input_file_list)
+        else:
+            # 从目录读取
+            input_dir = Path(args.input_dir).resolve()
+            print(f"📁 Input dir: {input_dir}")
+            
+            if not input_dir.exists():
+                print(f"❌ Input directory does not exist: {input_dir}")
+                return 1
+
+            image_files = get_image_files_from_dir(input_dir, args.input_pattern)
+
+        output_dir = Path(args.output_dir).resolve()
+        print(f"📁 Output dir: {output_dir}")
+        print(f"📊 Found {len(image_files)} image files")
+        
+        if args.test_mode:
+            image_files = image_files[:10]
+            print(f"🧪 Test mode: processing only {len(image_files)} images")
+        
+        print(f"🌐 Using server: {args.ip}:{args.port}")
+        print(f"📦 Batch size: {args.batch_size}")
+        print(f"🎯 Prompt mode: {args.prompt_mode}")
+        
+        # 创建处理器
+        processor = DotsOCRProcessor(
+            ip=args.ip,
+            port=args.port,
+            model_name=args.model_name,
+            prompt_mode=args.prompt_mode,
+            dpi=args.dpi,
+            min_pixels=args.min_pixels,
+            max_pixels=args.max_pixels
+        )
+        
+        # 开始处理
+        start_time = time.time()
+        
+        # 选择处理方式
+        if args.use_threading:
+            results = process_images_concurrent(
+                image_files,
+                processor,
+                args.batch_size,
+                str(output_dir),
+                args.max_workers
+            )
+        else:
+            results = process_images_single_process(
+                image_files,
+                processor,
+                args.batch_size,
+                str(output_dir)
+            )
+        
+        total_time = time.time() - start_time
+        
+        # 统计结果
+        success_count = sum(1 for r in results if r.get('success', False))
+        skipped_count = sum(1 for r in results if r.get('skipped', False))
+        error_count = len(results) - success_count
+        
+        print(f"\n" + "="*60)
+        print(f"✅ Processing completed!")
+        print(f"📊 Statistics:")
+        print(f"  Total files: {len(image_files)}")
+        print(f"  Successful: {success_count}")
+        print(f"  Skipped: {skipped_count}")
+        print(f"  Failed: {error_count}")
+        if len(image_files) > 0:
+            print(f"  Success rate: {success_count / len(image_files) * 100:.2f}%")
+        print(f"⏱️ Performance:")
+        print(f"  Total time: {total_time:.2f} seconds")
+        if total_time > 0:
+            print(f"  Throughput: {len(image_files) / total_time:.2f} images/second")
+            print(f"  Avg time per image: {total_time / len(image_files):.2f} seconds")
+        
+        # 保存结果统计
+        stats = {
+            "total_files": len(image_files),
+            "success_count": success_count,
+            "skipped_count": skipped_count,
+            "error_count": error_count,
+            "success_rate": success_count / len(image_files) if len(image_files) > 0 else 0,
+            "total_time": total_time,
+            "throughput": len(image_files) / total_time if total_time > 0 else 0,
+            "avg_time_per_image": total_time / len(image_files) if len(image_files) > 0 else 0,
+            "batch_size": args.batch_size,
+            "server": f"{args.ip}:{args.port}",
+            "model": args.model_name,
+            "prompt_mode": args.prompt_mode,
+            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
+        }
+        
+        # 保存最终结果
+        output_file_name = Path(output_dir).name
+        output_file = os.path.join(output_dir, f"{output_file_name}_results.json")
+        final_results = {
+            "stats": stats,
+            "results": results
+        }
+        
+        with open(output_file, 'w', encoding='utf-8') as f:
+            json.dump(final_results, f, ensure_ascii=False, indent=2)
+        
+        print(f"💾 Results saved to: {output_file}")
+
+        # 收集处理结果
+        if args.collect_results:
+            processed_files = collect_pid_files(output_file)
+            output_file_processed = Path(args.collect_results).resolve()
+            with open(output_file_processed, 'w', encoding='utf-8') as f:
+                f.write("image_path,status\n")
+                for file_path, status in processed_files:
+                    f.write(f"{file_path},{status}\n")
+            print(f"💾 Processed files saved to: {output_file_processed}")
+
+        return 0
+        
+    except Exception as e:
+        print(f"❌ Processing failed: {e}", file=sys.stderr)
+        traceback.print_exc()
+        return 1
+
+
+if __name__ == "__main__":
+    print(f"🚀 启动DotsOCR单进程程序...")
+    print(f"🔧 CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
+    
+    if len(sys.argv) == 1:
+        # 如果没有命令行参数,使用默认配置运行
+        print("ℹ️  No command line arguments provided. Running with default configuration...")
+        
+        # 默认配置
+        default_config = {
+            "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
+            "output_dir": "./OmniDocBench_DotsOCR_Results",
+            "ip": "10.192.72.11",
+            "port": "8101",
+            "model_name": "DotsOCR",
+            "prompt_mode": "prompt_layout_all_en",
+            "batch_size": "1",
+            "max_workers": "3",
+            "collect_results": "./OmniDocBench_DotsOCR_Results/processed_files.csv",
+        }
+        
+        # 如果需要处理失败的文件,可以使用这个配置
+        # default_config = {
+        #     "input_csv": "./OmniDocBench_DotsOCR_Results/processed_files.csv",
+        #     "output_dir": "./OmniDocBench_DotsOCR_Results",
+        #     "ip": "127.0.0.1",
+        #     "port": "8101",
+        #     "collect_results": f"./OmniDocBench_DotsOCR_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
+        # }
+        
+        # 构造参数
+        sys.argv = [sys.argv[0]]
+        for key, value in default_config.items():
+            sys.argv.extend([f"--{key}", str(value)])
+        
+        # 测试模式
+        sys.argv.append("--use_threading")
+        # sys.argv.append("--test_mode")
+    
+    sys.exit(main())