Browse Source

feat: 添加工具模块,包含CUDA环境检测、文件处理等功能

zhch158_admin 3 months ago
parent
commit
252e4c292c
4 changed files with 251 additions and 0 deletions
  1. 40 0
      zhch/utils/__init__.py
  2. 54 0
      zhch/utils/check_cuda_env.py
  3. 24 0
      zhch/utils/cuda_utils.py
  4. 133 0
      zhch/utils/file_utils.py

+ 40 - 0
zhch/utils/__init__.py

@@ -0,0 +1,40 @@
+"""
+工具模块包
+
+包含CUDA环境检测、文件处理等实用工具函数
+"""
+
+from .check_cuda_env import (
+    check_nvidia_environment,
+)
+
+from .cuda_utils import (
+    monitor_gpu_memory,
+    detect_available_gpus
+)
+
+from .file_utils import (
+    get_image_files_from_dir,
+    get_image_files_from_list,
+    get_image_files_from_csv,
+    split_files,
+    create_temp_file_list
+)
+
+__all__ = [
+    # CUDA环境检测
+    'check_nvidia_environment',
+    # CUDA工具
+    'monitor_gpu_memory',
+    'detect_available_gpus',
+    # 文件工具
+    'get_image_files_from_dir',
+    'get_image_files_from_list',
+    'get_image_files_from_csv',
+    'split_files',
+    'create_temp_file_list'
+]
+
+__version__ = "1.0.0"
+__author__ = "zhch158"
+__description__ = "PaddleX工具模块包,提供CUDA环境检测和文件处理功能"

+ 54 - 0
zhch/utils/check_cuda_env.py

@@ -0,0 +1,54 @@
+import subprocess
+import sys
+
+def check_nvidia_environment():
+    print("=== NVIDIA环境检查 ===")
+    
+    # 检查nvidia-smi
+    try:
+        result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=10)
+        if result.returncode == 0:
+            print("✓ nvidia-smi 正常运行")
+            # 提取驱动版本
+            lines = result.stdout.split('\n')
+            for line in lines:
+                if 'Driver Version:' in line:
+                    print(f"  {line.strip()}")
+                    break
+        else:
+            print("✗ nvidia-smi 失败:")
+            print(result.stderr)
+    except Exception as e:
+        print(f"✗ nvidia-smi 错误: {e}")
+    
+    # 检查NVML
+    try:
+        import pynvml
+        pynvml.nvmlInit()
+        driver_version = pynvml.nvmlSystemGetDriverVersion()
+        print(f"✓ NVML初始化成功,驱动版本: {driver_version}")
+        
+        device_count = pynvml.nvmlDeviceGetCount()
+        print(f"✓ 检测到 {device_count} 个GPU设备")
+        
+        for i in range(device_count):
+            handle = pynvml.nvmlDeviceGetHandleByIndex(i)
+            name = pynvml.nvmlDeviceGetName(handle)
+            print(f"  GPU {i}: {name}")
+            
+    except Exception as e:
+        print(f"✗ NVML错误: {e}")
+    
+    # 检查PaddlePaddle
+    try:
+        import paddle
+        print(f"✓ PaddlePaddle版本: {paddle.__version__}")
+        print(f"✓ CUDA编译支持: {paddle.device.is_compiled_with_cuda()}")
+        if paddle.device.is_compiled_with_cuda():
+            gpu_count = paddle.device.cuda.device_count()
+            print(f"✓ PaddlePaddle检测到 {gpu_count} 个GPU")
+    except Exception as e:
+        print(f"✗ PaddlePaddle错误: {e}")
+
+if __name__ == "__main__":
+    check_nvidia_environment()

+ 24 - 0
zhch/utils/cuda_utils.py

@@ -0,0 +1,24 @@
+import paddle
+from typing import List
+def detect_available_gpus() -> List[int]:
+    """检测可用的GPU"""
+    try:
+        gpu_count = paddle.device.cuda.device_count()
+        available_gpus = list(range(gpu_count))
+        print(f"检测到 {gpu_count} 个可用GPU: {available_gpus}")
+        return available_gpus
+    except Exception as e:
+        print(f"GPU检测失败: {e}")
+        return []
+
+def monitor_gpu_memory(gpu_ids: List[int] = [0, 1]):
+    """监控GPU内存使用情况"""
+    try:
+        for gpu_id in gpu_ids:
+            paddle.device.set_device(f"gpu:{gpu_id}")
+            total = paddle.device.cuda.get_device_properties(gpu_id).total_memory / 1024**3
+            allocated = paddle.device.cuda.memory_allocated() / 1024**3
+            reserved = paddle.device.cuda.memory_reserved() / 1024**3
+            print(f"GPU {gpu_id} - 显存: {total:.2f}GB, 已分配: {allocated:.2f}GB, 已预留: {reserved:.2f}GB")
+    except Exception as e:
+        print(f"GPU内存监控失败: {e}")

