ソースを参照

feat(zhch): 优化图像处理功能,添加警告抑制和进度统计信息

zhch158_admin 3 ヶ月 前
コミット
727ae8d328
1 ファイル変更47 行追加18 行削除
  1. 47 18
      zhch/ppstructurev3_single_process.py

+ 47 - 18
zhch/ppstructurev3_single_process.py

@@ -4,10 +4,17 @@ import os
 import traceback
 import argparse
 import sys
+import warnings
 from pathlib import Path
 from typing import List, Dict, Any
 import cv2
 import numpy as np
+
+# 抑制特定警告
+warnings.filterwarnings("ignore", message="To copy construct from a tensor")
+warnings.filterwarnings("ignore", message="Setting `pad_token_id`")
+warnings.filterwarnings("ignore", category=UserWarning, module="paddlex")
+
 from paddlex import create_pipeline
 from paddlex.utils.device import constr_device, parse_device
 from tqdm import tqdm
@@ -41,6 +48,9 @@ def process_images_single_process(image_paths: List[str],
     print(f"Initializing pipeline '{pipeline_name}' on device '{device}'...")
     
     try:
+        # 设置环境变量以减少警告
+        os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning'
+        
         # 初始化pipeline
         pipeline = create_pipeline(pipeline_name, device=device)
         print(f"Pipeline initialized successfully on {device}")
@@ -55,8 +65,10 @@ def process_images_single_process(image_paths: List[str],
     
     print(f"Processing {total_images} images with batch size {batch_size}")
     
-    # 使用tqdm显示进度
-    with tqdm(total=total_images, desc="Processing images", unit="img") as pbar:
+    # 使用tqdm显示进度,添加更多统计信息
+    with tqdm(total=total_images, desc="Processing images", unit="img", 
+              bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
+        
         # 按批次处理图像
         for i in range(0, total_images, batch_size):
             batch = image_paths[i:i + batch_size]
@@ -120,11 +132,15 @@ def process_images_single_process(image_paths: List[str],
                 
                 # 更新进度条
                 success_count = sum(1 for r in batch_results if r.get('success', False))
+                total_success = sum(1 for r in all_results if r.get('success', False))
+                avg_time = batch_processing_time / len(batch)
+                
                 pbar.update(len(batch))
                 pbar.set_postfix({
                     'batch_time': f"{batch_processing_time:.2f}s",
-                    'batch_success': f"{success_count}/{len(batch)}",
-                    'total_success': f"{sum(1 for r in all_results if r.get('success', False))}/{len(all_results)}"
+                    'avg_time': f"{avg_time:.2f}s/img",
+                    'success': f"{total_success}/{len(all_results)}",
+                    'rate': f"{total_success/len(all_results)*100:.1f}%"
                 })
                 
             except Exception as e:
@@ -185,7 +201,8 @@ def main():
             print(f"No image files found in {input_dir}")
             return 1
         
-        image_files = [str(f) for f in image_files]
+        # 去重并排序
+        image_files = sorted(list(set(str(f) for f in image_files)))
         print(f"Found {len(image_files)} image files")
         
         if args.test_mode:
@@ -205,6 +222,13 @@ def main():
                     if device_id >= gpu_count:
                         print(f"GPU {device_id} not available (only {gpu_count} GPUs), falling back to GPU 0")
                         args.device = "gpu:0"
+                    
+                    # 显示GPU信息
+                    if args.verbose:
+                        for i in range(gpu_count):
+                            props = paddle.device.cuda.get_device_properties(i)
+                            print(f"GPU {i}: {props.name} - {props.total_memory // 1024**3}GB")
+                        
             except Exception as e:
                 print(f"Error checking GPU availability: {e}, falling back to CPU")
                 args.device = "cpu"
@@ -227,16 +251,19 @@ def main():
         success_count = sum(1 for r in results if r.get('success', False))
         error_count = len(results) - success_count
         
-        print(f"\n" + "="*50)
-        print(f"Processing completed!")
-        print(f"Total files: {len(image_files)}")
-        print(f"Successful: {success_count}")
-        print(f"Failed: {error_count}")
+        print(f"\n" + "="*60)
+        print(f"✅ Processing completed!")
+        print(f"📊 Statistics:")
+        print(f"  Total files: {len(image_files)}")
+        print(f"  Successful: {success_count}")
+        print(f"  Failed: {error_count}")
         if len(image_files) > 0:
-            print(f"Success rate: {success_count / len(image_files) * 100:.2f}%")
-        print(f"Total time: {total_time:.2f} seconds")
+            print(f"  Success rate: {success_count / len(image_files) * 100:.2f}%")
+        print(f"⏱️ Performance:")
+        print(f"  Total time: {total_time:.2f} seconds")
         if total_time > 0:
-            print(f"Throughput: {len(image_files) / total_time:.2f} images/second")
+            print(f"  Throughput: {len(image_files) / total_time:.2f} images/second")
+            print(f"  Avg time per image: {total_time / len(image_files):.2f} seconds")
         
         # 保存结果统计
         stats = {
@@ -246,9 +273,11 @@ def main():
             "success_rate": success_count / len(image_files) if len(image_files) > 0 else 0,
             "total_time": total_time,
             "throughput": len(image_files) / total_time if total_time > 0 else 0,
+            "avg_time_per_image": total_time / len(image_files) if len(image_files) > 0 else 0,
             "batch_size": args.batch_size,
             "device": args.device,
-            "pipeline": args.pipeline
+            "pipeline": args.pipeline,
+            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
         }
         
         # 保存最终结果
@@ -261,23 +290,23 @@ def main():
         with open(output_file, 'w', encoding='utf-8') as f:
             json.dump(final_results, f, ensure_ascii=False, indent=2)
         
-        print(f"Results saved to: {output_file}")
+        print(f"💾 Results saved to: {output_file}")
         
         return 0
         
     except Exception as e:
-        print(f"Processing failed: {e}", file=sys.stderr)
+        print(f"Processing failed: {e}", file=sys.stderr)
         traceback.print_exc()
         return 1
 
 
 if __name__ == "__main__":
     print(f"🚀 启动单进程OCR程序...")
-    print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
+    print(f"🔧 CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
     
     if len(sys.argv) == 1:
         # 如果没有命令行参数,使用默认配置运行
-        print("No command line arguments provided. Running with default configuration...")
+        print("ℹ️  No command line arguments provided. Running with default configuration...")
         
         # 默认配置
         default_config = {