Browse Source

feat: 添加数字标准化功能,优化结果保存逻辑并支持表格一致性检查

zhch158_admin 1 month ago
parent
commit
93cfe7bf65
1 changed files with 159 additions and 19 deletions
  1. 159 19
      zhch/ppstructurev3_single_process.py

+ 159 - 19
zhch/ppstructurev3_single_process.py

@@ -28,7 +28,9 @@ from utils import (
     get_image_files_from_list,
     get_image_files_from_csv,
     collect_pid_files,
-    load_images_from_pdf
+    load_images_from_pdf,
+    normalize_financial_numbers,
+    normalize_markdown_table
 )
 
 def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]:
@@ -144,12 +146,138 @@ def get_input_files(args) -> List[str]:
     
     return input_files
 
+def normalize_pipeline_result(result: Dict[str, Any], normalize_numbers: bool = True) -> Dict[str, Any]:
+    """
+    对pipeline结果进行数字标准化处理
+    
+    Args:
+        result: pipeline返回的结果对象
+        normalize_numbers: 是否启用数字标准化
+        
+    Returns:
+        包含标准化信息的字典
+    """
+    if not normalize_numbers:
+        return {
+            "normalize_numbers": False,
+            "changes_applied": False,
+            "character_changes_count": 0,
+            "parsing_res_tables_count": 0,
+            "table_res_list_count": 0,
+            "table_consistency_fixed": False
+        }
+    
+    changes_count = 0
+    original_data = {}
+    
+    # 获取原始数据进行备份
+    if 'parsing_res_list' in result:
+        original_data['parsing_res_list'] = [item.copy() if hasattr(item, 'copy') else dict(item) for item in result['parsing_res_list']]
+
+    if 'table_res_list' in result:
+        original_data['table_res_list'] = [item.copy() if hasattr(item, 'copy') else dict(item) for item in result['table_res_list']]
+
+    try:
+        # 1. 标准化 parsing_res_list 中的文本内容
+        if 'parsing_res_list' in result:
+            for item in result['parsing_res_list']:
+                if 'block_content' in item and item['block_content']:
+                    original_content = str(item['block_content'])
+                    normalized_content = original_content
+                    
+                    # 根据block_label类型选择标准化方法
+                    if 'block_label' in item and item['block_label'] == 'table':
+                        normalized_content = normalize_markdown_table(original_content)
+                    
+                    if original_content != normalized_content:
+                        item['block_content'] = normalized_content
+                        changes_count += len([1 for o, n in zip(original_content, normalized_content) if o != n])
+        
+        # 2. 标准化 table_res_list 中的HTML表格
+        if 'table_res_list' in result:
+            for table_item in result['table_res_list']:
+                if 'pred_html' in table_item and table_item['pred_html']:
+                    original_html = str(table_item['pred_html'])
+                    normalized_html = normalize_markdown_table(original_html)
+                    
+                    if original_html != normalized_html:
+                        table_item['pred_html'] = normalized_html
+                        changes_count += len([1 for o, n in zip(original_html, normalized_html) if o != n])
+        
+        # 统计表格数量
+        parsing_res_tables_count = 0
+        table_res_list_count = 0
+        if 'parsing_res_list' in result:
+            parsing_res_tables_count = len([item for item in result['parsing_res_list'] 
+                                          if 'block_label' in item and item['block_label'] == 'table'])
+        if 'table_res_list' in result:
+            table_res_list_count = len(result['table_res_list'])
+        
+        # 检查是否需要修复表格一致性(这里只做统计,实际修复可能需要更复杂的逻辑)
+        table_consistency_fixed = False
+        if parsing_res_tables_count != table_res_list_count:
+            warnings.warn(f"⚠️ Warning: Table count mismatch - parsing_res_list has {parsing_res_tables_count} tables, "
+                          f"but table_res_list has {table_res_list_count} tables.")
+            table_consistency_fixed = True
+            # 这里可以添加实际的修复逻辑,例如根据需要添加或删除表格项
+            # 但由于缺乏具体规则,暂时只做统计和警告        
+        return {
+            "normalize_numbers": normalize_numbers,
+            "changes_applied": changes_count > 0,
+            "character_changes_count": changes_count,
+            "parsing_res_tables_count": parsing_res_tables_count,
+            "table_res_list_count": table_res_list_count,
+            "table_consistency_fixed": table_consistency_fixed
+        }
+        
+    except Exception as e:
+        print(f"⚠️ Warning: Error during normalization: {e}")
+        return {
+            "normalize_numbers": normalize_numbers,
+            "changes_applied": False,
+            "character_changes_count": 0,
+            "normalization_error": str(e)
+        }
+
+def save_normalized_files(result, output_dir: str, filename: str, 
+                         processing_info: Dict[str, Any], normalize_numbers: bool = True):
+    """
+    保存标准化处理后的文件,包括原始版本
+    """
+    output_path = Path(output_dir)
+    
+    # 保存标准化后的版本
+    json_output_path = str(output_path / f"{filename}.json")
+    md_output_path = str(output_path / f"{filename}.md")
+    
+    result.save_to_json(json_output_path)
+    result.save_to_markdown(md_output_path)
+    
+    # 如果有标准化变化,在JSON中添加处理信息
+    if normalize_numbers and processing_info.get('changes_applied', False):
+        try:
+            # 读取生成的JSON文件,添加处理信息
+            with open(json_output_path, 'r', encoding='utf-8') as f:
+                json_data = json.load(f)
+            
+            json_data['processing_info'] = processing_info
+            
+            # 重新保存包含处理信息的JSON
+            with open(json_output_path, 'w', encoding='utf-8') as f:
+                json.dump(json_data, f, ensure_ascii=False, indent=2)
+                
+        except Exception as e:
+            print(f"⚠️ Warning: Could not add processing info to JSON: {e}")
+    
+    return json_output_path, md_output_path
+
 def process_images_unified(image_paths: List[str],
                          pipeline_name: str = "PP-StructureV3",
                          device: str = "gpu:0",
-                         output_dir: str = "./output") -> List[Dict[str, Any]]:
+                         output_dir: str = "./output",
+                         normalize_numbers: bool = True) -> List[Dict[str, Any]]:
     """
-    统一的图像处理函数(修改自ppstructurev3_single_process.py)
+    统一的图像处理函数,支持数字标准化
     """
     # 创建输出目录
     output_path = Path(output_dir)
