"""PPStructureV3公共函数""" import json import traceback import warnings import base64 from pathlib import Path from PIL import Image from typing import List, Dict, Any, Union import numpy as np from utils import ( load_images_from_pdf, normalize_markdown_table, normalize_financial_numbers ) def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]: """ 将PDF转换为图像文件 Args: pdf_file: PDF文件路径 output_dir: 输出目录 dpi: 图像分辨率 Returns: 生成的图像文件路径列表 """ pdf_path = Path(pdf_file) if not pdf_path.exists() or pdf_path.suffix.lower() != '.pdf': print(f"❌ Invalid PDF file: {pdf_path}") return [] # 如果没有指定输出目录,使用PDF同名目录 if output_dir is None: output_path = pdf_path.parent / f"{pdf_path.stem}" else: output_path = Path(output_dir) / f"{pdf_path.stem}" output_path = output_path.resolve() output_path.mkdir(parents=True, exist_ok=True) try: # 使用doc_utils中的函数加载PDF图像 images = load_images_from_pdf(str(pdf_path), dpi=dpi) image_paths = [] for i, image in enumerate(images): # 生成图像文件名 image_filename = f"{pdf_path.stem}_page_{i+1:03d}.png" image_path = output_path / image_filename # 保存图像 image.save(str(image_path)) image_paths.append(str(image_path)) print(f"✅ Converted {len(images)} pages from {pdf_path.name} to images") return image_paths except Exception as e: print(f"❌ Error converting PDF {pdf_path}: {e}") traceback.print_exc() return [] def convert_pruned_result_to_json(pruned_result: Dict[str, Any], input_image_path: str, output_dir: str, filename: str, normalize_numbers: bool = True) -> tuple[str, Dict[str, Any]]: """ 将API返回结果转换为标准JSON格式,并支持数字标准化 """ if not pruned_result: return "", {} # 构造标准格式的JSON converted_json = { "input_path": input_image_path, "page_index": None, "model_settings": pruned_result.get('model_settings', {}), "parsing_res_list": pruned_result.get('parsing_res_list', []), "doc_preprocessor_res": { "input_path": None, "page_index": None, "model_settings": pruned_result.get('doc_preprocessor_res', {}).get('model_settings', {}), "angle": pruned_result.get('doc_preprocessor_res', {}).get('angle', 0) }, "layout_det_res": { "input_path": None, "page_index": None, "boxes": pruned_result.get('layout_det_res', {}).get('boxes', []) }, "overall_ocr_res": { "input_path": None, "page_index": None, "model_settings": pruned_result.get('overall_ocr_res', {}).get('model_settings', {}), "dt_polys": pruned_result.get('overall_ocr_res', {}).get('dt_polys', []), "text_det_params": pruned_result.get('overall_ocr_res', {}).get('text_det_params', {}), "text_type": pruned_result.get('overall_ocr_res', {}).get('text_type', 'general'), "textline_orientation_angles": pruned_result.get('overall_ocr_res', {}).get('textline_orientation_angles', []), "text_rec_score_thresh": pruned_result.get('overall_ocr_res', {}).get('text_rec_score_thresh', 0.0), "return_word_box": pruned_result.get('overall_ocr_res', {}).get('return_word_box', False), "rec_texts": pruned_result.get('overall_ocr_res', {}).get('rec_texts', []), "rec_scores": pruned_result.get('overall_ocr_res', {}).get('rec_scores', []), "rec_polys": pruned_result.get('overall_ocr_res', {}).get('rec_polys', []), "rec_boxes": pruned_result.get('overall_ocr_res', {}).get('rec_boxes', []) }, "table_res_list": pruned_result.get('table_res_list', []) } # 数字标准化处理 original_json = converted_json.copy() changes_count = 0 if normalize_numbers: # 1. 标准化 parsing_res_list 中的文本内容 for item in converted_json.get('parsing_res_list', []): if 'block_content' in item: original_content = item['block_content'] normalized_content = original_content # 根据block_label类型选择标准化方法 if item.get('block_label') == 'table': normalized_content = normalize_markdown_table(original_content) # else: # normalized_content = normalize_financial_numbers(original_content) if original_content != normalized_content: item['block_content'] = normalized_content changes_count += len([1 for o, n in zip(original_content, normalized_content) if o != n]) # 2. 标准化 table_res_list 中的HTML表格 for table_item in converted_json.get('table_res_list', []): if 'pred_html' in table_item: original_html = table_item['pred_html'] normalized_html = normalize_markdown_table(original_html) if original_html != normalized_html: table_item['pred_html'] = normalized_html changes_count += len([1 for o, n in zip(original_html, normalized_html) if o != n]) # 检查是否需要修复表格一致性(这里只做统计,实际修复可能需要更复杂的逻辑) # 统计表格数量 parsing_res_tables_count = 0 table_res_list_count = 0 if 'parsing_res_list' in converted_json: parsing_res_tables_count = len([item for item in converted_json['parsing_res_list'] if 'block_label' in item and item['block_label'] == 'table']) if 'table_res_list' in converted_json: table_res_list_count = len(converted_json["table_res_list"]) table_consistency_fixed = False if parsing_res_tables_count != table_res_list_count: warnings.warn(f"⚠️ Warning: {filename} Table count mismatch - parsing_res_list has {parsing_res_tables_count} tables, " f"but table_res_list has {table_res_list_count} tables.") table_consistency_fixed = True # 这里可以添加实际的修复逻辑,例如根据需要添加或删除表格项 # 但由于缺乏具体规则,暂时只做统计和警告 # 3. 标准化 overall_ocr_res 中的识别文本 # ocr_res = converted_json.get('overall_ocr_res', {}) # if 'rec_texts' in ocr_res: # original_texts = ocr_res['rec_texts'][:] # normalized_texts = [] # for text in original_texts: # normalized_text = normalize_financial_numbers(text) # normalized_texts.append(normalized_text) # if text != normalized_text: # changes_count += len([1 for o, n in zip(text, normalized_text) if o != n]) # ocr_res['rec_texts'] = normalized_texts # 添加标准化处理信息 converted_json['processing_info'] = { "normalize_numbers": normalize_numbers, "changes_applied": changes_count > 0, "character_changes_count": changes_count, "parsing_res_tables_count": parsing_res_tables_count, "table_res_list_count": table_res_list_count, "table_consistency_fixed": table_consistency_fixed } # if changes_count > 0: # print(f"🔧 已标准化 {changes_count} 个字符(全角→半角)") else: converted_json['processing_info'] = { "normalize_numbers": False, "changes_applied": False, "character_changes_count": 0 } # 保存JSON文件 output_path = Path(output_dir).resolve() output_path.mkdir(parents=True, exist_ok=True) json_file_path = output_path / f"{filename}.json" with open(json_file_path, 'w', encoding='utf-8') as f: json.dump(converted_json, f, ensure_ascii=False, indent=2) # 如果启用了标准化且有变化,保存原始版本用于对比 if normalize_numbers and changes_count > 0: original_output_path = output_path / f"{filename}_original.json" with open(original_output_path, 'w', encoding='utf-8') as f: json.dump(original_json, f, ensure_ascii=False, indent=2) return str(output_path), converted_json def save_image(image: Union[Image.Image, str, np.ndarray], output_path: str) -> str: """ 保存单个图像到指定路径 Args: image: 要保存的图像,可以是PIL Image对象、base64字符串或numpy数组 output_path: 输出文件路径 Returns: 保存的图像文件路径 """ try: if isinstance(image, Image.Image): image.save(output_path) elif isinstance(image, str): # 处理base64字符串 img_data = base64.b64decode(image) with open(output_path, 'wb') as f: f.write(img_data) elif isinstance(image, np.ndarray): # 处理numpy数组 pil_image = Image.fromarray(image) pil_image.save(output_path) else: raise ValueError(f"Unsupported image type: {type(image)}") # print(f"📷 Saved image: {output_path}") return str(output_path) except Exception as e: print(f"❌ Error saving image {output_path}: {e}") return "" def save_output_images(output_images: Dict[str, Any], output_dir: str, output_filename: str) -> Dict[str, str]: """ 保存API返回的输出图像 Args: output_images: 图像数组字典或PIL Image对象字典 output_dir: 输出目录 output_filename: 输出文件名前缀 Returns: 保存的图像文件路径字典 """ if not output_images: return {} output_path = Path(output_dir).resolve() output_path.mkdir(parents=True, exist_ok=True) saved_images = {} for img_name, img_data in output_images.items(): try: # 生成文件名 img_filename = f"{output_filename}_{img_name}.jpg" img_path = output_path / img_filename save_image(img_data, str(img_path)) saved_images[img_name] = str(img_path) except Exception as e: print(f"❌ Error saving image {img_name}: {e}") print(f" Image data type: {type(img_data)}") if hasattr(img_data, 'shape'): print(f" Image shape: {img_data.shape}") traceback.print_exc() return saved_images def save_markdown_content(markdown_data: Dict[str, Any], output_dir: str, filename: str, normalize_numbers: bool = True, key_text: str = 'text', key_images: str = 'images', json_data: Dict[str, Any] = None) -> str: """ 保存Markdown内容,支持数字标准化和表格补全 """ if not markdown_data and not json_data: return "" output_path = Path(output_dir).resolve() output_path.mkdir(parents=True, exist_ok=True) # 🎯 优先使用json_data生成完整内容 if json_data: return save_markdown_content_enhanced(json_data, str(output_path), filename, normalize_numbers) # 原有逻辑保持不变 markdown_text = markdown_data.get(key_text, '') # 数字标准化处理 changes_count = 0 if normalize_numbers and markdown_text: original_markdown_text = markdown_text markdown_text = normalize_markdown_table(markdown_text) changes_count = len([1 for o, n in zip(original_markdown_text, markdown_text) if o != n]) # if changes_count > 0: # print(f"🔧 Markdown中已标准化 {changes_count} 个字符(全角→半角)") md_file_path = output_path / f"{filename}.md" with open(md_file_path, 'w', encoding='utf-8') as f: f.write(markdown_text) # 如果启用了标准化且有变化,保存原始版本用于对比 if normalize_numbers and changes_count > 0: original_output_path = output_path / f"{filename}_original.md" with open(original_output_path, 'w', encoding='utf-8') as f: f.write(original_markdown_text) # 保存Markdown中的图像 markdown_images = markdown_data.get(key_images, {}) for img_path, img_data in markdown_images.items(): try: full_img_path = output_path / img_path full_img_path.parent.mkdir(parents=True, exist_ok=True) save_image(img_data, str(full_img_path)) except Exception as e: print(f"❌ Error saving Markdown image {img_path}: {e}") return str(md_file_path) def save_markdown_content_enhanced(json_data: Dict[str, Any], output_dir: str, filename: str, normalize_numbers: bool = True) -> str: """ 增强版Markdown内容保存,同时处理parsing_res_list和table_res_list """ if not json_data: return "" output_path = Path(output_dir).resolve() output_path.mkdir(parents=True, exist_ok=True) markdown_content = [] # 处理 parsing_res_list parsing_res_list = json_data.get('parsing_res_list', []) table_res_list = json_data.get('table_res_list', []) table_index = 0 # 用于匹配table_res_list中的表格 for item in parsing_res_list: block_label = item.get('block_label', '') block_content = item.get('block_content', '') if block_label == 'table': # 如果是表格,优先使用table_res_list中的详细HTML if table_index < len(table_res_list): detailed_html = table_res_list[table_index].get('pred_html', block_content) if normalize_numbers: detailed_html = normalize_markdown_table(detailed_html) # 转换为居中显示的HTML markdown_content.append(f'