Browse Source

feat: 更新工具模块,新增PDF转换为图像功能,优化CUDA检测及Flash Attention状态检查

zhch158_admin 1 month ago
parent
commit
0be5cc4398
4 changed files with 168 additions and 11 deletions
  1. 2 0
      zhch/utils/__init__.py
  2. 6 6
      zhch/utils/cuda_utils.py
  3. 120 5
      zhch/utils/doc_utils.py
  4. 40 0
      zhch/utils/verify_flash_attention.py

+ 2 - 0
zhch/utils/__init__.py

@@ -20,6 +20,7 @@ from .file_utils import (
     split_files,
     create_temp_file_list,
     collect_pid_files,
+    get_input_files,
 )
 
 from .doc_utils import (
@@ -46,6 +47,7 @@ __all__ = [
     'split_files',
     'create_temp_file_list',
     'collect_pid_files',
+    'get_input_files',
     # 金融数字标准化
     'normalize_financial_numbers',
     'normalize_markdown_table',

+ 6 - 6
zhch/utils/cuda_utils.py

@@ -1,9 +1,9 @@
-import paddle
+import torch
 from typing import List
 def detect_available_gpus() -> List[int]:
     """检测可用的GPU"""
     try:
-        gpu_count = paddle.device.cuda.device_count()
+        gpu_count = torch.cuda.device_count()
         available_gpus = list(range(gpu_count))
         print(f"检测到 {gpu_count} 个可用GPU: {available_gpus}")
         return available_gpus
@@ -15,10 +15,10 @@ def monitor_gpu_memory(gpu_ids: List[int] = [0, 1]):
     """监控GPU内存使用情况"""
     try:
         for gpu_id in gpu_ids:
-            paddle.device.set_device(f"gpu:{gpu_id}")
-            total = paddle.device.cuda.get_device_properties(gpu_id).total_memory / 1024**3
-            allocated = paddle.device.cuda.memory_allocated() / 1024**3
-            reserved = paddle.device.cuda.memory_reserved() / 1024**3
+            torch.cuda.set_device(gpu_id)
+            total = torch.cuda.get_device_properties(gpu_id).total_memory / 1024**3
+            allocated = torch.cuda.memory_allocated(gpu_id) / 1024**3
+            reserved = torch.cuda.memory_reserved(gpu_id) / 1024**3
             print(f"GPU {gpu_id} - 显存: {total:.2f}GB, 已分配: {allocated:.2f}GB, 已预留: {reserved:.2f}GB")
     except Exception as e:
         print(f"GPU内存监控失败: {e}")

+ 120 - 5
zhch/utils/doc_utils.py

@@ -2,6 +2,8 @@ import tempfile
 from pathlib import Path
 from typing import List, Tuple
 import json
+from .doc_utils import load_images_from_pdf
+import traceback
 
 def split_files(file_list: List[str], num_splits: int) -> List[List[str]]:
     """
@@ -47,7 +49,7 @@ def create_temp_file_list(file_chunk: List[str]) -> str:
             f.write(f"{file_path}\n")
         return f.name
 
-def get_image_files_from_dir(input_dir: Path, max_files: int = None) -> List[str]:
+def get_image_files_from_dir(input_dir: Path, pattern: str = "*", max_files: int = None) -> List[str]:
     """
     从目录获取图像文件列表
     
@@ -62,9 +64,9 @@ def get_image_files_from_dir(input_dir: Path, max_files: int = None) -> List[str
     image_files = []
     
     for ext in image_extensions:
-        image_files.extend(list(input_dir.glob(f"*{ext}")))
-        image_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
-    
+        image_files.extend(list(input_dir.glob(f"{pattern}{ext}")))
+        image_files.extend(list(input_dir.glob(f"{pattern}{ext.upper()}")))
+
     # 去重并排序
     image_files = sorted(list(set(str(f) for f in image_files)))
     
@@ -174,4 +176,117 @@ def collect_pid_files(pid_output_file: str) -> List[Tuple[str, str]]:
         image_path = file_result.get("image_path", "")
         status = "success" if file_result.get("success", False) else "fail"
         file_list.append((image_path, status))
-    return file_list
+    return file_list
+
+def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]:
+    """
+    将PDF转换为图像文件
+    
+    Args:
+        pdf_file: PDF文件路径
+        output_dir: 输出目录
+        dpi: 图像分辨率
+        
+    Returns:
+        生成的图像文件路径列表
+    """
+    pdf_path = Path(pdf_file)
+    if not pdf_path.exists() or pdf_path.suffix.lower() != '.pdf':
+        print(f"❌ Invalid PDF file: {pdf_path}")
+        return []
+
+    # 如果没有指定输出目录,使用PDF同名目录
+    if output_dir is None:
+        output_path = pdf_path.parent / f"{pdf_path.stem}"
+    else:
+        output_path = Path(output_dir) / f"{pdf_path.stem}"
+    output_path = output_path.resolve()
+    output_path.mkdir(parents=True, exist_ok=True)
+
+    try:
+        # 使用doc_utils中的函数加载PDF图像
+        images = load_images_from_pdf(str(pdf_path), dpi=dpi)
+        
+        image_paths = []
+        for i, image in enumerate(images):
+            # 生成图像文件名
+            image_filename = f"{pdf_path.stem}_page_{i+1:03d}.png"
+            image_path = output_path / image_filename
+
+            # 保存图像
+            image.save(str(image_path))
+            image_paths.append(str(image_path))
+            
+        print(f"✅ Converted {len(images)} pages from {pdf_path.name} to images")
+        return image_paths
+        
+    except Exception as e:
+        print(f"❌ Error converting PDF {pdf_path}: {e}")
+        traceback.print_exc()
+        return []
+
+def get_input_files(args) -> List[str]:
+    """
+    获取输入文件列表,统一处理PDF和图像文件
+    
+    Args:
+        args: 命令行参数
+        
+    Returns:
+        处理后的图像文件路径列表
+    """
+    input_files = []
+    
+    # 获取原始输入文件
+    if args.input_csv:
+        raw_files = get_image_files_from_csv(args.input_csv, "fail")
+    elif args.input_file_list:
+        raw_files = get_image_files_from_list(args.input_file_list)
+    elif args.input_file:
+        raw_files = [Path(args.input_file).resolve()]
+    else:
+        input_dir = Path(args.input_dir).resolve()
+        if not input_dir.exists():
+            print(f"❌ Input directory does not exist: {input_dir}")
+            return []
+        
+        # 获取所有支持的文件(图像和PDF)
+        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
+        pdf_extensions = ['.pdf']
+        
+        raw_files = []
+        for ext in image_extensions + pdf_extensions:
+            raw_files.extend(list(input_dir.glob(f"*{ext}")))
+            raw_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
+        
+        raw_files = [str(f) for f in raw_files]
+    
+    # 分别处理PDF和图像文件
+    pdf_count = 0
+    image_count = 0
+    
+    for file_path in raw_files:
+        file_path = Path(file_path)
+        
+        if file_path.suffix.lower() == '.pdf':
+            # 转换PDF为图像
+            print(f"📄 Processing PDF: {file_path.name}")
+            pdf_images = convert_pdf_to_images(
+                str(file_path), 
+                args.output_dir,
+                dpi=args.pdf_dpi
+            )
+            input_files.extend(pdf_images)
+            pdf_count += 1
+        else:
+            # 直接添加图像文件
+            if file_path.exists():
+                input_files.append(str(file_path))
+                image_count += 1
+    
+    print(f"📊 Input summary:")
+    print(f"  PDF files processed: {pdf_count}")
+    print(f"  Image files found: {image_count}")
+    print(f"  Total image files to process: {len(input_files)}")
+    
+    return sorted(list(set(str(f) for f in input_files)))

+ 40 - 0
zhch/utils/verify_flash_attention.py

@@ -0,0 +1,40 @@
+import torch
+import subprocess
+import pkg_resources
+
+def check_flash_attention():
+    print("🔍 Flash Attention 状态检查")
+    print("=" * 50)
+    
+    # 检查已安装的包
+    try:
+        flash_attn_version = pkg_resources.get_distribution("flash-attn").version
+        print(f"✅ flash-attn: {flash_attn_version}")
+    except:
+        print("❌ flash-attn: 未安装")
+    
+    try:
+        flashinfer_version = pkg_resources.get_distribution("flashinfer").version
+        print(f"✅ flashinfer: {flashinfer_version}")
+    except:
+        print("❌ flashinfer: 未安装")
+    
+    # 检查 CUDA 可用性
+    print(f"\n🔧 CUDA 状态:")
+    print(f"CUDA 可用: {torch.cuda.is_available()}")
+    if torch.cuda.is_available():
+        print(f"CUDA 版本: {torch.version.cuda}")
+        print(f"GPU 数量: {torch.cuda.device_count()}")
+        for i in range(torch.cuda.device_count()):
+            print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
+    
+    # 检查 Flash Attention 功能
+    try:
+        import flash_attn
+        print(f"\n✅ Flash Attention 可导入")
+        print(f"Flash Attention 版本: {flash_attn.__version__}")
+    except ImportError as e:
+        print(f"\n❌ Flash Attention 导入失败: {e}")
+
+if __name__ == "__main__":
+    check_flash_attention()