| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319 |
- """PPStructureV3公共函数"""
- import json
- import traceback
- import warnings
- import base64
- from pathlib import Path
- from PIL import Image
- from typing import List, Dict, Any, Union
- import numpy as np
- from utils import (
- load_images_from_pdf,
- normalize_markdown_table
- )
- def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]:
- """
- 将PDF转换为图像文件
-
- Args:
- pdf_file: PDF文件路径
- output_dir: 输出目录
- dpi: 图像分辨率
-
- Returns:
- 生成的图像文件路径列表
- """
- pdf_path = Path(pdf_file)
- if not pdf_path.exists() or pdf_path.suffix.lower() != '.pdf':
- print(f"❌ Invalid PDF file: {pdf_path}")
- return []
- # 如果没有指定输出目录,使用PDF同名目录
- if output_dir is None:
- output_path = pdf_path.parent / f"{pdf_path.stem}"
- else:
- output_path = Path(output_dir) / f"{pdf_path.stem}"
- output_path = output_path.resolve()
- output_path.mkdir(parents=True, exist_ok=True)
- try:
- # 使用doc_utils中的函数加载PDF图像
- images = load_images_from_pdf(str(pdf_path), dpi=dpi)
-
- image_paths = []
- for i, image in enumerate(images):
- # 生成图像文件名
- image_filename = f"{pdf_path.stem}_page_{i+1:03d}.png"
- image_path = output_path / image_filename
- # 保存图像
- image.save(str(image_path))
- image_paths.append(str(image_path))
-
- print(f"✅ Converted {len(images)} pages from {pdf_path.name} to images")
- return image_paths
-
- except Exception as e:
- print(f"❌ Error converting PDF {pdf_path}: {e}")
- traceback.print_exc()
- return []
- def convert_pruned_result_to_json(pruned_result: Dict[str, Any],
- input_image_path: str,
- output_dir: str,
- filename: str,
- normalize_numbers: bool = True) -> tuple[str, Dict[str, Any]]:
- """
- 将API返回结果转换为标准JSON格式,并支持数字标准化
- """
- if not pruned_result:
- return "", {}
-
- # 构造标准格式的JSON
- converted_json = {
- "input_path": input_image_path,
- "page_index": None,
- "model_settings": pruned_result.get('model_settings', {}),
- "parsing_res_list": pruned_result.get('parsing_res_list', []),
- "doc_preprocessor_res": {
- "input_path": None,
- "page_index": None,
- "model_settings": pruned_result.get('doc_preprocessor_res', {}).get('model_settings', {}),
- "angle": pruned_result.get('doc_preprocessor_res', {}).get('angle', 0)
- },
- "layout_det_res": {
- "input_path": None,
- "page_index": None,
- "boxes": pruned_result.get('layout_det_res', {}).get('boxes', [])
- },
- "overall_ocr_res": {
- "input_path": None,
- "page_index": None,
- "model_settings": pruned_result.get('overall_ocr_res', {}).get('model_settings', {}),
- "dt_polys": pruned_result.get('overall_ocr_res', {}).get('dt_polys', []),
- "text_det_params": pruned_result.get('overall_ocr_res', {}).get('text_det_params', {}),
- "text_type": pruned_result.get('overall_ocr_res', {}).get('text_type', 'general'),
- "textline_orientation_angles": pruned_result.get('overall_ocr_res', {}).get('textline_orientation_angles', []),
- "text_rec_score_thresh": pruned_result.get('overall_ocr_res', {}).get('text_rec_score_thresh', 0.0),
- "return_word_box": pruned_result.get('overall_ocr_res', {}).get('return_word_box', False),
- "rec_texts": pruned_result.get('overall_ocr_res', {}).get('rec_texts', []),
- "rec_scores": pruned_result.get('overall_ocr_res', {}).get('rec_scores', []),
- "rec_polys": pruned_result.get('overall_ocr_res', {}).get('rec_polys', []),
- "rec_boxes": pruned_result.get('overall_ocr_res', {}).get('rec_boxes', [])
- },
- "table_res_list": pruned_result.get('table_res_list', [])
- }
-
- # 数字标准化处理
- original_json = converted_json.copy()
- changes_count = 0
-
- if normalize_numbers:
- # 1. 标准化 parsing_res_list 中的文本内容
- for item in converted_json.get('parsing_res_list', []):
- if 'block_content' in item:
- original_content = item['block_content']
- normalized_content = original_content
- # 根据block_label类型选择标准化方法
- if item.get('block_label') == 'table':
- normalized_content = normalize_markdown_table(original_content)
- # else:
- # normalized_content = normalize_financial_numbers(original_content)
-
- if original_content != normalized_content:
- item['block_content'] = normalized_content
- changes_count += len([1 for o, n in zip(original_content, normalized_content) if o != n])
-
- # 2. 标准化 table_res_list 中的HTML表格
- for table_item in converted_json.get('table_res_list', []):
- if 'pred_html' in table_item:
- original_html = table_item['pred_html']
- normalized_html = normalize_markdown_table(original_html)
-
- if original_html != normalized_html:
- table_item['pred_html'] = normalized_html
- changes_count += len([1 for o, n in zip(original_html, normalized_html) if o != n])
- # 检查是否需要修复表格一致性(这里只做统计,实际修复可能需要更复杂的逻辑)
- # 统计表格数量
- parsing_res_tables_count = 0
- table_res_list_count = 0
- if 'parsing_res_list' in converted_json:
- parsing_res_tables_count = len([item for item in converted_json['parsing_res_list']
- if 'block_label' in item and item['block_label'] == 'table'])
- if 'table_res_list' in converted_json:
- table_res_list_count = len(converted_json["table_res_list"])
- table_consistency_fixed = False
- if parsing_res_tables_count != table_res_list_count:
- warnings.warn(f"⚠️ Warning: {filename} Table count mismatch - parsing_res_list has {parsing_res_tables_count} tables, "
- f"but table_res_list has {table_res_list_count} tables.")
- table_consistency_fixed = True
- # 这里可以添加实际的修复逻辑,例如根据需要添加或删除表格项
- # 但由于缺乏具体规则,暂时只做统计和警告
- # 3. 标准化 overall_ocr_res 中的识别文本
- # ocr_res = converted_json.get('overall_ocr_res', {})
- # if 'rec_texts' in ocr_res:
- # original_texts = ocr_res['rec_texts'][:]
- # normalized_texts = []
-
- # for text in original_texts:
- # normalized_text = normalize_financial_numbers(text)
- # normalized_texts.append(normalized_text)
- # if text != normalized_text:
- # changes_count += len([1 for o, n in zip(text, normalized_text) if o != n])
-
- # ocr_res['rec_texts'] = normalized_texts
-
- # 添加标准化处理信息
- converted_json['processing_info'] = {
- "normalize_numbers": normalize_numbers,
- "changes_applied": changes_count > 0,
- "character_changes_count": changes_count,
- "parsing_res_tables_count": parsing_res_tables_count,
- "table_res_list_count": table_res_list_count,
- "table_consistency_fixed": table_consistency_fixed
- }
-
- # if changes_count > 0:
- # print(f"🔧 已标准化 {changes_count} 个字符(全角→半角)")
- else:
- converted_json['processing_info'] = {
- "normalize_numbers": False,
- "changes_applied": False,
- "character_changes_count": 0
- }
-
- # 保存JSON文件
- output_path = Path(output_dir).resolve()
- output_path.mkdir(parents=True, exist_ok=True)
-
- json_file_path = output_path / f"{filename}.json"
- with open(json_file_path, 'w', encoding='utf-8') as f:
- json.dump(converted_json, f, ensure_ascii=False, indent=2)
-
- # 如果启用了标准化且有变化,保存原始版本用于对比
- if normalize_numbers and changes_count > 0:
- original_output_path = output_path / f"{filename}_original.json"
- with open(original_output_path, 'w', encoding='utf-8') as f:
- json.dump(original_json, f, ensure_ascii=False, indent=2)
-
- return str(output_path), converted_json
- def save_image(image: Union[Image.Image, str, np.ndarray], output_path: str) -> str:
- """
- 保存单个图像到指定路径
- Args:
- image: 要保存的图像,可以是PIL Image对象、base64字符串或numpy数组
- output_path: 输出文件路径
- Returns:
- 保存的图像文件路径
- """
- try:
- if isinstance(image, Image.Image):
- image.save(output_path)
- elif isinstance(image, str):
- # 处理base64字符串
- img_data = base64.b64decode(image)
- with open(output_path, 'wb') as f:
- f.write(img_data)
- elif isinstance(image, np.ndarray):
- # 处理numpy数组
- pil_image = Image.fromarray(image)
- pil_image.save(output_path)
- else:
- raise ValueError(f"Unsupported image type: {type(image)}")
- # print(f"📷 Saved image: {output_path}")
- return str(output_path)
- except Exception as e:
- print(f"❌ Error saving image {output_path}: {e}")
- return ""
- def save_output_images(output_images: Dict[str, Any], output_dir: str, output_filename: str) -> Dict[str, str]:
- """
- 保存API返回的输出图像
-
- Args:
- output_images: 图像数组字典或PIL Image对象字典
- output_dir: 输出目录
- output_filename: 输出文件名前缀
-
- Returns:
- 保存的图像文件路径字典
- """
- if not output_images:
- return {}
-
- output_path = Path(output_dir).resolve()
- output_path.mkdir(parents=True, exist_ok=True)
-
- saved_images = {}
-
- for img_name, img_data in output_images.items():
- try:
- # 生成文件名
- img_filename = f"{output_filename}_{img_name}.jpg"
- img_path = output_path / img_filename
- save_image(img_data, str(img_path))
- saved_images[img_name] = str(img_path)
-
- except Exception as e:
- print(f"❌ Error saving image {img_name}: {e}")
- print(f" Image data type: {type(img_data)}")
- if hasattr(img_data, 'shape'):
- print(f" Image shape: {img_data.shape}")
- traceback.print_exc()
-
- return saved_images
- def save_markdown_content(markdown_data: Dict[str, Any], output_dir: str,
- filename: str, normalize_numbers: bool = True, key_text: str = 'text', key_images: str = 'images') -> str:
- """
- 保存Markdown内容,支持数字标准化
- """
- if not markdown_data:
- return ""
- output_path = Path(output_dir).resolve()
- output_path.mkdir(parents=True, exist_ok=True)
-
- # 保存Markdown文本
- markdown_text = markdown_data.get(key_text, '')
-
- # 数字标准化处理
- changes_count = 0
- if normalize_numbers and markdown_text:
- original_markdown_text = markdown_text
- markdown_text = normalize_markdown_table(markdown_text)
-
- changes_count = len([1 for o, n in zip(original_markdown_text, markdown_text) if o != n])
- # if changes_count > 0:
- # print(f"🔧 Markdown中已标准化 {changes_count} 个字符(全角→半角)")
-
- md_file_path = output_path / f"{filename}.md"
- with open(md_file_path, 'w', encoding='utf-8') as f:
- f.write(markdown_text)
-
- # 如果启用了标准化且有变化,保存原始版本用于对比
- if normalize_numbers and changes_count > 0:
- original_output_path = output_path / f"{filename}_original.md"
- with open(original_output_path, 'w', encoding='utf-8') as f:
- f.write(original_markdown_text)
- # 保存Markdown中的图像
- markdown_images = markdown_data.get(key_images, {})
- for img_path, img_data in markdown_images.items():
- try:
- full_img_path = output_path / img_path
- full_img_path.parent.mkdir(parents=True, exist_ok=True)
- save_image(img_data, str(full_img_path))
-
- except Exception as e:
- print(f"❌ Error saving Markdown image {img_path}: {e}")
- return str(md_file_path)
|