+ 133 - 0
zhch/utils/file_utils.py

@@ -0,0 +1,133 @@
+import tempfile
+from pathlib import Path
+from typing import List
+
+def split_files(file_list: List[str], num_splits: int) -> List[List[str]]:
+    """
+    将文件列表分割成指定数量的子列表
+    
+    Args:
+        file_list: 文件路径列表
+        num_splits: 分割数量
+        
+    Returns:
+        分割后的文件列表
+    """
+    if num_splits <= 0:
+        return [file_list]
+    
+    chunk_size = len(file_list) // num_splits
+    remainder = len(file_list) % num_splits
+    
+    chunks = []
+    start = 0
+    
+    for i in range(num_splits):
+        # 前remainder个chunk多分配一个文件
+        current_chunk_size = chunk_size + (1 if i < remainder else 0)
+        if current_chunk_size > 0:
+            chunks.append(file_list[start:start + current_chunk_size])
+            start += current_chunk_size
+    
+    return [chunk for chunk in chunks if chunk]  # 过滤空列表
+
+def create_temp_file_list(file_chunk: List[str]) -> str:
+    """
+    创建临时文件列表文件
+    
+    Args:
+        file_chunk: 文件路径列表
+        
+    Returns:
+        临时文件路径
+    """
+    with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
+        for file_path in file_chunk:
+            f.write(f"{file_path}\n")
+        return f.name
+
+def get_image_files_from_dir(input_dir: Path, max_files: int = None) -> List[str]:
+    """
+    从目录获取图像文件列表
+    
+    Args:
+        input_dir: 输入目录
+        max_files: 最大文件数量限制
+        
+    Returns:
+        图像文件路径列表
+    """
+    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
+    image_files = []
+    
+    for ext in image_extensions:
+        image_files.extend(list(input_dir.glob(f"*{ext}")))
+        image_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
+    
+    # 去重并排序
+    image_files = sorted(list(set(str(f) for f in image_files)))
+    
+    # 限制文件数量
+    if max_files:
+        image_files = image_files[:max_files]
+    
+    return image_files
+
+def get_image_files_from_list(file_list_path: str) -> List[str]:
+    """
+    从文件列表获取图像文件列表
+    
+    Args:
+        file_list_path: 文件列表路径
+        
+    Returns:
+        图像文件路径列表
+    """
+    print(f"📄 Reading file list from: {file_list_path}")
+    
+    with open(file_list_path, 'r', encoding='utf-8') as f:
+        image_files = [line.strip() for line in f if line.strip()]
+    
+    # 验证文件存在性
+    valid_files = []
+    missing_files = []
+    
+    for file_path in image_files:
+        if Path(file_path).exists():
+            valid_files.append(file_path)
+        else:
+            missing_files.append(file_path)
+    
+    if missing_files:
+        print(f"⚠️ Warning: {len(missing_files)} files not found:")
+        for missing_file in missing_files[:5]:  # 只显示前5个
+            print(f"  - {missing_file}")
+        if len(missing_files) > 5:
+            print(f"  ... and {len(missing_files) - 5} more")
+    
+    print(f"✅ Found {len(valid_files)} valid files out of {len(image_files)} in list")
+    return valid_files
+
+def get_image_files_from_csv(csv_file: str, status_filter: str = "fail") -> List[str]:
+    """
+    从CSV文件获取图像文件列表
+
+    Args:
+        csv_file: CSV文件路径
+        status_filter: 状态过滤器
+
+    Returns:
+        图像文件路径列表
+    """
+    print(f"📄 Reading image files from CSV: {csv_file}")
+
+	# 读取CSV文件, 表头:image_path,status
+    image_files = []
+    with open(csv_file, 'r', encoding='utf-8') as f:
+        for line in f:
+            # 需要去掉表头, 按“,”分割,读取文件名,状态
+            image_file, status = line.strip().split(",")
+            if status.lower() == status_filter.lower():
+                image_files.append(image_file)
+
+    return image_files