| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468 |
- import json
- import base64
- from pathlib import Path
- from typing import Dict, Any, List
- import requests
- from utils.normalize_financial_numbers import normalize_financial_numbers, normalize_markdown_table, normalize_json_table
- def convert_api_result_to_json(api_result: Dict[str, Any],
- input_image_path: str,
- output_json_path: str,
- normalize_numbers: bool = True) -> Dict[str, Any]:
- """
- 将API返回结果转换为标准JSON格式,并支持数字标准化
-
- Args:
- api_result: API返回的结果
- input_image_path: 输入图像路径
- output_json_path: 输出JSON文件路径
- normalize_numbers: 是否标准化数字格式,默认True
-
- Returns:
- 转换后的JSON数据
- """
-
- # 获取主要数据
- layout_parsing_results = api_result.get('layoutParsingResults', [])
- data_info = api_result.get('dataInfo', {})
-
- if not layout_parsing_results:
- print("⚠️ Warning: No layoutParsingResults found in API response")
- return {}
-
- # 取第一个结果(通常只有一个)
- main_result = layout_parsing_results[0]
- pruned_result = main_result.get('prunedResult', {})
-
- # 构造标准格式的JSON
- converted_json = {
- "input_path": input_image_path,
- "page_index": None,
- "model_settings": pruned_result.get('model_settings', {}),
- "parsing_res_list": pruned_result.get('parsing_res_list', []),
- "doc_preprocessor_res": {
- "input_path": None,
- "page_index": None,
- "model_settings": pruned_result.get('doc_preprocessor_res', {}).get('model_settings', {}),
- "angle": pruned_result.get('doc_preprocessor_res', {}).get('angle', 0)
- },
- "layout_det_res": {
- "input_path": None,
- "page_index": None,
- "boxes": pruned_result.get('layout_det_res', {}).get('boxes', [])
- },
- "overall_ocr_res": {
- "input_path": None,
- "page_index": None,
- "model_settings": pruned_result.get('overall_ocr_res', {}).get('model_settings', {}),
- "dt_polys": pruned_result.get('overall_ocr_res', {}).get('dt_polys', []),
- "text_det_params": pruned_result.get('overall_ocr_res', {}).get('text_det_params', {}),
- "text_type": pruned_result.get('overall_ocr_res', {}).get('text_type', 'general'),
- "textline_orientation_angles": pruned_result.get('overall_ocr_res', {}).get('textline_orientation_angles', []),
- "text_rec_score_thresh": pruned_result.get('overall_ocr_res', {}).get('text_rec_score_thresh', 0.0),
- "return_word_box": pruned_result.get('overall_ocr_res', {}).get('return_word_box', False),
- "rec_texts": pruned_result.get('overall_ocr_res', {}).get('rec_texts', []),
- "rec_scores": pruned_result.get('overall_ocr_res', {}).get('rec_scores', []),
- "rec_polys": pruned_result.get('overall_ocr_res', {}).get('rec_polys', []),
- "rec_boxes": pruned_result.get('overall_ocr_res', {}).get('rec_boxes', [])
- },
- "table_res_list": pruned_result.get('table_res_list', [])
- }
-
- # 数字标准化处理
- original_json = converted_json.copy()
- changes_count = 0
-
- if normalize_numbers:
- print("🔧 正在标准化数字格式...")
-
- # 1. 标准化 parsing_res_list 中的文本内容
- for item in converted_json.get('parsing_res_list', []):
- if 'block_content' in item:
- original_content = item['block_content']
-
- # 根据block_label类型选择标准化方法
- if item.get('block_label') == 'table':
- # 表格内容使用表格专用标准化
- normalized_content = normalize_markdown_table(original_content)
- else:
- # 普通文本使用通用标准化
- normalized_content = normalize_financial_numbers(original_content)
-
- if original_content != normalized_content:
- item['block_content'] = normalized_content
- changes_count += len([1 for o, n in zip(original_content, normalized_content) if o != n])
-
- # 2. 标准化 table_res_list 中的HTML表格
- for table_item in converted_json.get('table_res_list', []):
- if 'pred_html' in table_item:
- original_html = table_item['pred_html']
- normalized_html = normalize_markdown_table(original_html)
-
- if original_html != normalized_html:
- table_item['pred_html'] = normalized_html
- changes_count += len([1 for o, n in zip(original_html, normalized_html) if o != n])
-
- # 3. 标准化 overall_ocr_res 中的识别文本
- ocr_res = converted_json.get('overall_ocr_res', {})
- if 'rec_texts' in ocr_res:
- original_texts = ocr_res['rec_texts'][:]
- normalized_texts = []
-
- for text in original_texts:
- normalized_text = normalize_financial_numbers(text)
- normalized_texts.append(normalized_text)
- if text != normalized_text:
- changes_count += len([1 for o, n in zip(text, normalized_text) if o != n])
-
- ocr_res['rec_texts'] = normalized_texts
-
- # 添加标准化处理信息
- converted_json['processing_info'] = {
- "normalize_numbers": normalize_numbers,
- "changes_applied": changes_count > 0,
- "character_changes_count": changes_count
- }
-
- if changes_count > 0:
- print(f"✅ 已标准化 {changes_count} 个字符(全角→半角)")
- else:
- print("ℹ️ 无需标准化(已是标准格式)")
- else:
- converted_json['processing_info'] = {
- "normalize_numbers": False,
- "changes_applied": False,
- "character_changes_count": 0
- }
-
- # 保存JSON文件
- output_path = Path(output_json_path)
- output_path.parent.mkdir(parents=True, exist_ok=True)
-
- with open(output_path, 'w', encoding='utf-8') as f:
- json.dump(converted_json, f, ensure_ascii=False, indent=4)
-
- # 如果启用了标准化且有变化,保存原始版本用于对比
- if normalize_numbers and changes_count > 0:
- original_output_path = output_path.parent / f"{output_path.stem}_original.json"
- with open(original_output_path, 'w', encoding='utf-8') as f:
- json.dump(original_json, f, ensure_ascii=False, indent=4)
- print(f"📄 原始JSON已保存到: {original_output_path}")
-
- print(f"✅ Converted JSON saved to: {output_path}")
- return converted_json
- def save_output_images(api_result: Dict[str, Any], output_dir: str) -> Dict[str, str]:
- """
- 保存API返回的输出图像
-
- Args:
- api_result: API返回的结果
- output_dir: 输出目录
-
- Returns:
- 保存的图像文件路径字典
- """
- layout_parsing_results = api_result.get('layoutParsingResults', [])
- if not layout_parsing_results:
- return {}
-
- main_result = layout_parsing_results[0]
- output_images = main_result.get('outputImages', {})
-
- output_dir = Path(output_dir)
- output_dir.mkdir(parents=True, exist_ok=True)
-
- saved_images = {}
-
- for img_name, img_base64 in output_images.items():
- try:
- # 解码base64图像
- img_data = base64.b64decode(img_base64)
-
- # 生成文件名
- img_filename = f"{img_name}.jpg"
- img_path = output_dir / img_filename
-
- # 保存图像
- with open(img_path, 'wb') as f:
- f.write(img_data)
-
- saved_images[img_name] = str(img_path)
- print(f"📷 Saved image: {img_path}")
-
- except Exception as e:
- print(f"❌ Error saving image {img_name}: {e}")
-
- return saved_images
- def save_markdown_content(api_result: Dict[str, Any], output_dir: str, normalize_numbers: bool = True) -> str:
- """
- 保存Markdown内容和相关图像,支持数字标准化
-
- Args:
- api_result: API返回的结果
- output_dir: 输出目录
- normalize_numbers: 是否标准化数字格式
-
- Returns:
- Markdown文件路径
- """
- layout_parsing_results = api_result.get('layoutParsingResults', [])
- if not layout_parsing_results:
- return ""
-
- main_result = layout_parsing_results[0]
- markdown_data = main_result.get('markdown', {})
-
- output_dir = Path(output_dir)
- output_dir.mkdir(parents=True, exist_ok=True)
-
- # 保存Markdown文本
- markdown_text = markdown_data.get('text', '')
- original_markdown_text = markdown_text
-
- # 数字标准化处理
- if normalize_numbers:
- print("🔧 正在标准化Markdown中的数字格式...")
- markdown_text = normalize_markdown_table(markdown_text)
-
- changes_count = len([1 for o, n in zip(original_markdown_text, markdown_text) if o != n])
- if changes_count > 0:
- print(f"✅ Markdown中已标准化 {changes_count} 个字符(全角→半角)")
- else:
- print("ℹ️ Markdown无需标准化(已是标准格式)")
-
- md_file_path = output_dir / 'document.md'
-
- with open(md_file_path, 'w', encoding='utf-8') as f:
- f.write(markdown_text)
-
- # 如果启用了标准化且有变化,保存原始版本
- if normalize_numbers and markdown_text != original_markdown_text:
- original_md_path = output_dir / 'document_original.md'
- with open(original_md_path, 'w', encoding='utf-8') as f:
- f.write(original_markdown_text)
- print(f"📄 原始Markdown已保存到: {original_md_path}")
-
- print(f"📝 Saved Markdown: {md_file_path}")
-
- # 保存Markdown中的图像
- markdown_images = markdown_data.get('images', {})
- for img_path, img_base64 in markdown_images.items():
- try:
- img_data = base64.b64decode(img_base64)
- full_img_path = output_dir / img_path
- full_img_path.parent.mkdir(parents=True, exist_ok=True)
-
- with open(full_img_path, 'wb') as f:
- f.write(img_data)
-
- print(f"🖼️ Saved Markdown image: {full_img_path}")
-
- except Exception as e:
- print(f"❌ Error saving Markdown image {img_path}: {e}")
-
- return str(md_file_path)
- def test_ppstructurev3_api_enhanced(image_path: str,
- API_URL: str,
- output_dir: str = "./api_output",
- normalize_numbers: bool = True) -> Dict[str, Any]:
- """
- 增强版的PP-StructureV3 API测试,保存为标准格式并支持数字标准化
-
- Args:
- image_path: 输入图像路径
- API_URL: API URL
- output_dir: 输出目录
- normalize_numbers: 是否标准化数字格式,默认True
- """
- # 对本地图像进行Base64编码
- with open(image_path, "rb") as file:
- image_bytes = file.read()
- image_data = base64.b64encode(image_bytes).decode("ascii")
- payload = {
- "file": image_data,
- "fileType": 1,
- }
- # 调用API
- print(f"🚀 Calling API: {API_URL}")
- print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
-
- response = requests.post(API_URL, json=payload)
- # 处理接口返回数据
- assert response.status_code == 200
- api_result = response.json()["result"]
-
- # 创建输出目录
- output_path = Path(output_dir)
- output_path.mkdir(parents=True, exist_ok=True)
-
- # 获取输入文件的基本名称
- input_name = Path(image_path).stem
-
- # 1. 转换并保存标准JSON格式(包含数字标准化)
- json_output_path = output_path / f"{input_name}.json"
- converted_json = convert_api_result_to_json(
- api_result,
- image_path,
- str(json_output_path),
- normalize_numbers=normalize_numbers
- )
-
- # 2. 保存输出图像
- images_dir = output_path / f"{input_name}_images"
- saved_images = save_output_images(api_result, str(images_dir))
-
- # 3. 保存Markdown内容(包含数字标准化)
- markdown_dir = output_path / f"{input_name}_markdown"
- markdown_file = save_markdown_content(api_result, str(markdown_dir), normalize_numbers=normalize_numbers)
-
- # 4. 保存完整的API响应(用于调试)
- full_response_path = output_path / f"{input_name}_full_response.json"
- with open(full_response_path, 'w', encoding='utf-8') as f:
- json.dump(api_result, f, ensure_ascii=False, indent=2)
-
- print(f"📊 Processing completed!")
- print(f" Standard JSON: {json_output_path}")
- print(f" Images directory: {images_dir}")
- print(f" Markdown file: {markdown_file}")
- print(f" Full response: {full_response_path}")
-
- # 打印详细统计(仿照ocr_by_vlm.py)
- processing_info = converted_json.get('processing_info', {})
- print("\n📊 API处理统计")
- print(f" 原始图片: {Path(image_path).resolve().as_posix()}")
- print(f" 输出路径: {json_output_path.resolve().as_posix()}")
- print(f" API地址: {API_URL}")
- print(f" 数字标准化: {processing_info.get('normalize_numbers', False)}")
- if normalize_numbers:
- print(f" 字符变化数: {processing_info.get('character_changes_count', 0)}")
- print(f" 应用了标准化: {processing_info.get('changes_applied', False)}")
-
- return {
- "standard_json": str(json_output_path),
- "images_dir": str(images_dir),
- "markdown_file": markdown_file,
- "full_response": str(full_response_path),
- "converted_data": converted_json,
- "processing_info": processing_info
- }
- def batch_process_api_results(image_list: List[str],
- API_URL: str,
- output_base_dir: str,
- normalize_numbers: bool = True) -> List[Dict[str, Any]]:
- """
- 批量处理多个图像文件
-
- Args:
- image_list: 图像文件路径列表
- API_URL: API URL
- output_base_dir: 输出基础目录
- normalize_numbers: 是否标准化数字格式
- """
- results = []
-
- print(f"🚀 开始批量处理 {len(image_list)} 个图像文件...")
- print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
-
- for i, image_path in enumerate(image_list):
- try:
- print(f"\n🔄 Processing {i+1}/{len(image_list)}: {Path(image_path).name}")
-
- # 为每个文件创建单独的输出目录
- output_dir = Path(output_base_dir) / f"result_{i+1:03d}_{Path(image_path).stem}"
-
- result = test_ppstructurev3_api_enhanced(
- image_path,
- API_URL,
- str(output_dir),
- normalize_numbers=normalize_numbers
- )
- results.append(result)
-
- except Exception as e:
- print(f"❌ Error processing {image_path}: {e}")
- results.append({"error": str(e), "image_path": image_path})
-
- # 生成批量处理统计
- success_count = sum(1 for r in results if 'error' not in r)
- total_changes = sum(r.get('processing_info', {}).get('character_changes_count', 0) for r in results if 'processing_info' in r)
-
- print(f"\n📊 批量处理完成统计")
- print(f" 总文件数: {len(image_list)}")
- print(f" 成功处理: {success_count}")
- print(f" 失败数量: {len(image_list) - success_count}")
- if normalize_numbers:
- print(f" 总标准化字符数: {total_changes}")
-
- return results
- if __name__ == "__main__":
- import argparse
-
- parser = argparse.ArgumentParser(description='PP-StructureV3 API客户端工具')
- parser.add_argument('image_path', nargs='?', help='图片文件路径')
- parser.add_argument('-u', '--url', default='http://localhost:8080/layout-parsing', help='API URL')
- parser.add_argument('-o', '--output', default='./api_conversion_output', help='输出目录')
- parser.add_argument('--no-normalize', action='store_true', help='禁用数字标准化')
- parser.add_argument('--batch', help='批量处理,指定包含图像路径的文本文件')
-
- args = parser.parse_args()
-
- normalize_numbers = not args.no_normalize
-
- try:
- if args.batch:
- # 批量处理模式
- with open(args.batch, 'r', encoding='utf-8') as f:
- image_list = [line.strip() for line in f if line.strip()]
-
- results = batch_process_api_results(
- image_list,
- args.url,
- args.output,
- normalize_numbers=normalize_numbers
- )
- print("\n🎉 批量API处理完成!")
-
- elif args.image_path:
- # 单文件处理模式
- result = test_ppstructurev3_api_enhanced(
- args.image_path,
- args.url,
- args.output,
- normalize_numbers=normalize_numbers
- )
-
- # 验证转换结果
- with open(result["standard_json"], 'r', encoding='utf-8') as f:
- converted_data = json.load(f)
-
- print(f"\n📋 转换后的数据包含:")
- print(f" - 解析结果块数: {len(converted_data.get('parsing_res_list', []))}")
- print(f" - OCR文本数: {len(converted_data.get('overall_ocr_res', {}).get('rec_texts', []))}")
- print(f" - 表格数: {len(converted_data.get('table_res_list', []))}")
-
- print("\n🎉 API测试和转换完成!")
-
- else:
- # 默认示例
- image_path = "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/data_DotsOCR_Results/2023年度报告母公司/2023年度报告母公司_page_004.png"
- result = test_ppstructurev3_api_enhanced(
- image_path,
- args.url,
- args.output,
- normalize_numbers=normalize_numbers
- )
- print("\n🎉 API测试和转换完成!")
-
- except Exception as e:
- print(f"❌ 处理失败: {e}")
- import traceback
- traceback.print_exc()
|