Эх сурвалжийг харах

feat: 重构图像处理逻辑,优化PDF转换、结果保存及Markdown内容处理

zhch158_admin 1 сар өмнө
parent
commit
5924d69342

+ 43 - 261
zhch/ppstructurev3_single_process.py

@@ -24,252 +24,15 @@ from dotenv import load_dotenv
 load_dotenv(override=True)
 
 from utils import (
-    get_image_files_from_dir,
-    get_image_files_from_list,
-    get_image_files_from_csv,
     collect_pid_files,
-    load_images_from_pdf,
-    normalize_financial_numbers,
-    normalize_markdown_table
 )
 
-def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]:
-    """
-    将PDF转换为图像文件
-    
-    Args:
-        pdf_file: PDF文件路径
-        output_dir: 输出目录
-        dpi: 图像分辨率
-        
-    Returns:
-        生成的图像文件路径列表
-    """
-    pdf_path = Path(pdf_file)
-    if not pdf_path.exists() or pdf_path.suffix.lower() != '.pdf':
-        print(f"❌ Invalid PDF file: {pdf_path}")
-        return []
-
-    # 如果没有指定输出目录,使用PDF同名目录
-    if output_dir is None:
-        output_path = pdf_path.parent / f"{pdf_path.stem}"
-    else:
-        output_path = Path(output_dir) / f"{pdf_path.stem}"
-    output_path = output_path.resolve()
-    output_path.mkdir(parents=True, exist_ok=True)
-
-    try:
-        # 使用doc_utils中的函数加载PDF图像
-        images = load_images_from_pdf(str(pdf_path), dpi=dpi)
-        
-        image_paths = []
-        for i, image in enumerate(images):
-            # 生成图像文件名
-            image_filename = f"{pdf_path.stem}_page_{i+1:03d}.png"
-            image_path = output_path / image_filename
-
-            # 保存图像
-            image.save(str(image_path))
-            image_paths.append(str(image_path))
-            
-        print(f"✅ Converted {len(images)} pages from {pdf_path.name} to images")
-        return image_paths
-        
-    except Exception as e:
-        print(f"❌ Error converting PDF {pdf_path}: {e}")
-        traceback.print_exc()
-        return []
-
-def get_input_files(args) -> List[str]:
-    """
-    获取输入文件列表,统一处理PDF和图像文件
-    
-    Args:
-        args: 命令行参数
-        
-    Returns:
-        处理后的图像文件路径列表
-    """
-    input_files = []
-    
-    # 获取原始输入文件
-    if args.input_csv:
-        raw_files = get_image_files_from_csv(args.input_csv, "fail")
-    elif args.input_file_list:
-        raw_files = get_image_files_from_list(args.input_file_list)
-    elif args.input_file:
-        raw_files = [Path(args.input_file).resolve()]
-    else:
-        input_dir = Path(args.input_dir).resolve()
-        if not input_dir.exists():
-            print(f"❌ Input directory does not exist: {input_dir}")
-            return []
-        
-        # 获取所有支持的文件(图像和PDF)
-        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
-        pdf_extensions = ['.pdf']
-        
-        raw_files = []
-        for ext in image_extensions + pdf_extensions:
-            raw_files.extend(list(input_dir.glob(f"*{ext}")))
-            raw_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
-        
-        raw_files = [str(f) for f in raw_files]
-    
-    # 分别处理PDF和图像文件
-    pdf_count = 0
-    image_count = 0
-    
-    for file_path in raw_files:
-        file_path = Path(file_path)
-        
-        if file_path.suffix.lower() == '.pdf':
-            # 转换PDF为图像
-            print(f"📄 Processing PDF: {file_path.name}")
-            pdf_images = convert_pdf_to_images(
-                str(file_path), 
-                args.output_dir,
-                dpi=args.pdf_dpi
-            )
-            input_files.extend(pdf_images)
-            pdf_count += 1
-        else:
-            # 直接添加图像文件
-            if file_path.exists():
-                input_files.append(str(file_path))
-                image_count += 1
-    
-    print(f"📊 Input summary:")
-    print(f"  PDF files processed: {pdf_count}")
-    print(f"  Image files found: {image_count}")
-    print(f"  Total image files to process: {len(input_files)}")
-    
-    return input_files
-
-def normalize_pipeline_result(result: Dict[str, Any], normalize_numbers: bool = True) -> Dict[str, Any]:
-    """
-    对pipeline结果进行数字标准化处理
-    
-    Args:
-        result: pipeline返回的结果对象
-        normalize_numbers: 是否启用数字标准化
-        
-    Returns:
-        包含标准化信息的字典
-    """
-    if not normalize_numbers:
-        return {
-            "normalize_numbers": False,
-            "changes_applied": False,
-            "character_changes_count": 0,
-            "parsing_res_tables_count": 0,
-            "table_res_list_count": 0,
-            "table_consistency_fixed": False
-        }
-    
-    changes_count = 0
-    original_data = {}
-    
-    # 获取原始数据进行备份
-    if 'parsing_res_list' in result:
-        original_data['parsing_res_list'] = [item.copy() if hasattr(item, 'copy') else dict(item) for item in result['parsing_res_list']]
-
-    if 'table_res_list' in result:
-        original_data['table_res_list'] = [item.copy() if hasattr(item, 'copy') else dict(item) for item in result['table_res_list']]
-
-    try:
-        # 1. 标准化 parsing_res_list 中的文本内容
-        if 'parsing_res_list' in result:
-            for item in result['parsing_res_list']:
-                if 'block_content' in item and item['block_content']:
-                    original_content = str(item['block_content'])
-                    normalized_content = original_content
-                    
-                    # 根据block_label类型选择标准化方法
-                    if 'block_label' in item and item['block_label'] == 'table':
-                        normalized_content = normalize_markdown_table(original_content)
-                    
-                    if original_content != normalized_content:
-                        item['block_content'] = normalized_content
-                        changes_count += len([1 for o, n in zip(original_content, normalized_content) if o != n])
-        
-        # 2. 标准化 table_res_list 中的HTML表格
-        if 'table_res_list' in result:
-            for table_item in result['table_res_list']:
-                if 'pred_html' in table_item and table_item['pred_html']:
-                    original_html = str(table_item['pred_html'])
-                    normalized_html = normalize_markdown_table(original_html)
-                    
-                    if original_html != normalized_html:
-                        table_item['pred_html'] = normalized_html
-                        changes_count += len([1 for o, n in zip(original_html, normalized_html) if o != n])
-        
-        # 统计表格数量
-        parsing_res_tables_count = 0
-        table_res_list_count = 0
-        if 'parsing_res_list' in result:
-            parsing_res_tables_count = len([item for item in result['parsing_res_list'] 
-                                          if 'block_label' in item and item['block_label'] == 'table'])
-        if 'table_res_list' in result:
-            table_res_list_count = len(result['table_res_list'])
-        
-        # 检查是否需要修复表格一致性(这里只做统计,实际修复可能需要更复杂的逻辑)
-        table_consistency_fixed = False
-        if parsing_res_tables_count != table_res_list_count:
-            warnings.warn(f"⚠️ Warning: Table count mismatch - parsing_res_list has {parsing_res_tables_count} tables, "
-                          f"but table_res_list has {table_res_list_count} tables.")
-            table_consistency_fixed = True
-            # 这里可以添加实际的修复逻辑,例如根据需要添加或删除表格项
-            # 但由于缺乏具体规则,暂时只做统计和警告        
-        return {
-            "normalize_numbers": normalize_numbers,
-            "changes_applied": changes_count > 0,
-            "character_changes_count": changes_count,
-            "parsing_res_tables_count": parsing_res_tables_count,
-            "table_res_list_count": table_res_list_count,
-            "table_consistency_fixed": table_consistency_fixed
-        }
-        
-    except Exception as e:
-        print(f"⚠️ Warning: Error during normalization: {e}")
-        return {
-            "normalize_numbers": normalize_numbers,
-            "changes_applied": False,
-            "character_changes_count": 0,
-            "normalization_error": str(e)
-        }
-
-def save_normalized_files(result, output_dir: str, filename: str, 
-                         processing_info: Dict[str, Any], normalize_numbers: bool = True):
-    """
-    保存标准化处理后的文件,包括原始版本
-    """
-    output_path = Path(output_dir)
-    
-    # 保存标准化后的版本
-    json_output_path = str(output_path / f"{filename}.json")
-    md_output_path = str(output_path / f"{filename}.md")
-    
-    result.save_to_json(json_output_path)
-    result.save_to_markdown(md_output_path)
-    
-    # 如果有标准化变化,在JSON中添加处理信息
-    if normalize_numbers and processing_info.get('changes_applied', False):
-        try:
-            # 读取生成的JSON文件,添加处理信息
-            with open(json_output_path, 'r', encoding='utf-8') as f:
-                json_data = json.load(f)
-            
-            json_data['processing_info'] = processing_info
-            
-            # 重新保存包含处理信息的JSON
-            with open(json_output_path, 'w', encoding='utf-8') as f:
-                json.dump(json_data, f, ensure_ascii=False, indent=2)
-                
-        except Exception as e:
-            print(f"⚠️ Warning: Could not add processing info to JSON: {e}")
-    
-    return json_output_path, md_output_path
+from ppstructurev3_utils import (
+   get_input_files,
+   convert_pruned_result_to_json,
+   save_output_images,
+   save_markdown_content
+)
 
 def process_images_unified(image_paths: List[str],
                          pipeline_name: str = "PP-StructureV3",
@@ -327,7 +90,9 @@ def process_images_unified(image_paths: List[str],
                 processing_time = time.time() - start_time
                 
                 # 处理结果
-                for result in results:
+                for idx, result in enumerate(results):
+                    if idx > 0:
+                        raise ValueError("Multiple results found for a single image")
                     try:
                         input_path = Path(result["input_path"])
                         
@@ -337,17 +102,30 @@ def process_images_unified(image_paths: List[str],
                         else:
                             output_filename = f"{input_path.stem}"
                         
-                        # 应用数字标准化
-                        processing_info = normalize_pipeline_result(result, normalize_numbers)
-                        
-                        # 保存JSON和Markdown文件(包含标准化处理)
-                        json_output_path, md_output_path = save_normalized_files(
-                            result, output_dir, output_filename, processing_info, normalize_numbers
+                        # 转换并保存标准JSON格式
+                        json_content = result.json['res']
+                        json_output_path, converted_json = convert_pruned_result_to_json(
+                            json_content, 
+                            str(input_path), 
+                            output_dir,
+                            output_filename,
+                            normalize_numbers=normalize_numbers
+                        )
+
+                        # 保存输出图像
+                        img_content = result.img
+                        saved_images = save_output_images(img_content, str(output_dir), output_filename) 
+
+                        # 保存Markdown内容
+                        markdown_content = result.markdown
+                        md_output_path = save_markdown_content(
+                            markdown_content, 
+                            output_dir, 
+                            output_filename,
+                            normalize_numbers=normalize_numbers,
+                            key_text='markdown_texts',
+                            key_images='markdown_images'
                         )
-                        
-                        # 如果有表格一致性修复,输出提示
-                        if processing_info.get('table_consistency_fixed', False):
-                            print(f"🔧 修复了表格一致性问题:{input_path.name}")
                         
                         # 记录处理结果
                         all_results.append({
@@ -358,9 +136,9 @@ def process_images_unified(image_paths: List[str],
                             "output_json": json_output_path,
                             "output_md": md_output_path,
                             "is_pdf_page": "_page_" in input_path.name,  # 标记是否为PDF页面
-                            "processing_info": processing_info
-                        })
-                        
+                            "processing_info": converted_json.get('processing_info', {})
+                        })                        
+
                     except Exception as e:
                         print(f"Error saving result for {result.get('input_path', 'unknown')}: {e}", file=sys.stderr)
                         traceback.print_exc()
@@ -533,11 +311,15 @@ if __name__ == "__main__":
         
         # 默认配置
         default_config = {
-            "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
-            "output_dir": "./OmniDocBench_PPStructureV3_Results",
+            # "input_file": "/home/ubuntu/zhch/data/至远彩色印刷工业有限公司/2023年度报告母公司.pdf",
+            "input_file": "/home/ubuntu/zhch/data/至远彩色印刷工业有限公司/PPStructureV3_Results/2023年度报告母公司/2023年度报告母公司_page_027.png",
+            "output_dir": "/home/ubuntu/zhch/data/至远彩色印刷工业有限公司/PPStructureV3_Results",
+            "collect_results": f"/home/ubuntu/zhch/data/至远彩色印刷工业有限公司/PPStructureV3_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
+            # "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
+            # "output_dir": "./OmniDocBench_PPStructureV3_Results",
+            # "collect_results": f"./OmniDocBench_PPStructureV3_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
             "pipeline": "./my_config/PP-StructureV3.yaml",
-            "device": "gpu:0",
-            "collect_results": f"./OmniDocBench_PPStructureV3_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
+            "device": "gpu:3",
         }
         
         # 构造参数