Browse Source

feat(zhch): 优化图像处理功能,添加警告抑制和进度统计信息

zhch158_admin 3 months ago
parent
commit
727ae8d328
1 changed files with 47 additions and 18 deletions
  1. 47 18
      zhch/ppstructurev3_single_process.py

+ 47 - 18
zhch/ppstructurev3_single_process.py

@@ -4,10 +4,17 @@ import os
 import traceback
 import traceback
 import argparse
 import argparse
 import sys
 import sys
+import warnings
 from pathlib import Path
 from pathlib import Path
 from typing import List, Dict, Any
 from typing import List, Dict, Any
 import cv2
 import cv2
 import numpy as np
 import numpy as np
+
+# 抑制特定警告
+warnings.filterwarnings("ignore", message="To copy construct from a tensor")
+warnings.filterwarnings("ignore", message="Setting `pad_token_id`")
+warnings.filterwarnings("ignore", category=UserWarning, module="paddlex")
+
 from paddlex import create_pipeline
 from paddlex import create_pipeline
 from paddlex.utils.device import constr_device, parse_device
 from paddlex.utils.device import constr_device, parse_device
 from tqdm import tqdm
 from tqdm import tqdm
@@ -41,6 +48,9 @@ def process_images_single_process(image_paths: List[str],
     print(f"Initializing pipeline '{pipeline_name}' on device '{device}'...")
     print(f"Initializing pipeline '{pipeline_name}' on device '{device}'...")
     
     
     try:
     try:
+        # 设置环境变量以减少警告
+        os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning'
+        
         # 初始化pipeline
         # 初始化pipeline
         pipeline = create_pipeline(pipeline_name, device=device)
         pipeline = create_pipeline(pipeline_name, device=device)
         print(f"Pipeline initialized successfully on {device}")
         print(f"Pipeline initialized successfully on {device}")
@@ -55,8 +65,10 @@ def process_images_single_process(image_paths: List[str],
     
     
     print(f"Processing {total_images} images with batch size {batch_size}")
     print(f"Processing {total_images} images with batch size {batch_size}")
     
     
-    # 使用tqdm显示进度
-    with tqdm(total=total_images, desc="Processing images", unit="img") as pbar:
+    # 使用tqdm显示进度,添加更多统计信息
+    with tqdm(total=total_images, desc="Processing images", unit="img", 
+              bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
+        
         # 按批次处理图像
         # 按批次处理图像
         for i in range(0, total_images, batch_size):
         for i in range(0, total_images, batch_size):
             batch = image_paths[i:i + batch_size]
             batch = image_paths[i:i + batch_size]
@@ -120,11 +132,15 @@ def process_images_single_process(image_paths: List[str],
                 
                 
                 # 更新进度条
                 # 更新进度条
                 success_count = sum(1 for r in batch_results if r.get('success', False))
                 success_count = sum(1 for r in batch_results if r.get('success', False))
+                total_success = sum(1 for r in all_results if r.get('success', False))
+                avg_time = batch_processing_time / len(batch)
+                
                 pbar.update(len(batch))
                 pbar.update(len(batch))
                 pbar.set_postfix({
                 pbar.set_postfix({
                     'batch_time': f"{batch_processing_time:.2f}s",
                     'batch_time': f"{batch_processing_time:.2f}s",
-                    'batch_success': f"{success_count}/{len(batch)}",
-                    'total_success': f"{sum(1 for r in all_results if r.get('success', False))}/{len(all_results)}"
+                    'avg_time': f"{avg_time:.2f}s/img",
+                    'success': f"{total_success}/{len(all_results)}",
+                    'rate': f"{total_success/len(all_results)*100:.1f}%"
                 })
                 })
                 
                 
             except Exception as e:
             except Exception as e:
@@ -185,7 +201,8 @@ def main():
             print(f"No image files found in {input_dir}")
             print(f"No image files found in {input_dir}")
             return 1
             return 1
         
         
-        image_files = [str(f) for f in image_files]
+        # 去重并排序
+        image_files = sorted(list(set(str(f) for f in image_files)))
         print(f"Found {len(image_files)} image files")
         print(f"Found {len(image_files)} image files")
         
         
         if args.test_mode:
         if args.test_mode:
@@ -205,6 +222,13 @@ def main():
                     if device_id >= gpu_count:
                     if device_id >= gpu_count:
                         print(f"GPU {device_id} not available (only {gpu_count} GPUs), falling back to GPU 0")
                         print(f"GPU {device_id} not available (only {gpu_count} GPUs), falling back to GPU 0")
                         args.device = "gpu:0"
                         args.device = "gpu:0"
+                    
+                    # 显示GPU信息
+                    if args.verbose:
+                        for i in range(gpu_count):
+                            props = paddle.device.cuda.get_device_properties(i)
+                            print(f"GPU {i}: {props.name} - {props.total_memory // 1024**3}GB")
+                        
             except Exception as e:
             except Exception as e:
                 print(f"Error checking GPU availability: {e}, falling back to CPU")
                 print(f"Error checking GPU availability: {e}, falling back to CPU")
                 args.device = "cpu"
                 args.device = "cpu"
@@ -227,16 +251,19 @@ def main():
         success_count = sum(1 for r in results if r.get('success', False))
         success_count = sum(1 for r in results if r.get('success', False))
         error_count = len(results) - success_count
         error_count = len(results) - success_count
         
         
-        print(f"\n" + "="*50)
-        print(f"Processing completed!")
-        print(f"Total files: {len(image_files)}")
-        print(f"Successful: {success_count}")
-        print(f"Failed: {error_count}")
+        print(f"\n" + "="*60)
+        print(f"✅ Processing completed!")
+        print(f"📊 Statistics:")
+        print(f"  Total files: {len(image_files)}")
+        print(f"  Successful: {success_count}")
+        print(f"  Failed: {error_count}")
         if len(image_files) > 0:
         if len(image_files) > 0:
-            print(f"Success rate: {success_count / len(image_files) * 100:.2f}%")
-        print(f"Total time: {total_time:.2f} seconds")
+            print(f"  Success rate: {success_count / len(image_files) * 100:.2f}%")
+        print(f"⏱️ Performance:")
+        print(f"  Total time: {total_time:.2f} seconds")
         if total_time > 0:
         if total_time > 0:
-            print(f"Throughput: {len(image_files) / total_time:.2f} images/second")
+            print(f"  Throughput: {len(image_files) / total_time:.2f} images/second")
+            print(f"  Avg time per image: {total_time / len(image_files):.2f} seconds")
         
         
         # 保存结果统计
         # 保存结果统计
         stats = {
         stats = {
@@ -246,9 +273,11 @@ def main():
             "success_rate": success_count / len(image_files) if len(image_files) > 0 else 0,
             "success_rate": success_count / len(image_files) if len(image_files) > 0 else 0,
             "total_time": total_time,
             "total_time": total_time,
             "throughput": len(image_files) / total_time if total_time > 0 else 0,
             "throughput": len(image_files) / total_time if total_time > 0 else 0,
+            "avg_time_per_image": total_time / len(image_files) if len(image_files) > 0 else 0,
             "batch_size": args.batch_size,
             "batch_size": args.batch_size,
             "device": args.device,
             "device": args.device,
-            "pipeline": args.pipeline
+            "pipeline": args.pipeline,
+            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
         }
         }
         
         
         # 保存最终结果
         # 保存最终结果
@@ -261,23 +290,23 @@ def main():
         with open(output_file, 'w', encoding='utf-8') as f:
         with open(output_file, 'w', encoding='utf-8') as f:
             json.dump(final_results, f, ensure_ascii=False, indent=2)
             json.dump(final_results, f, ensure_ascii=False, indent=2)
         
         
-        print(f"Results saved to: {output_file}")
+        print(f"💾 Results saved to: {output_file}")
         
         
         return 0
         return 0
         
         
     except Exception as e:
     except Exception as e:
-        print(f"Processing failed: {e}", file=sys.stderr)
+        print(f"Processing failed: {e}", file=sys.stderr)
         traceback.print_exc()
         traceback.print_exc()
         return 1
         return 1
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     print(f"🚀 启动单进程OCR程序...")
     print(f"🚀 启动单进程OCR程序...")
-    print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
+    print(f"🔧 CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
     
     
     if len(sys.argv) == 1:
     if len(sys.argv) == 1:
         # 如果没有命令行参数,使用默认配置运行
         # 如果没有命令行参数,使用默认配置运行
-        print("No command line arguments provided. Running with default configuration...")
+        print("ℹ️  No command line arguments provided. Running with default configuration...")
         
         
         # 默认配置
         # 默认配置
         default_config = {
         default_config = {