test_ppstructure_v3_client.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. import json
  2. import base64
  3. from pathlib import Path
  4. from typing import Dict, Any, List
  5. import requests
  6. from utils.normalize_financial_numbers import normalize_financial_numbers, normalize_markdown_table, normalize_json_table
  7. def convert_api_result_to_json(api_result: Dict[str, Any],
  8. input_image_path: str,
  9. output_json_path: str,
  10. normalize_numbers: bool = True) -> Dict[str, Any]:
  11. """
  12. 将API返回结果转换为标准JSON格式,并支持数字标准化
  13. Args:
  14. api_result: API返回的结果
  15. input_image_path: 输入图像路径
  16. output_json_path: 输出JSON文件路径
  17. normalize_numbers: 是否标准化数字格式,默认True
  18. Returns:
  19. 转换后的JSON数据
  20. """
  21. # 获取主要数据
  22. layout_parsing_results = api_result.get('layoutParsingResults', [])
  23. data_info = api_result.get('dataInfo', {})
  24. if not layout_parsing_results:
  25. print("⚠️ Warning: No layoutParsingResults found in API response")
  26. return {}
  27. # 取第一个结果(通常只有一个)
  28. main_result = layout_parsing_results[0]
  29. pruned_result = main_result.get('prunedResult', {})
  30. # 构造标准格式的JSON
  31. converted_json = {
  32. "input_path": input_image_path,
  33. "page_index": None,
  34. "model_settings": pruned_result.get('model_settings', {}),
  35. "parsing_res_list": pruned_result.get('parsing_res_list', []),
  36. "doc_preprocessor_res": {
  37. "input_path": None,
  38. "page_index": None,
  39. "model_settings": pruned_result.get('doc_preprocessor_res', {}).get('model_settings', {}),
  40. "angle": pruned_result.get('doc_preprocessor_res', {}).get('angle', 0)
  41. },
  42. "layout_det_res": {
  43. "input_path": None,
  44. "page_index": None,
  45. "boxes": pruned_result.get('layout_det_res', {}).get('boxes', [])
  46. },
  47. "overall_ocr_res": {
  48. "input_path": None,
  49. "page_index": None,
  50. "model_settings": pruned_result.get('overall_ocr_res', {}).get('model_settings', {}),
  51. "dt_polys": pruned_result.get('overall_ocr_res', {}).get('dt_polys', []),
  52. "text_det_params": pruned_result.get('overall_ocr_res', {}).get('text_det_params', {}),
  53. "text_type": pruned_result.get('overall_ocr_res', {}).get('text_type', 'general'),
  54. "textline_orientation_angles": pruned_result.get('overall_ocr_res', {}).get('textline_orientation_angles', []),
  55. "text_rec_score_thresh": pruned_result.get('overall_ocr_res', {}).get('text_rec_score_thresh', 0.0),
  56. "return_word_box": pruned_result.get('overall_ocr_res', {}).get('return_word_box', False),
  57. "rec_texts": pruned_result.get('overall_ocr_res', {}).get('rec_texts', []),
  58. "rec_scores": pruned_result.get('overall_ocr_res', {}).get('rec_scores', []),
  59. "rec_polys": pruned_result.get('overall_ocr_res', {}).get('rec_polys', []),
  60. "rec_boxes": pruned_result.get('overall_ocr_res', {}).get('rec_boxes', [])
  61. },
  62. "table_res_list": pruned_result.get('table_res_list', [])
  63. }
  64. # 数字标准化处理
  65. original_json = converted_json.copy()
  66. changes_count = 0
  67. if normalize_numbers:
  68. print("🔧 正在标准化数字格式...")
  69. # 1. 标准化 parsing_res_list 中的文本内容
  70. for item in converted_json.get('parsing_res_list', []):
  71. if 'block_content' in item:
  72. original_content = item['block_content']
  73. # 根据block_label类型选择标准化方法
  74. if item.get('block_label') == 'table':
  75. # 表格内容使用表格专用标准化
  76. normalized_content = normalize_markdown_table(original_content)
  77. else:
  78. # 普通文本使用通用标准化
  79. normalized_content = normalize_financial_numbers(original_content)
  80. if original_content != normalized_content:
  81. item['block_content'] = normalized_content
  82. changes_count += len([1 for o, n in zip(original_content, normalized_content) if o != n])
  83. # 2. 标准化 table_res_list 中的HTML表格
  84. for table_item in converted_json.get('table_res_list', []):
  85. if 'pred_html' in table_item:
  86. original_html = table_item['pred_html']
  87. normalized_html = normalize_markdown_table(original_html)
  88. if original_html != normalized_html:
  89. table_item['pred_html'] = normalized_html
  90. changes_count += len([1 for o, n in zip(original_html, normalized_html) if o != n])
  91. # 3. 标准化 overall_ocr_res 中的识别文本
  92. ocr_res = converted_json.get('overall_ocr_res', {})
  93. if 'rec_texts' in ocr_res:
  94. original_texts = ocr_res['rec_texts'][:]
  95. normalized_texts = []
  96. for text in original_texts:
  97. normalized_text = normalize_financial_numbers(text)
  98. normalized_texts.append(normalized_text)
  99. if text != normalized_text:
  100. changes_count += len([1 for o, n in zip(text, normalized_text) if o != n])
  101. ocr_res['rec_texts'] = normalized_texts
  102. # 添加标准化处理信息
  103. converted_json['processing_info'] = {
  104. "normalize_numbers": normalize_numbers,
  105. "changes_applied": changes_count > 0,
  106. "character_changes_count": changes_count
  107. }
  108. if changes_count > 0:
  109. print(f"✅ 已标准化 {changes_count} 个字符(全角→半角)")
  110. else:
  111. print("ℹ️ 无需标准化(已是标准格式)")
  112. else:
  113. converted_json['processing_info'] = {
  114. "normalize_numbers": False,
  115. "changes_applied": False,
  116. "character_changes_count": 0
  117. }
  118. # 保存JSON文件
  119. output_path = Path(output_json_path)
  120. output_path.parent.mkdir(parents=True, exist_ok=True)
  121. with open(output_path, 'w', encoding='utf-8') as f:
  122. json.dump(converted_json, f, ensure_ascii=False, indent=4)
  123. # 如果启用了标准化且有变化,保存原始版本用于对比
  124. if normalize_numbers and changes_count > 0:
  125. original_output_path = output_path.parent / f"{output_path.stem}_original.json"
  126. with open(original_output_path, 'w', encoding='utf-8') as f:
  127. json.dump(original_json, f, ensure_ascii=False, indent=4)
  128. print(f"📄 原始JSON已保存到: {original_output_path}")
  129. print(f"✅ Converted JSON saved to: {output_path}")
  130. return converted_json
  131. def save_output_images(api_result: Dict[str, Any], output_dir: str) -> Dict[str, str]:
  132. """
  133. 保存API返回的输出图像
  134. Args:
  135. api_result: API返回的结果
  136. output_dir: 输出目录
  137. Returns:
  138. 保存的图像文件路径字典
  139. """
  140. layout_parsing_results = api_result.get('layoutParsingResults', [])
  141. if not layout_parsing_results:
  142. return {}
  143. main_result = layout_parsing_results[0]
  144. output_images = main_result.get('outputImages', {})
  145. output_dir = Path(output_dir)
  146. output_dir.mkdir(parents=True, exist_ok=True)
  147. saved_images = {}
  148. for img_name, img_base64 in output_images.items():
  149. try:
  150. # 解码base64图像
  151. img_data = base64.b64decode(img_base64)
  152. # 生成文件名
  153. img_filename = f"{img_name}.jpg"
  154. img_path = output_dir / img_filename
  155. # 保存图像
  156. with open(img_path, 'wb') as f:
  157. f.write(img_data)
  158. saved_images[img_name] = str(img_path)
  159. print(f"📷 Saved image: {img_path}")
  160. except Exception as e:
  161. print(f"❌ Error saving image {img_name}: {e}")
  162. return saved_images
  163. def save_markdown_content(api_result: Dict[str, Any], output_dir: str, normalize_numbers: bool = True) -> str:
  164. """
  165. 保存Markdown内容和相关图像,支持数字标准化
  166. Args:
  167. api_result: API返回的结果
  168. output_dir: 输出目录
  169. normalize_numbers: 是否标准化数字格式
  170. Returns:
  171. Markdown文件路径
  172. """
  173. layout_parsing_results = api_result.get('layoutParsingResults', [])
  174. if not layout_parsing_results:
  175. return ""
  176. main_result = layout_parsing_results[0]
  177. markdown_data = main_result.get('markdown', {})
  178. output_dir = Path(output_dir)
  179. output_dir.mkdir(parents=True, exist_ok=True)
  180. # 保存Markdown文本
  181. markdown_text = markdown_data.get('text', '')
  182. original_markdown_text = markdown_text
  183. # 数字标准化处理
  184. if normalize_numbers:
  185. print("🔧 正在标准化Markdown中的数字格式...")
  186. markdown_text = normalize_markdown_table(markdown_text)
  187. changes_count = len([1 for o, n in zip(original_markdown_text, markdown_text) if o != n])
  188. if changes_count > 0:
  189. print(f"✅ Markdown中已标准化 {changes_count} 个字符(全角→半角)")
  190. else:
  191. print("ℹ️ Markdown无需标准化(已是标准格式)")
  192. md_file_path = output_dir / 'document.md'
  193. with open(md_file_path, 'w', encoding='utf-8') as f:
  194. f.write(markdown_text)
  195. # 如果启用了标准化且有变化,保存原始版本
  196. if normalize_numbers and markdown_text != original_markdown_text:
  197. original_md_path = output_dir / 'document_original.md'
  198. with open(original_md_path, 'w', encoding='utf-8') as f:
  199. f.write(original_markdown_text)
  200. print(f"📄 原始Markdown已保存到: {original_md_path}")
  201. print(f"📝 Saved Markdown: {md_file_path}")
  202. # 保存Markdown中的图像
  203. markdown_images = markdown_data.get('images', {})
  204. for img_path, img_base64 in markdown_images.items():
  205. try:
  206. img_data = base64.b64decode(img_base64)
  207. full_img_path = output_dir / img_path
  208. full_img_path.parent.mkdir(parents=True, exist_ok=True)
  209. with open(full_img_path, 'wb') as f:
  210. f.write(img_data)
  211. print(f"🖼️ Saved Markdown image: {full_img_path}")
  212. except Exception as e:
  213. print(f"❌ Error saving Markdown image {img_path}: {e}")
  214. return str(md_file_path)
  215. def test_ppstructurev3_api_enhanced(image_path: str,
  216. API_URL: str,
  217. output_dir: str = "./api_output",
  218. normalize_numbers: bool = True) -> Dict[str, Any]:
  219. """
  220. 增强版的PP-StructureV3 API测试,保存为标准格式并支持数字标准化
  221. Args:
  222. image_path: 输入图像路径
  223. API_URL: API URL
  224. output_dir: 输出目录
  225. normalize_numbers: 是否标准化数字格式,默认True
  226. """
  227. # 对本地图像进行Base64编码
  228. with open(image_path, "rb") as file:
  229. image_bytes = file.read()
  230. image_data = base64.b64encode(image_bytes).decode("ascii")
  231. payload = {
  232. "file": image_data,
  233. "fileType": 1,
  234. }
  235. # 调用API
  236. print(f"🚀 Calling API: {API_URL}")
  237. print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
  238. response = requests.post(API_URL, json=payload)
  239. # 处理接口返回数据
  240. assert response.status_code == 200
  241. api_result = response.json()["result"]
  242. # 创建输出目录
  243. output_path = Path(output_dir)
  244. output_path.mkdir(parents=True, exist_ok=True)
  245. # 获取输入文件的基本名称
  246. input_name = Path(image_path).stem
  247. # 1. 转换并保存标准JSON格式(包含数字标准化)
  248. json_output_path = output_path / f"{input_name}.json"
  249. converted_json = convert_api_result_to_json(
  250. api_result,
  251. image_path,
  252. str(json_output_path),
  253. normalize_numbers=normalize_numbers
  254. )
  255. # 2. 保存输出图像
  256. images_dir = output_path / f"{input_name}_images"
  257. saved_images = save_output_images(api_result, str(images_dir))
  258. # 3. 保存Markdown内容(包含数字标准化)
  259. markdown_dir = output_path / f"{input_name}_markdown"
  260. markdown_file = save_markdown_content(api_result, str(markdown_dir), normalize_numbers=normalize_numbers)
  261. # 4. 保存完整的API响应(用于调试)
  262. full_response_path = output_path / f"{input_name}_full_response.json"
  263. with open(full_response_path, 'w', encoding='utf-8') as f:
  264. json.dump(api_result, f, ensure_ascii=False, indent=2)
  265. print(f"📊 Processing completed!")
  266. print(f" Standard JSON: {json_output_path}")
  267. print(f" Images directory: {images_dir}")
  268. print(f" Markdown file: {markdown_file}")
  269. print(f" Full response: {full_response_path}")
  270. # 打印详细统计(仿照ocr_by_vlm.py)
  271. processing_info = converted_json.get('processing_info', {})
  272. print("\n📊 API处理统计")
  273. print(f" 原始图片: {Path(image_path).resolve().as_posix()}")
  274. print(f" 输出路径: {json_output_path.resolve().as_posix()}")
  275. print(f" API地址: {API_URL}")
  276. print(f" 数字标准化: {processing_info.get('normalize_numbers', False)}")
  277. if normalize_numbers:
  278. print(f" 字符变化数: {processing_info.get('character_changes_count', 0)}")
  279. print(f" 应用了标准化: {processing_info.get('changes_applied', False)}")
  280. return {
  281. "standard_json": str(json_output_path),
  282. "images_dir": str(images_dir),
  283. "markdown_file": markdown_file,
  284. "full_response": str(full_response_path),
  285. "converted_data": converted_json,
  286. "processing_info": processing_info
  287. }
  288. def batch_process_api_results(image_list: List[str],
  289. API_URL: str,
  290. output_base_dir: str,
  291. normalize_numbers: bool = True) -> List[Dict[str, Any]]:
  292. """
  293. 批量处理多个图像文件
  294. Args:
  295. image_list: 图像文件路径列表
  296. API_URL: API URL
  297. output_base_dir: 输出基础目录
  298. normalize_numbers: 是否标准化数字格式
  299. """
  300. results = []
  301. print(f"🚀 开始批量处理 {len(image_list)} 个图像文件...")
  302. print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
  303. for i, image_path in enumerate(image_list):
  304. try:
  305. print(f"\n🔄 Processing {i+1}/{len(image_list)}: {Path(image_path).name}")
  306. # 为每个文件创建单独的输出目录
  307. output_dir = Path(output_base_dir) / f"result_{i+1:03d}_{Path(image_path).stem}"
  308. result = test_ppstructurev3_api_enhanced(
  309. image_path,
  310. API_URL,
  311. str(output_dir),
  312. normalize_numbers=normalize_numbers
  313. )
  314. results.append(result)
  315. except Exception as e:
  316. print(f"❌ Error processing {image_path}: {e}")
  317. results.append({"error": str(e), "image_path": image_path})
  318. # 生成批量处理统计
  319. success_count = sum(1 for r in results if 'error' not in r)
  320. total_changes = sum(r.get('processing_info', {}).get('character_changes_count', 0) for r in results if 'processing_info' in r)
  321. print(f"\n📊 批量处理完成统计")
  322. print(f" 总文件数: {len(image_list)}")
  323. print(f" 成功处理: {success_count}")
  324. print(f" 失败数量: {len(image_list) - success_count}")
  325. if normalize_numbers:
  326. print(f" 总标准化字符数: {total_changes}")
  327. return results
  328. if __name__ == "__main__":
  329. import argparse
  330. parser = argparse.ArgumentParser(description='PP-StructureV3 API客户端工具')
  331. parser.add_argument('image_path', nargs='?', help='图片文件路径')
  332. parser.add_argument('-u', '--url', default='http://localhost:8080/layout-parsing', help='API URL')
  333. parser.add_argument('-o', '--output', default='./api_conversion_output', help='输出目录')
  334. parser.add_argument('--no-normalize', action='store_true', help='禁用数字标准化')
  335. parser.add_argument('--batch', help='批量处理,指定包含图像路径的文本文件')
  336. args = parser.parse_args()
  337. normalize_numbers = not args.no_normalize
  338. try:
  339. if args.batch:
  340. # 批量处理模式
  341. with open(args.batch, 'r', encoding='utf-8') as f:
  342. image_list = [line.strip() for line in f if line.strip()]
  343. results = batch_process_api_results(
  344. image_list,
  345. args.url,
  346. args.output,
  347. normalize_numbers=normalize_numbers
  348. )
  349. print("\n🎉 批量API处理完成!")
  350. elif args.image_path:
  351. # 单文件处理模式
  352. result = test_ppstructurev3_api_enhanced(
  353. args.image_path,
  354. args.url,
  355. args.output,
  356. normalize_numbers=normalize_numbers
  357. )
  358. # 验证转换结果
  359. with open(result["standard_json"], 'r', encoding='utf-8') as f:
  360. converted_data = json.load(f)
  361. print(f"\n📋 转换后的数据包含:")
  362. print(f" - 解析结果块数: {len(converted_data.get('parsing_res_list', []))}")
  363. print(f" - OCR文本数: {len(converted_data.get('overall_ocr_res', {}).get('rec_texts', []))}")
  364. print(f" - 表格数: {len(converted_data.get('table_res_list', []))}")
  365. print("\n🎉 API测试和转换完成!")
  366. else:
  367. # 默认示例
  368. image_path = "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/data_DotsOCR_Results/2023年度报告母公司/2023年度报告母公司_page_004.png"
  369. result = test_ppstructurev3_api_enhanced(
  370. image_path,
  371. args.url,
  372. args.output,
  373. normalize_numbers=normalize_numbers
  374. )
  375. print("\n🎉 API测试和转换完成!")
  376. except Exception as e:
  377. print(f"❌ 处理失败: {e}")
  378. import traceback
  379. traceback.print_exc()