Эх сурвалжийг харах

feat: 添加PPStructureV3公共函数,支持PDF转换为图像、输入文件处理及结果保存

zhch158_admin 1 сар өмнө
parent
commit
e07d3a6a41
1 өөрчлөгдсөн 388 нэмэгдсэн , 0 устгасан
  1. 388 0
      zhch/ppstructurev3_utils.py

+ 388 - 0
zhch/ppstructurev3_utils.py

@@ -0,0 +1,388 @@
+"""PPStructureV3公共函数"""
+import json
+import traceback
+import warnings
+import base64
+from pathlib import Path
+from PIL import Image
+from typing import List, Dict, Any, Union
+import numpy as np
+
+from utils import (
+    get_image_files_from_dir,
+    get_image_files_from_list,
+    get_image_files_from_csv,
+    load_images_from_pdf,
+    normalize_markdown_table
+)
+
+def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]:
+    """
+    将PDF转换为图像文件
+    
+    Args:
+        pdf_file: PDF文件路径
+        output_dir: 输出目录
+        dpi: 图像分辨率
+        
+    Returns:
+        生成的图像文件路径列表
+    """
+    pdf_path = Path(pdf_file)
+    if not pdf_path.exists() or pdf_path.suffix.lower() != '.pdf':
+        print(f"❌ Invalid PDF file: {pdf_path}")
+        return []
+
+    # 如果没有指定输出目录,使用PDF同名目录
+    if output_dir is None:
+        output_path = pdf_path.parent / f"{pdf_path.stem}"
+    else:
+        output_path = Path(output_dir) / f"{pdf_path.stem}"
+    output_path = output_path.resolve()
+    output_path.mkdir(parents=True, exist_ok=True)
+
+    try:
+        # 使用doc_utils中的函数加载PDF图像
+        images = load_images_from_pdf(str(pdf_path), dpi=dpi)
+        
+        image_paths = []
+        for i, image in enumerate(images):
+            # 生成图像文件名
+            image_filename = f"{pdf_path.stem}_page_{i+1:03d}.png"
+            image_path = output_path / image_filename
+
+            # 保存图像
+            image.save(str(image_path))
+            image_paths.append(str(image_path))
+            
+        print(f"✅ Converted {len(images)} pages from {pdf_path.name} to images")
+        return image_paths
+        
+    except Exception as e:
+        print(f"❌ Error converting PDF {pdf_path}: {e}")
+        traceback.print_exc()
+        return []
+
+def get_input_files(args) -> List[str]:
+    """
+    获取输入文件列表,统一处理PDF和图像文件
+    
+    Args:
+        args: 命令行参数
+        
+    Returns:
+        处理后的图像文件路径列表
+    """
+    input_files = []
+    
+    # 获取原始输入文件
+    if args.input_csv:
+        raw_files = get_image_files_from_csv(args.input_csv, "fail")
+    elif args.input_file_list:
+        raw_files = get_image_files_from_list(args.input_file_list)
+    elif args.input_file:
+        raw_files = [Path(args.input_file).resolve()]
+    else:
+        input_dir = Path(args.input_dir).resolve()
+        if not input_dir.exists():
+            print(f"❌ Input directory does not exist: {input_dir}")
+            return []
+        
+        # 获取所有支持的文件(图像和PDF)
+        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
+        pdf_extensions = ['.pdf']
+        
+        raw_files = []
+        for ext in image_extensions + pdf_extensions:
+            raw_files.extend(list(input_dir.glob(f"*{ext}")))
+            raw_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
+        
+        raw_files = [str(f) for f in raw_files]
+    
+    # 分别处理PDF和图像文件
+    pdf_count = 0
+    image_count = 0
+    
+    for file_path in raw_files:
+        file_path = Path(file_path)
+        
+        if file_path.suffix.lower() == '.pdf':
+            # 转换PDF为图像
+            print(f"📄 Processing PDF: {file_path.name}")
+            pdf_images = convert_pdf_to_images(
+                str(file_path), 
+                args.output_dir,
+                dpi=args.pdf_dpi
+            )
+            input_files.extend(pdf_images)
+            pdf_count += 1
+        else:
+            # 直接添加图像文件
+            if file_path.exists():
+                input_files.append(str(file_path))
+                image_count += 1
+    
+    print(f"📊 Input summary:")
+    print(f"  PDF files processed: {pdf_count}")
+    print(f"  Image files found: {image_count}")
+    print(f"  Total image files to process: {len(input_files)}")
+    
+    return input_files
+
+def convert_pruned_result_to_json(pruned_result: Dict[str, Any], 
+                              input_image_path: str, 
+                              output_dir: str, 
+                              filename: str,
+                              normalize_numbers: bool = True) -> tuple[str, Dict[str, Any]]:
+    """
+    将API返回结果转换为标准JSON格式,并支持数字标准化
+    """
+    if not pruned_result:
+        return "", {}
+    
+    # 构造标准格式的JSON
+    converted_json = {
+        "input_path": input_image_path,
+        "page_index": None,
+        "model_settings": pruned_result.get('model_settings', {}),
+        "parsing_res_list": pruned_result.get('parsing_res_list', []),
+        "doc_preprocessor_res": {
+            "input_path": None,
+            "page_index": None,
+            "model_settings": pruned_result.get('doc_preprocessor_res', {}).get('model_settings', {}),
+            "angle": pruned_result.get('doc_preprocessor_res', {}).get('angle', 0)
+        },
+        "layout_det_res": {
+            "input_path": None,
+            "page_index": None,
+            "boxes": pruned_result.get('layout_det_res', {}).get('boxes', [])
+        },
+        "overall_ocr_res": {
+            "input_path": None,
+            "page_index": None,
+            "model_settings": pruned_result.get('overall_ocr_res', {}).get('model_settings', {}),
+            "dt_polys": pruned_result.get('overall_ocr_res', {}).get('dt_polys', []),
+            "text_det_params": pruned_result.get('overall_ocr_res', {}).get('text_det_params', {}),
+            "text_type": pruned_result.get('overall_ocr_res', {}).get('text_type', 'general'),
+            "textline_orientation_angles": pruned_result.get('overall_ocr_res', {}).get('textline_orientation_angles', []),
+            "text_rec_score_thresh": pruned_result.get('overall_ocr_res', {}).get('text_rec_score_thresh', 0.0),
+            "return_word_box": pruned_result.get('overall_ocr_res', {}).get('return_word_box', False),
+            "rec_texts": pruned_result.get('overall_ocr_res', {}).get('rec_texts', []),
+            "rec_scores": pruned_result.get('overall_ocr_res', {}).get('rec_scores', []),
+            "rec_polys": pruned_result.get('overall_ocr_res', {}).get('rec_polys', []),
+            "rec_boxes": pruned_result.get('overall_ocr_res', {}).get('rec_boxes', [])
+        },
+        "table_res_list": pruned_result.get('table_res_list', [])
+    }
+    
+    # 数字标准化处理
+    original_json = converted_json.copy()
+    changes_count = 0
+    
+    if normalize_numbers:
+        # 1. 标准化 parsing_res_list 中的文本内容
+        for item in converted_json.get('parsing_res_list', []):
+            if 'block_content' in item:
+                original_content = item['block_content']
+                normalized_content = original_content
+                # 根据block_label类型选择标准化方法
+                if item.get('block_label') == 'table':
+                    normalized_content = normalize_markdown_table(original_content)
+                # else:
+                #     normalized_content = normalize_financial_numbers(original_content)
+                
+                if original_content != normalized_content:
+                    item['block_content'] = normalized_content
+                    changes_count += len([1 for o, n in zip(original_content, normalized_content) if o != n])
+        
+        # 2. 标准化 table_res_list 中的HTML表格
+        for table_item in converted_json.get('table_res_list', []):
+            if 'pred_html' in table_item:
+                original_html = table_item['pred_html']
+                normalized_html = normalize_markdown_table(original_html)
+                
+                if original_html != normalized_html:
+                    table_item['pred_html'] = normalized_html
+                    changes_count += len([1 for o, n in zip(original_html, normalized_html) if o != n])
+
+        # 检查是否需要修复表格一致性(这里只做统计,实际修复可能需要更复杂的逻辑)
+               # 统计表格数量
+        parsing_res_tables_count = 0
+        table_res_list_count = 0
+        if 'parsing_res_list' in converted_json:
+            parsing_res_tables_count = len([item for item in converted_json['parsing_res_list'] 
+                                          if 'block_label' in item and item['block_label'] == 'table'])
+        if 'table_res_list' in converted_json:
+            table_res_list_count = len(converted_json["table_res_list"])
+        table_consistency_fixed = False
+        if parsing_res_tables_count != table_res_list_count:
+            warnings.warn(f"⚠️ Warning: {filename} Table count mismatch - parsing_res_list has {parsing_res_tables_count} tables, "
+                          f"but table_res_list has {table_res_list_count} tables.")
+            table_consistency_fixed = True
+            # 这里可以添加实际的修复逻辑,例如根据需要添加或删除表格项
+            # 但由于缺乏具体规则,暂时只做统计和警告
+
+        # 3. 标准化 overall_ocr_res 中的识别文本
+        # ocr_res = converted_json.get('overall_ocr_res', {})
+        # if 'rec_texts' in ocr_res:
+        #     original_texts = ocr_res['rec_texts'][:]
+        #     normalized_texts = []
+            
+        #     for text in original_texts:
+        #         normalized_text = normalize_financial_numbers(text)
+        #         normalized_texts.append(normalized_text)
+        #         if text != normalized_text:
+        #             changes_count += len([1 for o, n in zip(text, normalized_text) if o != n])
+            
+        #     ocr_res['rec_texts'] = normalized_texts
+        
+        # 添加标准化处理信息
+        converted_json['processing_info'] = {
+            "normalize_numbers": normalize_numbers,
+            "changes_applied": changes_count > 0,
+            "character_changes_count": changes_count,
+            "parsing_res_tables_count": parsing_res_tables_count,
+            "table_res_list_count": table_res_list_count,
+            "table_consistency_fixed": table_consistency_fixed
+        }
+        
+        # if changes_count > 0:
+        #     print(f"🔧 已标准化 {changes_count} 个字符(全角→半角)")
+    else:
+        converted_json['processing_info'] = {
+            "normalize_numbers": False,
+            "changes_applied": False,
+            "character_changes_count": 0
+        }
+    
+    # 保存JSON文件
+    output_path = Path(output_dir).resolve()
+    output_path.mkdir(parents=True, exist_ok=True)
+    
+    json_file_path = output_path / f"{filename}.json"
+    with open(json_file_path, 'w', encoding='utf-8') as f:
+        json.dump(converted_json, f, ensure_ascii=False, indent=2)
+    
+    # 如果启用了标准化且有变化,保存原始版本用于对比
+    if normalize_numbers and changes_count > 0:
+        original_output_path = output_path / f"{filename}_original.json"
+        with open(original_output_path, 'w', encoding='utf-8') as f:
+            json.dump(original_json, f, ensure_ascii=False, indent=2)
+    
+    return str(output_path), converted_json
+
+def save_image(image: Union[Image.Image, str, np.ndarray], output_path: str) -> str:
+    """
+    保存单个图像到指定路径
+
+    Args:
+        image: 要保存的图像,可以是PIL Image对象、base64字符串或numpy数组
+        output_path: 输出文件路径
+
+    Returns:
+        保存的图像文件路径
+    """
+    try:
+        if isinstance(image, Image.Image):
+            image.save(output_path)
+        elif isinstance(image, str):
+            # 处理base64字符串
+            img_data = base64.b64decode(image)
+            with open(output_path, 'wb') as f:
+                f.write(img_data)
+        elif isinstance(image, np.ndarray):
+            # 处理numpy数组
+            pil_image = Image.fromarray(image)
+            pil_image.save(output_path)
+        else:
+            raise ValueError(f"Unsupported image type: {type(image)}")
+
+        # print(f"📷 Saved image: {output_path}")
+        return str(output_path)
+
+    except Exception as e:
+        print(f"❌ Error saving image {output_path}: {e}")
+        return ""
+
+def save_output_images(output_images: Dict[str, Any], output_dir: str, output_filename: str) -> Dict[str, str]:
+    """
+    保存API返回的输出图像
+    
+    Args:
+        output_images: 图像数组字典或PIL Image对象字典
+        output_dir: 输出目录
+        output_filename: 输出文件名前缀
+        
+    Returns:
+        保存的图像文件路径字典
+    """
+    if not output_images:
+        return {}
+    
+    output_path = Path(output_dir).resolve()
+    output_path.mkdir(parents=True, exist_ok=True)
+    
+    saved_images = {}
+    
+    for img_name, img_data in output_images.items():
+        try:
+            # 生成文件名
+            img_filename = f"{output_filename}_{img_name}.jpg"
+            img_path = output_path / img_filename
+            save_image(img_data, str(img_path))
+            saved_images[img_name] = str(img_path)
+            
+        except Exception as e:
+            print(f"❌ Error saving image {img_name}: {e}")
+            print(f"   Image data type: {type(img_data)}")
+            if hasattr(img_data, 'shape'):
+                print(f"   Image shape: {img_data.shape}")
+            traceback.print_exc()
+    
+    return saved_images
+
+def save_markdown_content(markdown_data: Dict[str, Any], output_dir: str, 
+                         filename: str, normalize_numbers: bool = True, key_text: str = 'text', key_images: str = 'images') -> str:
+    """
+    保存Markdown内容,支持数字标准化
+    """
+    if not markdown_data:
+        return ""
+    output_path = Path(output_dir).resolve()
+    output_path.mkdir(parents=True, exist_ok=True)
+    
+    # 保存Markdown文本
+    markdown_text = markdown_data.get(key_text, '')
+    
+    # 数字标准化处理
+    changes_count = 0
+    if normalize_numbers and markdown_text:
+        original_markdown_text = markdown_text
+        markdown_text = normalize_markdown_table(markdown_text)
+        
+        changes_count = len([1 for o, n in zip(original_markdown_text, markdown_text) if o != n])
+        # if changes_count > 0:
+        #     print(f"🔧 Markdown中已标准化 {changes_count} 个字符(全角→半角)")
+    
+    md_file_path = output_path / f"{filename}.md"
+    with open(md_file_path, 'w', encoding='utf-8') as f:
+        f.write(markdown_text)
+    
+    # 如果启用了标准化且有变化,保存原始版本用于对比
+    if normalize_numbers and changes_count > 0:
+        original_output_path = output_path / f"{filename}_original.md"
+        with open(original_output_path, 'w', encoding='utf-8') as f:
+            f.write(original_markdown_text)
+
+    # 保存Markdown中的图像
+    markdown_images = markdown_data.get(key_images, {})
+    for img_path, img_data in markdown_images.items():
+        try:
+            full_img_path = output_path / img_path
+            full_img_path.parent.mkdir(parents=True, exist_ok=True)
+            save_image(img_data, str(full_img_path))
+            
+        except Exception as e:
+            print(f"❌ Error saving Markdown image {img_path}: {e}")
+
+    return str(md_file_path)