|
|
@@ -0,0 +1,468 @@
|
|
|
+import json
|
|
|
+import base64
|
|
|
+from pathlib import Path
|
|
|
+from typing import Dict, Any, List
|
|
|
+import requests
|
|
|
+from utils.normalize_financial_numbers import normalize_financial_numbers, normalize_markdown_table, normalize_json_table
|
|
|
+
|
|
|
+def convert_api_result_to_json(api_result: Dict[str, Any],
|
|
|
+ input_image_path: str,
|
|
|
+ output_json_path: str,
|
|
|
+ normalize_numbers: bool = True) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 将API返回结果转换为标准JSON格式,并支持数字标准化
|
|
|
+
|
|
|
+ Args:
|
|
|
+ api_result: API返回的结果
|
|
|
+ input_image_path: 输入图像路径
|
|
|
+ output_json_path: 输出JSON文件路径
|
|
|
+ normalize_numbers: 是否标准化数字格式,默认True
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 转换后的JSON数据
|
|
|
+ """
|
|
|
+
|
|
|
+ # 获取主要数据
|
|
|
+ layout_parsing_results = api_result.get('layoutParsingResults', [])
|
|
|
+ data_info = api_result.get('dataInfo', {})
|
|
|
+
|
|
|
+ if not layout_parsing_results:
|
|
|
+ print("⚠️ Warning: No layoutParsingResults found in API response")
|
|
|
+ return {}
|
|
|
+
|
|
|
+ # 取第一个结果(通常只有一个)
|
|
|
+ main_result = layout_parsing_results[0]
|
|
|
+ pruned_result = main_result.get('prunedResult', {})
|
|
|
+
|
|
|
+ # 构造标准格式的JSON
|
|
|
+ converted_json = {
|
|
|
+ "input_path": input_image_path,
|
|
|
+ "page_index": None,
|
|
|
+ "model_settings": pruned_result.get('model_settings', {}),
|
|
|
+ "parsing_res_list": pruned_result.get('parsing_res_list', []),
|
|
|
+ "doc_preprocessor_res": {
|
|
|
+ "input_path": None,
|
|
|
+ "page_index": None,
|
|
|
+ "model_settings": pruned_result.get('doc_preprocessor_res', {}).get('model_settings', {}),
|
|
|
+ "angle": pruned_result.get('doc_preprocessor_res', {}).get('angle', 0)
|
|
|
+ },
|
|
|
+ "layout_det_res": {
|
|
|
+ "input_path": None,
|
|
|
+ "page_index": None,
|
|
|
+ "boxes": pruned_result.get('layout_det_res', {}).get('boxes', [])
|
|
|
+ },
|
|
|
+ "overall_ocr_res": {
|
|
|
+ "input_path": None,
|
|
|
+ "page_index": None,
|
|
|
+ "model_settings": pruned_result.get('overall_ocr_res', {}).get('model_settings', {}),
|
|
|
+ "dt_polys": pruned_result.get('overall_ocr_res', {}).get('dt_polys', []),
|
|
|
+ "text_det_params": pruned_result.get('overall_ocr_res', {}).get('text_det_params', {}),
|
|
|
+ "text_type": pruned_result.get('overall_ocr_res', {}).get('text_type', 'general'),
|
|
|
+ "textline_orientation_angles": pruned_result.get('overall_ocr_res', {}).get('textline_orientation_angles', []),
|
|
|
+ "text_rec_score_thresh": pruned_result.get('overall_ocr_res', {}).get('text_rec_score_thresh', 0.0),
|
|
|
+ "return_word_box": pruned_result.get('overall_ocr_res', {}).get('return_word_box', False),
|
|
|
+ "rec_texts": pruned_result.get('overall_ocr_res', {}).get('rec_texts', []),
|
|
|
+ "rec_scores": pruned_result.get('overall_ocr_res', {}).get('rec_scores', []),
|
|
|
+ "rec_polys": pruned_result.get('overall_ocr_res', {}).get('rec_polys', []),
|
|
|
+ "rec_boxes": pruned_result.get('overall_ocr_res', {}).get('rec_boxes', [])
|
|
|
+ },
|
|
|
+ "table_res_list": pruned_result.get('table_res_list', [])
|
|
|
+ }
|
|
|
+
|
|
|
+ # 数字标准化处理
|
|
|
+ original_json = converted_json.copy()
|
|
|
+ changes_count = 0
|
|
|
+
|
|
|
+ if normalize_numbers:
|
|
|
+ print("🔧 正在标准化数字格式...")
|
|
|
+
|
|
|
+ # 1. 标准化 parsing_res_list 中的文本内容
|
|
|
+ for item in converted_json.get('parsing_res_list', []):
|
|
|
+ if 'block_content' in item:
|
|
|
+ original_content = item['block_content']
|
|
|
+
|
|
|
+ # 根据block_label类型选择标准化方法
|
|
|
+ if item.get('block_label') == 'table':
|
|
|
+ # 表格内容使用表格专用标准化
|
|
|
+ normalized_content = normalize_markdown_table(original_content)
|
|
|
+ else:
|
|
|
+ # 普通文本使用通用标准化
|
|
|
+ normalized_content = normalize_financial_numbers(original_content)
|
|
|
+
|
|
|
+ if original_content != normalized_content:
|
|
|
+ item['block_content'] = normalized_content
|
|
|
+ changes_count += len([1 for o, n in zip(original_content, normalized_content) if o != n])
|
|
|
+
|
|
|
+ # 2. 标准化 table_res_list 中的HTML表格
|
|
|
+ for table_item in converted_json.get('table_res_list', []):
|
|
|
+ if 'pred_html' in table_item:
|
|
|
+ original_html = table_item['pred_html']
|
|
|
+ normalized_html = normalize_markdown_table(original_html)
|
|
|
+
|
|
|
+ if original_html != normalized_html:
|
|
|
+ table_item['pred_html'] = normalized_html
|
|
|
+ changes_count += len([1 for o, n in zip(original_html, normalized_html) if o != n])
|
|
|
+
|
|
|
+ # 3. 标准化 overall_ocr_res 中的识别文本
|
|
|
+ ocr_res = converted_json.get('overall_ocr_res', {})
|
|
|
+ if 'rec_texts' in ocr_res:
|
|
|
+ original_texts = ocr_res['rec_texts'][:]
|
|
|
+ normalized_texts = []
|
|
|
+
|
|
|
+ for text in original_texts:
|
|
|
+ normalized_text = normalize_financial_numbers(text)
|
|
|
+ normalized_texts.append(normalized_text)
|
|
|
+ if text != normalized_text:
|
|
|
+ changes_count += len([1 for o, n in zip(text, normalized_text) if o != n])
|
|
|
+
|
|
|
+ ocr_res['rec_texts'] = normalized_texts
|
|
|
+
|
|
|
+ # 添加标准化处理信息
|
|
|
+ converted_json['processing_info'] = {
|
|
|
+ "normalize_numbers": normalize_numbers,
|
|
|
+ "changes_applied": changes_count > 0,
|
|
|
+ "character_changes_count": changes_count
|
|
|
+ }
|
|
|
+
|
|
|
+ if changes_count > 0:
|
|
|
+ print(f"✅ 已标准化 {changes_count} 个字符(全角→半角)")
|
|
|
+ else:
|
|
|
+ print("ℹ️ 无需标准化(已是标准格式)")
|
|
|
+ else:
|
|
|
+ converted_json['processing_info'] = {
|
|
|
+ "normalize_numbers": False,
|
|
|
+ "changes_applied": False,
|
|
|
+ "character_changes_count": 0
|
|
|
+ }
|
|
|
+
|
|
|
+ # 保存JSON文件
|
|
|
+ output_path = Path(output_json_path)
|
|
|
+ output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ with open(output_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(converted_json, f, ensure_ascii=False, indent=4)
|
|
|
+
|
|
|
+ # 如果启用了标准化且有变化,保存原始版本用于对比
|
|
|
+ if normalize_numbers and changes_count > 0:
|
|
|
+ original_output_path = output_path.parent / f"{output_path.stem}_original.json"
|
|
|
+ with open(original_output_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(original_json, f, ensure_ascii=False, indent=4)
|
|
|
+ print(f"📄 原始JSON已保存到: {original_output_path}")
|
|
|
+
|
|
|
+ print(f"✅ Converted JSON saved to: {output_path}")
|
|
|
+ return converted_json
|
|
|
+
|
|
|
+def save_output_images(api_result: Dict[str, Any], output_dir: str) -> Dict[str, str]:
|
|
|
+ """
|
|
|
+ 保存API返回的输出图像
|
|
|
+
|
|
|
+ Args:
|
|
|
+ api_result: API返回的结果
|
|
|
+ output_dir: 输出目录
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 保存的图像文件路径字典
|
|
|
+ """
|
|
|
+ layout_parsing_results = api_result.get('layoutParsingResults', [])
|
|
|
+ if not layout_parsing_results:
|
|
|
+ return {}
|
|
|
+
|
|
|
+ main_result = layout_parsing_results[0]
|
|
|
+ output_images = main_result.get('outputImages', {})
|
|
|
+
|
|
|
+ output_dir = Path(output_dir)
|
|
|
+ output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ saved_images = {}
|
|
|
+
|
|
|
+ for img_name, img_base64 in output_images.items():
|
|
|
+ try:
|
|
|
+ # 解码base64图像
|
|
|
+ img_data = base64.b64decode(img_base64)
|
|
|
+
|
|
|
+ # 生成文件名
|
|
|
+ img_filename = f"{img_name}.jpg"
|
|
|
+ img_path = output_dir / img_filename
|
|
|
+
|
|
|
+ # 保存图像
|
|
|
+ with open(img_path, 'wb') as f:
|
|
|
+ f.write(img_data)
|
|
|
+
|
|
|
+ saved_images[img_name] = str(img_path)
|
|
|
+ print(f"📷 Saved image: {img_path}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ Error saving image {img_name}: {e}")
|
|
|
+
|
|
|
+ return saved_images
|
|
|
+
|
|
|
+def save_markdown_content(api_result: Dict[str, Any], output_dir: str, normalize_numbers: bool = True) -> str:
|
|
|
+ """
|
|
|
+ 保存Markdown内容和相关图像,支持数字标准化
|
|
|
+
|
|
|
+ Args:
|
|
|
+ api_result: API返回的结果
|
|
|
+ output_dir: 输出目录
|
|
|
+ normalize_numbers: 是否标准化数字格式
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Markdown文件路径
|
|
|
+ """
|
|
|
+ layout_parsing_results = api_result.get('layoutParsingResults', [])
|
|
|
+ if not layout_parsing_results:
|
|
|
+ return ""
|
|
|
+
|
|
|
+ main_result = layout_parsing_results[0]
|
|
|
+ markdown_data = main_result.get('markdown', {})
|
|
|
+
|
|
|
+ output_dir = Path(output_dir)
|
|
|
+ output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ # 保存Markdown文本
|
|
|
+ markdown_text = markdown_data.get('text', '')
|
|
|
+ original_markdown_text = markdown_text
|
|
|
+
|
|
|
+ # 数字标准化处理
|
|
|
+ if normalize_numbers:
|
|
|
+ print("🔧 正在标准化Markdown中的数字格式...")
|
|
|
+ markdown_text = normalize_markdown_table(markdown_text)
|
|
|
+
|
|
|
+ changes_count = len([1 for o, n in zip(original_markdown_text, markdown_text) if o != n])
|
|
|
+ if changes_count > 0:
|
|
|
+ print(f"✅ Markdown中已标准化 {changes_count} 个字符(全角→半角)")
|
|
|
+ else:
|
|
|
+ print("ℹ️ Markdown无需标准化(已是标准格式)")
|
|
|
+
|
|
|
+ md_file_path = output_dir / 'document.md'
|
|
|
+
|
|
|
+ with open(md_file_path, 'w', encoding='utf-8') as f:
|
|
|
+ f.write(markdown_text)
|
|
|
+
|
|
|
+ # 如果启用了标准化且有变化,保存原始版本
|
|
|
+ if normalize_numbers and markdown_text != original_markdown_text:
|
|
|
+ original_md_path = output_dir / 'document_original.md'
|
|
|
+ with open(original_md_path, 'w', encoding='utf-8') as f:
|
|
|
+ f.write(original_markdown_text)
|
|
|
+ print(f"📄 原始Markdown已保存到: {original_md_path}")
|
|
|
+
|
|
|
+ print(f"📝 Saved Markdown: {md_file_path}")
|
|
|
+
|
|
|
+ # 保存Markdown中的图像
|
|
|
+ markdown_images = markdown_data.get('images', {})
|
|
|
+ for img_path, img_base64 in markdown_images.items():
|
|
|
+ try:
|
|
|
+ img_data = base64.b64decode(img_base64)
|
|
|
+ full_img_path = output_dir / img_path
|
|
|
+ full_img_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ with open(full_img_path, 'wb') as f:
|
|
|
+ f.write(img_data)
|
|
|
+
|
|
|
+ print(f"🖼️ Saved Markdown image: {full_img_path}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ Error saving Markdown image {img_path}: {e}")
|
|
|
+
|
|
|
+ return str(md_file_path)
|
|
|
+
|
|
|
+def test_ppstructurev3_api_enhanced(image_path: str,
|
|
|
+ API_URL: str,
|
|
|
+ output_dir: str = "./api_output",
|
|
|
+ normalize_numbers: bool = True) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 增强版的PP-StructureV3 API测试,保存为标准格式并支持数字标准化
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image_path: 输入图像路径
|
|
|
+ API_URL: API URL
|
|
|
+ output_dir: 输出目录
|
|
|
+ normalize_numbers: 是否标准化数字格式,默认True
|
|
|
+ """
|
|
|
+ # 对本地图像进行Base64编码
|
|
|
+ with open(image_path, "rb") as file:
|
|
|
+ image_bytes = file.read()
|
|
|
+ image_data = base64.b64encode(image_bytes).decode("ascii")
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ "file": image_data,
|
|
|
+ "fileType": 1,
|
|
|
+ }
|
|
|
+
|
|
|
+ # 调用API
|
|
|
+ print(f"🚀 Calling API: {API_URL}")
|
|
|
+ print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
|
|
|
+
|
|
|
+ response = requests.post(API_URL, json=payload)
|
|
|
+
|
|
|
+ # 处理接口返回数据
|
|
|
+ assert response.status_code == 200
|
|
|
+ api_result = response.json()["result"]
|
|
|
+
|
|
|
+ # 创建输出目录
|
|
|
+ output_path = Path(output_dir)
|
|
|
+ output_path.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ # 获取输入文件的基本名称
|
|
|
+ input_name = Path(image_path).stem
|
|
|
+
|
|
|
+ # 1. 转换并保存标准JSON格式(包含数字标准化)
|
|
|
+ json_output_path = output_path / f"{input_name}.json"
|
|
|
+ converted_json = convert_api_result_to_json(
|
|
|
+ api_result,
|
|
|
+ image_path,
|
|
|
+ str(json_output_path),
|
|
|
+ normalize_numbers=normalize_numbers
|
|
|
+ )
|
|
|
+
|
|
|
+ # 2. 保存输出图像
|
|
|
+ images_dir = output_path / f"{input_name}_images"
|
|
|
+ saved_images = save_output_images(api_result, str(images_dir))
|
|
|
+
|
|
|
+ # 3. 保存Markdown内容(包含数字标准化)
|
|
|
+ markdown_dir = output_path / f"{input_name}_markdown"
|
|
|
+ markdown_file = save_markdown_content(api_result, str(markdown_dir), normalize_numbers=normalize_numbers)
|
|
|
+
|
|
|
+ # 4. 保存完整的API响应(用于调试)
|
|
|
+ full_response_path = output_path / f"{input_name}_full_response.json"
|
|
|
+ with open(full_response_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(api_result, f, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ print(f"📊 Processing completed!")
|
|
|
+ print(f" Standard JSON: {json_output_path}")
|
|
|
+ print(f" Images directory: {images_dir}")
|
|
|
+ print(f" Markdown file: {markdown_file}")
|
|
|
+ print(f" Full response: {full_response_path}")
|
|
|
+
|
|
|
+ # 打印详细统计(仿照ocr_by_vlm.py)
|
|
|
+ processing_info = converted_json.get('processing_info', {})
|
|
|
+ print("\n📊 API处理统计")
|
|
|
+ print(f" 原始图片: {Path(image_path).resolve().as_posix()}")
|
|
|
+ print(f" 输出路径: {json_output_path.resolve().as_posix()}")
|
|
|
+ print(f" API地址: {API_URL}")
|
|
|
+ print(f" 数字标准化: {processing_info.get('normalize_numbers', False)}")
|
|
|
+ if normalize_numbers:
|
|
|
+ print(f" 字符变化数: {processing_info.get('character_changes_count', 0)}")
|
|
|
+ print(f" 应用了标准化: {processing_info.get('changes_applied', False)}")
|
|
|
+
|
|
|
+ return {
|
|
|
+ "standard_json": str(json_output_path),
|
|
|
+ "images_dir": str(images_dir),
|
|
|
+ "markdown_file": markdown_file,
|
|
|
+ "full_response": str(full_response_path),
|
|
|
+ "converted_data": converted_json,
|
|
|
+ "processing_info": processing_info
|
|
|
+ }
|
|
|
+
|
|
|
+def batch_process_api_results(image_list: List[str],
|
|
|
+ API_URL: str,
|
|
|
+ output_base_dir: str,
|
|
|
+ normalize_numbers: bool = True) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 批量处理多个图像文件
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image_list: 图像文件路径列表
|
|
|
+ API_URL: API URL
|
|
|
+ output_base_dir: 输出基础目录
|
|
|
+ normalize_numbers: 是否标准化数字格式
|
|
|
+ """
|
|
|
+ results = []
|
|
|
+
|
|
|
+ print(f"🚀 开始批量处理 {len(image_list)} 个图像文件...")
|
|
|
+ print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
|
|
|
+
|
|
|
+ for i, image_path in enumerate(image_list):
|
|
|
+ try:
|
|
|
+ print(f"\n🔄 Processing {i+1}/{len(image_list)}: {Path(image_path).name}")
|
|
|
+
|
|
|
+ # 为每个文件创建单独的输出目录
|
|
|
+ output_dir = Path(output_base_dir) / f"result_{i+1:03d}_{Path(image_path).stem}"
|
|
|
+
|
|
|
+ result = test_ppstructurev3_api_enhanced(
|
|
|
+ image_path,
|
|
|
+ API_URL,
|
|
|
+ str(output_dir),
|
|
|
+ normalize_numbers=normalize_numbers
|
|
|
+ )
|
|
|
+ results.append(result)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ Error processing {image_path}: {e}")
|
|
|
+ results.append({"error": str(e), "image_path": image_path})
|
|
|
+
|
|
|
+ # 生成批量处理统计
|
|
|
+ success_count = sum(1 for r in results if 'error' not in r)
|
|
|
+ total_changes = sum(r.get('processing_info', {}).get('character_changes_count', 0) for r in results if 'processing_info' in r)
|
|
|
+
|
|
|
+ print(f"\n📊 批量处理完成统计")
|
|
|
+ print(f" 总文件数: {len(image_list)}")
|
|
|
+ print(f" 成功处理: {success_count}")
|
|
|
+ print(f" 失败数量: {len(image_list) - success_count}")
|
|
|
+ if normalize_numbers:
|
|
|
+ print(f" 总标准化字符数: {total_changes}")
|
|
|
+
|
|
|
+ return results
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ import argparse
|
|
|
+
|
|
|
+ parser = argparse.ArgumentParser(description='PP-StructureV3 API客户端工具')
|
|
|
+ parser.add_argument('image_path', nargs='?', help='图片文件路径')
|
|
|
+ parser.add_argument('-u', '--url', default='http://localhost:8080/layout-parsing', help='API URL')
|
|
|
+ parser.add_argument('-o', '--output', default='./api_conversion_output', help='输出目录')
|
|
|
+ parser.add_argument('--no-normalize', action='store_true', help='禁用数字标准化')
|
|
|
+ parser.add_argument('--batch', help='批量处理,指定包含图像路径的文本文件')
|
|
|
+
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ normalize_numbers = not args.no_normalize
|
|
|
+
|
|
|
+ try:
|
|
|
+ if args.batch:
|
|
|
+ # 批量处理模式
|
|
|
+ with open(args.batch, 'r', encoding='utf-8') as f:
|
|
|
+ image_list = [line.strip() for line in f if line.strip()]
|
|
|
+
|
|
|
+ results = batch_process_api_results(
|
|
|
+ image_list,
|
|
|
+ args.url,
|
|
|
+ args.output,
|
|
|
+ normalize_numbers=normalize_numbers
|
|
|
+ )
|
|
|
+ print("\n🎉 批量API处理完成!")
|
|
|
+
|
|
|
+ elif args.image_path:
|
|
|
+ # 单文件处理模式
|
|
|
+ result = test_ppstructurev3_api_enhanced(
|
|
|
+ args.image_path,
|
|
|
+ args.url,
|
|
|
+ args.output,
|
|
|
+ normalize_numbers=normalize_numbers
|
|
|
+ )
|
|
|
+
|
|
|
+ # 验证转换结果
|
|
|
+ with open(result["standard_json"], 'r', encoding='utf-8') as f:
|
|
|
+ converted_data = json.load(f)
|
|
|
+
|
|
|
+ print(f"\n📋 转换后的数据包含:")
|
|
|
+ print(f" - 解析结果块数: {len(converted_data.get('parsing_res_list', []))}")
|
|
|
+ print(f" - OCR文本数: {len(converted_data.get('overall_ocr_res', {}).get('rec_texts', []))}")
|
|
|
+ print(f" - 表格数: {len(converted_data.get('table_res_list', []))}")
|
|
|
+
|
|
|
+ print("\n🎉 API测试和转换完成!")
|
|
|
+
|
|
|
+ else:
|
|
|
+ # 默认示例
|
|
|
+ image_path = "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/data_DotsOCR_Results/2023年度报告母公司/2023年度报告母公司_page_004.png"
|
|
|
+ result = test_ppstructurev3_api_enhanced(
|
|
|
+ image_path,
|
|
|
+ args.url,
|
|
|
+ args.output,
|
|
|
+ normalize_numbers=normalize_numbers
|
|
|
+ )
|
|
|
+ print("\n🎉 API测试和转换完成!")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 处理失败: {e}")
|
|
|
+ import traceback
|
|
|
+ traceback.print_exc()
|