Эх сурвалжийг харах

refactor: 移除未使用的函数和导入,优化代码结构

zhch158_admin 3 сар өмнө
parent
commit
7aaf06c23e

+ 11 - 129
zhch/ppstructurev3_scheduler.py

@@ -1,3 +1,7 @@
+"""
+多GPU多进程推理始终有问题,多个进程启动后,paddle底层报错
+目前无法定位原因
+"""
 import json
 import time
 import os
@@ -12,135 +16,14 @@ import threading
 from queue import Queue
 from tqdm import tqdm
 
-def split_files(file_list: List[str], num_splits: int) -> List[List[str]]:
-    """
-    将文件列表分割成指定数量的子列表
-    
-    Args:
-        file_list: 文件路径列表
-        num_splits: 分割数量
-        
-    Returns:
-        分割后的文件列表
-    """
-    if num_splits <= 0:
-        return [file_list]
-    
-    chunk_size = len(file_list) // num_splits
-    remainder = len(file_list) % num_splits
-    
-    chunks = []
-    start = 0
-    
-    for i in range(num_splits):
-        # 前remainder个chunk多分配一个文件
-        current_chunk_size = chunk_size + (1 if i < remainder else 0)
-        if current_chunk_size > 0:
-            chunks.append(file_list[start:start + current_chunk_size])
-            start += current_chunk_size
-    
-    return [chunk for chunk in chunks if chunk]  # 过滤空列表
-
-def create_temp_file_list(file_chunk: List[str]) -> str:
-    """
-    创建临时文件列表文件
-    
-    Args:
-        file_chunk: 文件路径列表
-        
-    Returns:
-        临时文件路径
-    """
-    with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
-        for file_path in file_chunk:
-            f.write(f"{file_path}\n")
-        return f.name
-
-def get_image_files_from_dir(input_dir: Path, max_files: int = None) -> List[str]:
-    """
-    从目录获取图像文件列表
-    
-    Args:
-        input_dir: 输入目录
-        max_files: 最大文件数量限制
-        
-    Returns:
-        图像文件路径列表
-    """
-    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
-    image_files = []
-    
-    for ext in image_extensions:
-        image_files.extend(list(input_dir.glob(f"*{ext}")))
-        image_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
-    
-    # 去重并排序
-    image_files = sorted(list(set(str(f) for f in image_files)))
-    
-    # 限制文件数量
-    if max_files:
-        image_files = image_files[:max_files]
-    
-    return image_files
+from utils import (
+    get_image_files_from_dir,
+    get_image_files_from_list,
+    get_image_files_from_csv,
+    split_files,
+    create_temp_file_list
+)
 
-def get_image_files_from_list(file_list_path: str) -> List[str]:
-    """
-    从文件列表获取图像文件列表
-    
-    Args:
-        file_list_path: 文件列表路径
-        
-    Returns:
-        图像文件路径列表
-    """
-    print(f"📄 Reading file list from: {file_list_path}")
-    
-    with open(file_list_path, 'r', encoding='utf-8') as f:
-        image_files = [line.strip() for line in f if line.strip()]
-    
-    # 验证文件存在性
-    valid_files = []
-    missing_files = []
-    
-    for file_path in image_files:
-        if Path(file_path).exists():
-            valid_files.append(file_path)
-        else:
-            missing_files.append(file_path)
-    
-    if missing_files:
-        print(f"⚠️ Warning: {len(missing_files)} files not found:")
-        for missing_file in missing_files[:5]:  # 只显示前5个
-            print(f"  - {missing_file}")
-        if len(missing_files) > 5:
-            print(f"  ... and {len(missing_files) - 5} more")
-    
-    print(f"✅ Found {len(valid_files)} valid files out of {len(image_files)} in list")
-    return valid_files
-
-def get_image_files_from_csv(csv_file: str, status_filter: str = "fail") -> List[str]:
-    """
-    从CSV文件获取图像文件列表
-
-    Args:
-        csv_file: CSV文件路径
-        status_filter: 状态过滤器
-
-    Returns:
-        图像文件路径列表
-    """
-    print(f"📄 Reading image files from CSV: {csv_file}")
-
-	# 读取CSV文件, 表头:image_path,status
-    image_files = []
-    with open(csv_file, 'r', encoding='utf-8') as f:
-        for line in f:
-            # 需要去掉表头, 按“,”分割,读取文件名,状态
-            image_file, status = line.strip().split(",")
-            if status.lower() == status_filter.lower():
-                image_files.append(image_file)
-
-    return image_files
 
 def collect_pid_files(pid_output_file: str) -> List[Tuple[str, str]]:
     """
@@ -351,7 +234,6 @@ def main():
     parser = argparse.ArgumentParser(description="PaddleX PP-StructureV3 Multi-Process Scheduler")
     
     # 输入输出参数
-    # 输入输出参数
     input_group = parser.add_mutually_exclusive_group(required=True)
     input_group.add_argument("--input_dir", type=str, help="Input directory")
     input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")