@@ -174,6 +302,7 @@ def process_images_unified(image_paths: List[str],
     total_images = len(image_paths)
     
     print(f"Processing {total_images} images one by one")
+    print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
     
     # 使用tqdm显示进度
     with tqdm(total=total_images, desc="Processing images", unit="img", 
@@ -208,12 +337,17 @@ def process_images_unified(image_paths: List[str],
                         else:
                             output_filename = f"{input_path.stem}"
                         
-                        # 保存JSON和Markdown文件
-                        json_output_path = str(Path(output_dir, f"{output_filename}.json"))
-                        md_output_path = str(Path(output_dir, f"{output_filename}.md"))
+                        # 应用数字标准化
+                        processing_info = normalize_pipeline_result(result, normalize_numbers)
+                        
+                        # 保存JSON和Markdown文件(包含标准化处理)
+                        json_output_path, md_output_path = save_normalized_files(
+                            result, output_dir, output_filename, processing_info, normalize_numbers
+                        )
                         
-                        result.save_to_json(json_output_path)
-                        result.save_to_markdown(md_output_path)
+                        # 如果有表格一致性修复,输出提示
+                        if processing_info.get('table_consistency_fixed', False):
+                            print(f"🔧 修复了表格一致性问题:{input_path.name}")
                         
                         # 记录处理结果
                         all_results.append({
@@ -223,7 +357,8 @@ def process_images_unified(image_paths: List[str],
                             "device": device,
                             "output_json": json_output_path,
                             "output_md": md_output_path,
-                            "is_pdf_page": "_page_" in input_path.name  # 标记是否为PDF页面
+                            "is_pdf_page": "_page_" in input_path.name,  # 标记是否为PDF页面
+                            "processing_info": processing_info
                         })
                         
                     except Exception as e:
@@ -278,11 +413,14 @@ def main():
     parser.add_argument("--pipeline", type=str, default="PP-StructureV3", help="Pipeline name")
     parser.add_argument("--device", type=str, default="gpu:0", help="Device string (e.g., 'gpu:0', 'cpu')")
     parser.add_argument("--pdf_dpi", type=int, default=200, help="DPI for PDF to image conversion")
+    parser.add_argument("--no-normalize", action="store_true", help="禁用数字标准化")
     parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 20 files)")
     parser.add_argument("--collect_results", type=str, help="收集处理结果到指定CSV文件")
 
     args = parser.parse_args()
     
+    normalize_numbers = not args.no_normalize
+    
     try:
         # 获取并预处理输入文件
         print("🔄 Preprocessing input files...")
@@ -304,7 +442,8 @@ def main():
             input_files,
             args.pipeline,
             args.device,
-            args.output_dir
+            args.output_dir,
+            normalize_numbers=normalize_numbers
         )
         total_time = time.time() - start_time
         
@@ -312,6 +451,7 @@ def main():
         success_count = sum(1 for r in results if r.get('success', False))
         error_count = len(results) - success_count
         pdf_page_count = sum(1 for r in results if r.get('is_pdf_page', False))
+        total_changes = sum(r.get('processing_info', {}).get('character_changes_count', 0) for r in results if 'processing_info' in r)
         
         print(f"\n" + "="*60)
         print(f"✅ Processing completed!")
@@ -323,6 +463,8 @@ def main():
         print(f"  Failed: {error_count}")
         if len(input_files) > 0:
             print(f"  Success rate: {success_count / len(input_files) * 100:.2f}%")
+        if normalize_numbers:
+            print(f"  总标准化字符数: {total_changes}")
         print(f"⏱️ Performance:")
         print(f"  Total time: {total_time:.2f} seconds")
         if total_time > 0:
@@ -343,6 +485,8 @@ def main():
             "device": args.device,
             "pipeline": args.pipeline,
             "pdf_dpi": args.pdf_dpi,
+            "normalize_numbers": normalize_numbers,
+            "total_character_changes": total_changes,
             "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
         }
         
@@ -359,7 +503,7 @@ def main():
         
         print(f"💾 Results saved to: {output_file}")
 
-		# 如果没有收集结果的路径,使用缺省文件名,和output_dir同一路径
+        # 如果没有收集结果的路径,使用缺省文件名,和output_dir同一路径
         if not args.collect_results:
             output_file_processed = Path(args.output_dir) / f"processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv"
         else:
@@ -387,7 +531,7 @@ if __name__ == "__main__":
         # 如果没有命令行参数,使用默认配置运行
         print("ℹ️  No command line arguments provided. Running with default configuration...")
         
-        # 默认配置(删除了 batch_size)
+        # 默认配置
         default_config = {
             "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
             "output_dir": "./OmniDocBench_PPStructureV3_Results",
@@ -396,18 +540,14 @@ if __name__ == "__main__":
             "collect_results": f"./OmniDocBench_PPStructureV3_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
         }
         
-        # default_config = {
-        #     "input_csv": "./OmniDocBench_PPStructureV3_Results/processed_files.csv",
-        #     "output_dir": "./OmniDocBench_PPStructureV3_Results",
-        #     "pipeline": "./my_config/PP-StructureV3.yaml",
-        #     "device": "gpu:0",
-        #     "collect_results": f"./OmniDocBench_PPStructureV3_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
-        # }
         # 构造参数
         sys.argv = [sys.argv[0]]
         for key, value in default_config.items():
             sys.argv.extend([f"--{key}", str(value)])
         
+        # 可以添加禁用标准化选项
+        # sys.argv.append("--no-normalize")
+        
         # 测试模式
         # sys.argv.append("--test_mode")