ppstructurev3_utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. """PPStructureV3公共函数"""
  2. import json
  3. import traceback
  4. import warnings
  5. import base64
  6. from pathlib import Path
  7. from PIL import Image
  8. from typing import List, Dict, Any, Union
  9. import numpy as np
  10. from utils import (
  11. load_images_from_pdf,
  12. normalize_markdown_table,
  13. normalize_financial_numbers
  14. )
  15. def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]:
  16. """
  17. 将PDF转换为图像文件
  18. Args:
  19. pdf_file: PDF文件路径
  20. output_dir: 输出目录
  21. dpi: 图像分辨率
  22. Returns:
  23. 生成的图像文件路径列表
  24. """
  25. pdf_path = Path(pdf_file)
  26. if not pdf_path.exists() or pdf_path.suffix.lower() != '.pdf':
  27. print(f"❌ Invalid PDF file: {pdf_path}")
  28. return []
  29. # 如果没有指定输出目录,使用PDF同名目录
  30. if output_dir is None:
  31. output_path = pdf_path.parent / f"{pdf_path.stem}"
  32. else:
  33. output_path = Path(output_dir) / f"{pdf_path.stem}"
  34. output_path = output_path.resolve()
  35. output_path.mkdir(parents=True, exist_ok=True)
  36. try:
  37. # 使用doc_utils中的函数加载PDF图像
  38. images = load_images_from_pdf(str(pdf_path), dpi=dpi)
  39. image_paths = []
  40. for i, image in enumerate(images):
  41. # 生成图像文件名
  42. image_filename = f"{pdf_path.stem}_page_{i+1:03d}.png"
  43. image_path = output_path / image_filename
  44. # 保存图像
  45. image.save(str(image_path))
  46. image_paths.append(str(image_path))
  47. print(f"✅ Converted {len(images)} pages from {pdf_path.name} to images")
  48. return image_paths
  49. except Exception as e:
  50. print(f"❌ Error converting PDF {pdf_path}: {e}")
  51. traceback.print_exc()
  52. return []
  53. def convert_pruned_result_to_json(pruned_result: Dict[str, Any],
  54. input_image_path: str,
  55. output_dir: str,
  56. filename: str,
  57. normalize_numbers: bool = True) -> tuple[str, Dict[str, Any]]:
  58. """
  59. 将API返回结果转换为标准JSON格式,并支持数字标准化
  60. """
  61. if not pruned_result:
  62. return "", {}
  63. # 构造标准格式的JSON
  64. converted_json = {
  65. "input_path": input_image_path,
  66. "page_index": None,
  67. "model_settings": pruned_result.get('model_settings', {}),
  68. "parsing_res_list": pruned_result.get('parsing_res_list', []),
  69. "doc_preprocessor_res": {
  70. "input_path": None,
  71. "page_index": None,
  72. "model_settings": pruned_result.get('doc_preprocessor_res', {}).get('model_settings', {}),
  73. "angle": pruned_result.get('doc_preprocessor_res', {}).get('angle', 0)
  74. },
  75. "layout_det_res": {
  76. "input_path": None,
  77. "page_index": None,
  78. "boxes": pruned_result.get('layout_det_res', {}).get('boxes', [])
  79. },
  80. "overall_ocr_res": {
  81. "input_path": None,
  82. "page_index": None,
  83. "model_settings": pruned_result.get('overall_ocr_res', {}).get('model_settings', {}),
  84. "dt_polys": pruned_result.get('overall_ocr_res', {}).get('dt_polys', []),
  85. "text_det_params": pruned_result.get('overall_ocr_res', {}).get('text_det_params', {}),
  86. "text_type": pruned_result.get('overall_ocr_res', {}).get('text_type', 'general'),
  87. "textline_orientation_angles": pruned_result.get('overall_ocr_res', {}).get('textline_orientation_angles', []),
  88. "text_rec_score_thresh": pruned_result.get('overall_ocr_res', {}).get('text_rec_score_thresh', 0.0),
  89. "return_word_box": pruned_result.get('overall_ocr_res', {}).get('return_word_box', False),
  90. "rec_texts": pruned_result.get('overall_ocr_res', {}).get('rec_texts', []),
  91. "rec_scores": pruned_result.get('overall_ocr_res', {}).get('rec_scores', []),
  92. "rec_polys": pruned_result.get('overall_ocr_res', {}).get('rec_polys', []),
  93. "rec_boxes": pruned_result.get('overall_ocr_res', {}).get('rec_boxes', [])
  94. },
  95. "table_res_list": pruned_result.get('table_res_list', [])
  96. }
  97. # 数字标准化处理
  98. original_json = converted_json.copy()
  99. changes_count = 0
  100. if normalize_numbers:
  101. # 1. 标准化 parsing_res_list 中的文本内容
  102. for item in converted_json.get('parsing_res_list', []):
  103. if 'block_content' in item:
  104. original_content = item['block_content']
  105. normalized_content = original_content
  106. # 根据block_label类型选择标准化方法
  107. if item.get('block_label') == 'table':
  108. normalized_content = normalize_markdown_table(original_content)
  109. # else:
  110. # normalized_content = normalize_financial_numbers(original_content)
  111. if original_content != normalized_content:
  112. item['block_content'] = normalized_content
  113. changes_count += len([1 for o, n in zip(original_content, normalized_content) if o != n])
  114. # 2. 标准化 table_res_list 中的HTML表格
  115. for table_item in converted_json.get('table_res_list', []):
  116. if 'pred_html' in table_item:
  117. original_html = table_item['pred_html']
  118. normalized_html = normalize_markdown_table(original_html)
  119. if original_html != normalized_html:
  120. table_item['pred_html'] = normalized_html
  121. changes_count += len([1 for o, n in zip(original_html, normalized_html) if o != n])
  122. # 检查是否需要修复表格一致性(这里只做统计,实际修复可能需要更复杂的逻辑)
  123. # 统计表格数量
  124. parsing_res_tables_count = 0
  125. table_res_list_count = 0
  126. if 'parsing_res_list' in converted_json:
  127. parsing_res_tables_count = len([item for item in converted_json['parsing_res_list']
  128. if 'block_label' in item and item['block_label'] == 'table'])
  129. if 'table_res_list' in converted_json:
  130. table_res_list_count = len(converted_json["table_res_list"])
  131. table_consistency_fixed = False
  132. if parsing_res_tables_count != table_res_list_count:
  133. warnings.warn(f"⚠️ Warning: {filename} Table count mismatch - parsing_res_list has {parsing_res_tables_count} tables, "
  134. f"but table_res_list has {table_res_list_count} tables.")
  135. table_consistency_fixed = True
  136. # 这里可以添加实际的修复逻辑,例如根据需要添加或删除表格项
  137. # 但由于缺乏具体规则,暂时只做统计和警告
  138. # 3. 标准化 overall_ocr_res 中的识别文本
  139. # ocr_res = converted_json.get('overall_ocr_res', {})
  140. # if 'rec_texts' in ocr_res:
  141. # original_texts = ocr_res['rec_texts'][:]
  142. # normalized_texts = []
  143. # for text in original_texts:
  144. # normalized_text = normalize_financial_numbers(text)
  145. # normalized_texts.append(normalized_text)
  146. # if text != normalized_text:
  147. # changes_count += len([1 for o, n in zip(text, normalized_text) if o != n])
  148. # ocr_res['rec_texts'] = normalized_texts
  149. # 添加标准化处理信息
  150. converted_json['processing_info'] = {
  151. "normalize_numbers": normalize_numbers,
  152. "changes_applied": changes_count > 0,
  153. "character_changes_count": changes_count,
  154. "parsing_res_tables_count": parsing_res_tables_count,
  155. "table_res_list_count": table_res_list_count,
  156. "table_consistency_fixed": table_consistency_fixed
  157. }
  158. # if changes_count > 0:
  159. # print(f"🔧 已标准化 {changes_count} 个字符(全角→半角)")
  160. else:
  161. converted_json['processing_info'] = {
  162. "normalize_numbers": False,
  163. "changes_applied": False,
  164. "character_changes_count": 0
  165. }
  166. # 保存JSON文件
  167. output_path = Path(output_dir).resolve()
  168. output_path.mkdir(parents=True, exist_ok=True)
  169. json_file_path = output_path / f"{filename}.json"
  170. with open(json_file_path, 'w', encoding='utf-8') as f:
  171. json.dump(converted_json, f, ensure_ascii=False, indent=2)
  172. # 如果启用了标准化且有变化,保存原始版本用于对比
  173. if normalize_numbers and changes_count > 0:
  174. original_output_path = output_path / f"{filename}_original.json"
  175. with open(original_output_path, 'w', encoding='utf-8') as f:
  176. json.dump(original_json, f, ensure_ascii=False, indent=2)
  177. return str(output_path), converted_json
  178. def save_image(image: Union[Image.Image, str, np.ndarray], output_path: str) -> str:
  179. """
  180. 保存单个图像到指定路径
  181. Args:
  182. image: 要保存的图像,可以是PIL Image对象、base64字符串或numpy数组
  183. output_path: 输出文件路径
  184. Returns:
  185. 保存的图像文件路径
  186. """
  187. try:
  188. if isinstance(image, Image.Image):
  189. image.save(output_path)
  190. elif isinstance(image, str):
  191. # 处理base64字符串
  192. img_data = base64.b64decode(image)
  193. with open(output_path, 'wb') as f:
  194. f.write(img_data)
  195. elif isinstance(image, np.ndarray):
  196. # 处理numpy数组
  197. pil_image = Image.fromarray(image)
  198. pil_image.save(output_path)
  199. else:
  200. raise ValueError(f"Unsupported image type: {type(image)}")
  201. # print(f"📷 Saved image: {output_path}")
  202. return str(output_path)
  203. except Exception as e:
  204. print(f"❌ Error saving image {output_path}: {e}")
  205. return ""
  206. def save_output_images(output_images: Dict[str, Any], output_dir: str, output_filename: str) -> Dict[str, str]:
  207. """
  208. 保存API返回的输出图像
  209. Args:
  210. output_images: 图像数组字典或PIL Image对象字典
  211. output_dir: 输出目录
  212. output_filename: 输出文件名前缀
  213. Returns:
  214. 保存的图像文件路径字典
  215. """
  216. if not output_images:
  217. return {}
  218. output_path = Path(output_dir).resolve()
  219. output_path.mkdir(parents=True, exist_ok=True)
  220. saved_images = {}
  221. for img_name, img_data in output_images.items():
  222. try:
  223. # 生成文件名
  224. img_filename = f"{output_filename}_{img_name}.jpg"
  225. img_path = output_path / img_filename
  226. save_image(img_data, str(img_path))
  227. saved_images[img_name] = str(img_path)
  228. except Exception as e:
  229. print(f"❌ Error saving image {img_name}: {e}")
  230. print(f" Image data type: {type(img_data)}")
  231. if hasattr(img_data, 'shape'):
  232. print(f" Image shape: {img_data.shape}")
  233. traceback.print_exc()
  234. return saved_images
  235. def save_markdown_content(markdown_data: Dict[str, Any], output_dir: str,
  236. filename: str, normalize_numbers: bool = True,
  237. key_text: str = 'text', key_images: str = 'images',
  238. json_data: Dict[str, Any] = None) -> str:
  239. """
  240. 保存Markdown内容,支持数字标准化和表格补全
  241. """
  242. if not markdown_data and not json_data:
  243. return ""
  244. output_path = Path(output_dir).resolve()
  245. output_path.mkdir(parents=True, exist_ok=True)
  246. # 🎯 优先使用json_data生成完整内容
  247. if json_data:
  248. return save_markdown_content_enhanced(json_data, str(output_path), filename, normalize_numbers)
  249. # 原有逻辑保持不变
  250. markdown_text = markdown_data.get(key_text, '')
  251. # 数字标准化处理
  252. changes_count = 0
  253. if normalize_numbers and markdown_text:
  254. original_markdown_text = markdown_text
  255. markdown_text = normalize_markdown_table(markdown_text)
  256. changes_count = len([1 for o, n in zip(original_markdown_text, markdown_text) if o != n])
  257. # if changes_count > 0:
  258. # print(f"🔧 Markdown中已标准化 {changes_count} 个字符(全角→半角)")
  259. md_file_path = output_path / f"{filename}.md"
  260. with open(md_file_path, 'w', encoding='utf-8') as f:
  261. f.write(markdown_text)
  262. # 如果启用了标准化且有变化,保存原始版本用于对比
  263. if normalize_numbers and changes_count > 0:
  264. original_output_path = output_path / f"{filename}_original.md"
  265. with open(original_output_path, 'w', encoding='utf-8') as f:
  266. f.write(original_markdown_text)
  267. # 保存Markdown中的图像
  268. markdown_images = markdown_data.get(key_images, {})
  269. for img_path, img_data in markdown_images.items():
  270. try:
  271. full_img_path = output_path / img_path
  272. full_img_path.parent.mkdir(parents=True, exist_ok=True)
  273. save_image(img_data, str(full_img_path))
  274. except Exception as e:
  275. print(f"❌ Error saving Markdown image {img_path}: {e}")
  276. return str(md_file_path)
  277. def save_markdown_content_enhanced(json_data: Dict[str, Any], output_dir: str,
  278. filename: str, normalize_numbers: bool = True) -> str:
  279. """
  280. 增强版Markdown内容保存,同时处理parsing_res_list和table_res_list
  281. """
  282. if not json_data:
  283. return ""
  284. output_path = Path(output_dir).resolve()
  285. output_path.mkdir(parents=True, exist_ok=True)
  286. markdown_content = []
  287. # 处理 parsing_res_list
  288. parsing_res_list = json_data.get('parsing_res_list', [])
  289. table_res_list = json_data.get('table_res_list', [])
  290. table_index = 0 # 用于匹配table_res_list中的表格
  291. for item in parsing_res_list:
  292. block_label = item.get('block_label', '')
  293. block_content = item.get('block_content', '')
  294. if block_label == 'table':
  295. # 如果是表格,优先使用table_res_list中的详细HTML
  296. if table_index < len(table_res_list):
  297. detailed_html = table_res_list[table_index].get('pred_html', block_content)
  298. if normalize_numbers:
  299. detailed_html = normalize_markdown_table(detailed_html)
  300. # 转换为居中显示的HTML
  301. markdown_content.append(f'<div style="text-align: center;">{detailed_html}</div>')
  302. table_index += 1
  303. else:
  304. # 如果table_res_list中没有对应项,使用parsing_res_list中的内容
  305. if normalize_numbers:
  306. block_content = normalize_markdown_table(block_content)
  307. markdown_content.append(f'<div style="text-align: center;">{block_content}</div>')
  308. else:
  309. # 非表格内容直接添加
  310. if normalize_numbers:
  311. block_content = normalize_financial_numbers(block_content)
  312. markdown_content.append(block_content)
  313. # 🎯 关键修复:处理剩余的table_res_list项目
  314. # 如果table_res_list中还有未处理的表格(比parsing_res_list中的表格多)
  315. remaining_tables = table_res_list[table_index:]
  316. for table_item in remaining_tables:
  317. detailed_html = table_item.get('pred_html', '')
  318. if detailed_html:
  319. if normalize_numbers:
  320. detailed_html = normalize_markdown_table(detailed_html)
  321. markdown_content.append(f'<div style="text-align: center;">{detailed_html}</div>')
  322. # 合并所有内容
  323. final_markdown = '\n\n'.join(markdown_content)
  324. # 保存文件
  325. md_file_path = output_path / f"{filename}.md"
  326. with open(md_file_path, 'w', encoding='utf-8') as f:
  327. f.write(final_markdown)
  328. print(f"📄 Enhanced Markdown saved: {md_file_path}")
  329. print(f" - parsing_res_list tables: {sum(1 for item in parsing_res_list if item.get('block_label') == 'table')}")
  330. print(f" - table_res_list tables: {len(table_res_list)}")
  331. print(f" - remaining tables added: {len(remaining_tables)}")
  332. return str(md_file_path)