utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. """PaddleX 公共工具函数"""
  2. import json
  3. import traceback
  4. import warnings
  5. import base64
  6. from pathlib import Path
  7. from PIL import Image
  8. from typing import List, Dict, Any, Union
  9. import numpy as np
  10. # 导入 ocr_utils
  11. import sys
  12. ocr_platform_root = Path(__file__).parents[2]
  13. if str(ocr_platform_root) not in sys.path:
  14. sys.path.insert(0, str(ocr_platform_root))
  15. from ocr_utils import (
  16. normalize_markdown_table,
  17. normalize_financial_numbers
  18. )
  19. # 注意:load_images_from_pdf 不再需要,因为 PDF 转图片由 ocr_utils.get_input_files() 统一处理
  20. def convert_pruned_result_to_json(pruned_result: Dict[str, Any],
  21. input_image_path: str,
  22. output_dir: str,
  23. filename: str,
  24. normalize_numbers: bool = True) -> tuple[str, Dict[str, Any]]:
  25. """
  26. 将API返回结果转换为标准JSON格式,并支持数字标准化
  27. """
  28. if not pruned_result:
  29. return "", {}
  30. # 构造标准格式的JSON
  31. converted_json = {
  32. "input_path": input_image_path,
  33. "page_index": None,
  34. "model_settings": pruned_result.get('model_settings', {}),
  35. "parsing_res_list": pruned_result.get('parsing_res_list', []),
  36. "doc_preprocessor_res": {
  37. "input_path": None,
  38. "page_index": None,
  39. "model_settings": pruned_result.get('doc_preprocessor_res', {}).get('model_settings', {}),
  40. "angle": pruned_result.get('doc_preprocessor_res', {}).get('angle', 0)
  41. },
  42. "layout_det_res": {
  43. "input_path": None,
  44. "page_index": None,
  45. "boxes": pruned_result.get('layout_det_res', {}).get('boxes', [])
  46. },
  47. "overall_ocr_res": {
  48. "input_path": None,
  49. "page_index": None,
  50. "model_settings": pruned_result.get('overall_ocr_res', {}).get('model_settings', {}),
  51. "dt_polys": pruned_result.get('overall_ocr_res', {}).get('dt_polys', []),
  52. "text_det_params": pruned_result.get('overall_ocr_res', {}).get('text_det_params', {}),
  53. "text_type": pruned_result.get('overall_ocr_res', {}).get('text_type', 'general'),
  54. "textline_orientation_angles": pruned_result.get('overall_ocr_res', {}).get('textline_orientation_angles', []),
  55. "text_rec_score_thresh": pruned_result.get('overall_ocr_res', {}).get('text_rec_score_thresh', 0.0),
  56. "return_word_box": pruned_result.get('overall_ocr_res', {}).get('return_word_box', False),
  57. "rec_texts": pruned_result.get('overall_ocr_res', {}).get('rec_texts', []),
  58. "rec_scores": pruned_result.get('overall_ocr_res', {}).get('rec_scores', []),
  59. "rec_polys": pruned_result.get('overall_ocr_res', {}).get('rec_polys', []),
  60. "rec_boxes": pruned_result.get('overall_ocr_res', {}).get('rec_boxes', [])
  61. },
  62. "table_res_list": pruned_result.get('table_res_list', [])
  63. }
  64. # 数字标准化处理
  65. original_json = converted_json.copy()
  66. changes_count = 0
  67. if normalize_numbers:
  68. # 1. 标准化 parsing_res_list 中的文本内容
  69. for item in converted_json.get('parsing_res_list', []):
  70. if 'block_content' in item:
  71. original_content = item['block_content']
  72. normalized_content = original_content
  73. # 根据block_label类型选择标准化方法
  74. if item.get('block_label') == 'table':
  75. normalized_content = normalize_markdown_table(original_content)
  76. if original_content != normalized_content:
  77. item['block_content'] = normalized_content
  78. changes_count += len([1 for o, n in zip(original_content, normalized_content) if o != n])
  79. # 2. 标准化 table_res_list 中的HTML表格
  80. for table_item in converted_json.get('table_res_list', []):
  81. if 'pred_html' in table_item:
  82. original_html = table_item['pred_html']
  83. normalized_html = normalize_markdown_table(original_html)
  84. if original_html != normalized_html:
  85. table_item['pred_html'] = normalized_html
  86. changes_count += len([1 for o, n in zip(original_html, normalized_html) if o != n])
  87. # 统计表格数量
  88. parsing_res_tables_count = 0
  89. table_res_list_count = 0
  90. if 'parsing_res_list' in converted_json:
  91. parsing_res_tables_count = len([item for item in converted_json['parsing_res_list']
  92. if 'block_label' in item and item['block_label'] == 'table'])
  93. if 'table_res_list' in converted_json:
  94. table_res_list_count = len(converted_json["table_res_list"])
  95. table_consistency_fixed = False
  96. if parsing_res_tables_count != table_res_list_count:
  97. warnings.warn(f"⚠️ Warning: {filename} Table count mismatch - parsing_res_list has {parsing_res_tables_count} tables, "
  98. f"but table_res_list has {table_res_list_count} tables.")
  99. table_consistency_fixed = True
  100. # 添加标准化处理信息
  101. converted_json['processing_info'] = {
  102. "normalize_numbers": normalize_numbers,
  103. "changes_applied": changes_count > 0,
  104. "character_changes_count": changes_count,
  105. "parsing_res_tables_count": parsing_res_tables_count,
  106. "table_res_list_count": table_res_list_count,
  107. "table_consistency_fixed": table_consistency_fixed
  108. }
  109. else:
  110. converted_json['processing_info'] = {
  111. "normalize_numbers": False,
  112. "changes_applied": False,
  113. "character_changes_count": 0
  114. }
  115. # 保存JSON文件
  116. output_path = Path(output_dir).resolve()
  117. output_path.mkdir(parents=True, exist_ok=True)
  118. json_file_path = output_path / f"{filename}.json"
  119. with open(json_file_path, 'w', encoding='utf-8') as f:
  120. json.dump(converted_json, f, ensure_ascii=False, indent=2)
  121. # 如果启用了标准化且有变化,保存原始版本用于对比
  122. if normalize_numbers and changes_count > 0:
  123. original_output_path = output_path / f"{filename}_original.json"
  124. with open(original_output_path, 'w', encoding='utf-8') as f:
  125. json.dump(original_json, f, ensure_ascii=False, indent=2)
  126. return str(output_path), converted_json
  127. def save_image(image: Union[Image.Image, str, np.ndarray], output_path: str) -> str:
  128. """
  129. 保存单个图像到指定路径
  130. Args:
  131. image: 要保存的图像,可以是PIL Image对象、base64字符串或numpy数组
  132. output_path: 输出文件路径
  133. Returns:
  134. 保存的图像文件路径
  135. """
  136. try:
  137. if isinstance(image, Image.Image):
  138. image.save(output_path)
  139. elif isinstance(image, str):
  140. # 处理base64字符串
  141. img_data = base64.b64decode(image)
  142. with open(output_path, 'wb') as f:
  143. f.write(img_data)
  144. elif isinstance(image, np.ndarray):
  145. # 处理numpy数组
  146. pil_image = Image.fromarray(image)
  147. pil_image.save(output_path)
  148. else:
  149. raise ValueError(f"Unsupported image type: {type(image)}")
  150. return str(output_path)
  151. except Exception as e:
  152. print(f"❌ Error saving image {output_path}: {e}")
  153. return ""
  154. def save_output_images(output_images: Dict[str, Any], output_dir: str, output_filename: str) -> Dict[str, str]:
  155. """
  156. 保存API返回的输出图像
  157. Args:
  158. output_images: 图像数组字典或PIL Image对象字典
  159. output_dir: 输出目录
  160. output_filename: 输出文件名前缀
  161. Returns:
  162. 保存的图像文件路径字典
  163. """
  164. if not output_images:
  165. return {}
  166. output_path = Path(output_dir).resolve()
  167. output_path.mkdir(parents=True, exist_ok=True)
  168. saved_images = {}
  169. for img_name, img_data in output_images.items():
  170. try:
  171. # 生成文件名
  172. img_filename = f"{output_filename}_{img_name}.jpg"
  173. img_path = output_path / img_filename
  174. save_image(img_data, str(img_path))
  175. saved_images[img_name] = str(img_path)
  176. except Exception as e:
  177. print(f"❌ Error saving image {img_name}: {e}")
  178. print(f" Image data type: {type(img_data)}")
  179. if hasattr(img_data, 'shape'):
  180. print(f" Image shape: {img_data.shape}")
  181. traceback.print_exc()
  182. return saved_images
  183. def save_markdown_content(markdown_data: Dict[str, Any], output_dir: str,
  184. filename: str, normalize_numbers: bool = True,
  185. key_text: str = 'text', key_images: str = 'images',
  186. json_data: Dict[str, Any] = None) -> str:
  187. """
  188. 保存Markdown内容,支持数字标准化和表格补全
  189. """
  190. if not markdown_data and not json_data:
  191. return ""
  192. output_path = Path(output_dir).resolve()
  193. output_path.mkdir(parents=True, exist_ok=True)
  194. # 🎯 优先使用json_data生成完整内容
  195. if json_data:
  196. return save_markdown_content_enhanced(json_data, str(output_path), filename, normalize_numbers)
  197. # 原有逻辑保持不变
  198. markdown_text = markdown_data.get(key_text, '')
  199. # 数字标准化处理
  200. changes_count = 0
  201. if normalize_numbers and markdown_text:
  202. original_markdown_text = markdown_text
  203. markdown_text = normalize_markdown_table(markdown_text)
  204. changes_count = len([1 for o, n in zip(original_markdown_text, markdown_text) if o != n])
  205. md_file_path = output_path / f"{filename}.md"
  206. with open(md_file_path, 'w', encoding='utf-8') as f:
  207. f.write(markdown_text)
  208. # 如果启用了标准化且有变化,保存原始版本用于对比
  209. if normalize_numbers and changes_count > 0:
  210. original_output_path = output_path / f"{filename}_original.md"
  211. with open(original_output_path, 'w', encoding='utf-8') as f:
  212. f.write(original_markdown_text)
  213. # 保存Markdown中的图像
  214. markdown_images = markdown_data.get(key_images, {})
  215. for img_path, img_data in markdown_images.items():
  216. try:
  217. full_img_path = output_path / img_path
  218. full_img_path.parent.mkdir(parents=True, exist_ok=True)
  219. save_image(img_data, str(full_img_path))
  220. except Exception as e:
  221. print(f"❌ Error saving Markdown image {img_path}: {e}")
  222. return str(md_file_path)
  223. def save_markdown_content_enhanced(json_data: Dict[str, Any], output_dir: str,
  224. filename: str, normalize_numbers: bool = True) -> str:
  225. """
  226. 增强版Markdown内容保存,同时处理parsing_res_list和table_res_list
  227. """
  228. if not json_data:
  229. return ""
  230. output_path = Path(output_dir).resolve()
  231. output_path.mkdir(parents=True, exist_ok=True)
  232. markdown_content = []
  233. # 处理 parsing_res_list
  234. parsing_res_list = json_data.get('parsing_res_list', [])
  235. table_res_list = json_data.get('table_res_list', [])
  236. table_index = 0 # 用于匹配table_res_list中的表格
  237. for item in parsing_res_list:
  238. block_label = item.get('block_label', '')
  239. block_content = item.get('block_content', '')
  240. if block_label == 'table':
  241. # 如果是表格,优先使用table_res_list中的详细HTML
  242. if table_index < len(table_res_list):
  243. detailed_html = table_res_list[table_index].get('pred_html', block_content)
  244. if normalize_numbers:
  245. detailed_html = normalize_markdown_table(detailed_html)
  246. # 转换为居中显示的HTML
  247. markdown_content.append(f'<div style="text-align: center;">{detailed_html}</div>')
  248. table_index += 1
  249. else:
  250. # 如果table_res_list中没有对应项,使用parsing_res_list中的内容
  251. if normalize_numbers:
  252. block_content = normalize_markdown_table(block_content)
  253. markdown_content.append(f'<div style="text-align: center;">{block_content}</div>')
  254. else:
  255. # 非表格内容直接添加
  256. if normalize_numbers:
  257. block_content = normalize_financial_numbers(block_content)
  258. markdown_content.append(block_content)
  259. # 🎯 关键修复:处理剩余的table_res_list项目
  260. # 如果table_res_list中还有未处理的表格(比parsing_res_list中的表格多)
  261. remaining_tables = table_res_list[table_index:]
  262. for table_item in remaining_tables:
  263. detailed_html = table_item.get('pred_html', '')
  264. if detailed_html:
  265. if normalize_numbers:
  266. detailed_html = normalize_markdown_table(detailed_html)
  267. markdown_content.append(f'<div style="text-align: center;">{detailed_html}</div>')
  268. # 合并所有内容
  269. final_markdown = '\n\n'.join(markdown_content)
  270. # 保存文件
  271. md_file_path = output_path / f"{filename}.md"
  272. with open(md_file_path, 'w', encoding='utf-8') as f:
  273. f.write(final_markdown)
  274. print(f"📄 Enhanced Markdown saved: {md_file_path}")
  275. print(f" - parsing_res_list tables: {sum(1 for item in parsing_res_list if item.get('block_label') == 'table')}")
  276. print(f" - table_res_list tables: {len(table_res_list)}")
  277. print(f" - remaining tables added: {len(remaining_tables)}")
  278. return str(md_file_path)