ppstructurev3_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. """PPStructureV3公共函数"""
  2. import json
  3. import traceback
  4. import warnings
  5. import base64
  6. from pathlib import Path
  7. from PIL import Image
  8. from typing import List, Dict, Any, Union
  9. import numpy as np
  10. from utils import (
  11. load_images_from_pdf,
  12. normalize_markdown_table
  13. )
  14. def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]:
  15. """
  16. 将PDF转换为图像文件
  17. Args:
  18. pdf_file: PDF文件路径
  19. output_dir: 输出目录
  20. dpi: 图像分辨率
  21. Returns:
  22. 生成的图像文件路径列表
  23. """
  24. pdf_path = Path(pdf_file)
  25. if not pdf_path.exists() or pdf_path.suffix.lower() != '.pdf':
  26. print(f"❌ Invalid PDF file: {pdf_path}")
  27. return []
  28. # 如果没有指定输出目录,使用PDF同名目录
  29. if output_dir is None:
  30. output_path = pdf_path.parent / f"{pdf_path.stem}"
  31. else:
  32. output_path = Path(output_dir) / f"{pdf_path.stem}"
  33. output_path = output_path.resolve()
  34. output_path.mkdir(parents=True, exist_ok=True)
  35. try:
  36. # 使用doc_utils中的函数加载PDF图像
  37. images = load_images_from_pdf(str(pdf_path), dpi=dpi)
  38. image_paths = []
  39. for i, image in enumerate(images):
  40. # 生成图像文件名
  41. image_filename = f"{pdf_path.stem}_page_{i+1:03d}.png"
  42. image_path = output_path / image_filename
  43. # 保存图像
  44. image.save(str(image_path))
  45. image_paths.append(str(image_path))
  46. print(f"✅ Converted {len(images)} pages from {pdf_path.name} to images")
  47. return image_paths
  48. except Exception as e:
  49. print(f"❌ Error converting PDF {pdf_path}: {e}")
  50. traceback.print_exc()
  51. return []
  52. def convert_pruned_result_to_json(pruned_result: Dict[str, Any],
  53. input_image_path: str,
  54. output_dir: str,
  55. filename: str,
  56. normalize_numbers: bool = True) -> tuple[str, Dict[str, Any]]:
  57. """
  58. 将API返回结果转换为标准JSON格式,并支持数字标准化
  59. """
  60. if not pruned_result:
  61. return "", {}
  62. # 构造标准格式的JSON
  63. converted_json = {
  64. "input_path": input_image_path,
  65. "page_index": None,
  66. "model_settings": pruned_result.get('model_settings', {}),
  67. "parsing_res_list": pruned_result.get('parsing_res_list', []),
  68. "doc_preprocessor_res": {
  69. "input_path": None,
  70. "page_index": None,
  71. "model_settings": pruned_result.get('doc_preprocessor_res', {}).get('model_settings', {}),
  72. "angle": pruned_result.get('doc_preprocessor_res', {}).get('angle', 0)
  73. },
  74. "layout_det_res": {
  75. "input_path": None,
  76. "page_index": None,
  77. "boxes": pruned_result.get('layout_det_res', {}).get('boxes', [])
  78. },
  79. "overall_ocr_res": {
  80. "input_path": None,
  81. "page_index": None,
  82. "model_settings": pruned_result.get('overall_ocr_res', {}).get('model_settings', {}),
  83. "dt_polys": pruned_result.get('overall_ocr_res', {}).get('dt_polys', []),
  84. "text_det_params": pruned_result.get('overall_ocr_res', {}).get('text_det_params', {}),
  85. "text_type": pruned_result.get('overall_ocr_res', {}).get('text_type', 'general'),
  86. "textline_orientation_angles": pruned_result.get('overall_ocr_res', {}).get('textline_orientation_angles', []),
  87. "text_rec_score_thresh": pruned_result.get('overall_ocr_res', {}).get('text_rec_score_thresh', 0.0),
  88. "return_word_box": pruned_result.get('overall_ocr_res', {}).get('return_word_box', False),
  89. "rec_texts": pruned_result.get('overall_ocr_res', {}).get('rec_texts', []),
  90. "rec_scores": pruned_result.get('overall_ocr_res', {}).get('rec_scores', []),
  91. "rec_polys": pruned_result.get('overall_ocr_res', {}).get('rec_polys', []),
  92. "rec_boxes": pruned_result.get('overall_ocr_res', {}).get('rec_boxes', [])
  93. },
  94. "table_res_list": pruned_result.get('table_res_list', [])
  95. }
  96. # 数字标准化处理
  97. original_json = converted_json.copy()
  98. changes_count = 0
  99. if normalize_numbers:
  100. # 1. 标准化 parsing_res_list 中的文本内容
  101. for item in converted_json.get('parsing_res_list', []):
  102. if 'block_content' in item:
  103. original_content = item['block_content']
  104. normalized_content = original_content
  105. # 根据block_label类型选择标准化方法
  106. if item.get('block_label') == 'table':
  107. normalized_content = normalize_markdown_table(original_content)
  108. # else:
  109. # normalized_content = normalize_financial_numbers(original_content)
  110. if original_content != normalized_content:
  111. item['block_content'] = normalized_content
  112. changes_count += len([1 for o, n in zip(original_content, normalized_content) if o != n])
  113. # 2. 标准化 table_res_list 中的HTML表格
  114. for table_item in converted_json.get('table_res_list', []):
  115. if 'pred_html' in table_item:
  116. original_html = table_item['pred_html']
  117. normalized_html = normalize_markdown_table(original_html)
  118. if original_html != normalized_html:
  119. table_item['pred_html'] = normalized_html
  120. changes_count += len([1 for o, n in zip(original_html, normalized_html) if o != n])
  121. # 检查是否需要修复表格一致性(这里只做统计,实际修复可能需要更复杂的逻辑)
  122. # 统计表格数量
  123. parsing_res_tables_count = 0
  124. table_res_list_count = 0
  125. if 'parsing_res_list' in converted_json:
  126. parsing_res_tables_count = len([item for item in converted_json['parsing_res_list']
  127. if 'block_label' in item and item['block_label'] == 'table'])
  128. if 'table_res_list' in converted_json:
  129. table_res_list_count = len(converted_json["table_res_list"])
  130. table_consistency_fixed = False
  131. if parsing_res_tables_count != table_res_list_count:
  132. warnings.warn(f"⚠️ Warning: {filename} Table count mismatch - parsing_res_list has {parsing_res_tables_count} tables, "
  133. f"but table_res_list has {table_res_list_count} tables.")
  134. table_consistency_fixed = True
  135. # 这里可以添加实际的修复逻辑,例如根据需要添加或删除表格项
  136. # 但由于缺乏具体规则,暂时只做统计和警告
  137. # 3. 标准化 overall_ocr_res 中的识别文本
  138. # ocr_res = converted_json.get('overall_ocr_res', {})
  139. # if 'rec_texts' in ocr_res:
  140. # original_texts = ocr_res['rec_texts'][:]
  141. # normalized_texts = []
  142. # for text in original_texts:
  143. # normalized_text = normalize_financial_numbers(text)
  144. # normalized_texts.append(normalized_text)
  145. # if text != normalized_text:
  146. # changes_count += len([1 for o, n in zip(text, normalized_text) if o != n])
  147. # ocr_res['rec_texts'] = normalized_texts
  148. # 添加标准化处理信息
  149. converted_json['processing_info'] = {
  150. "normalize_numbers": normalize_numbers,
  151. "changes_applied": changes_count > 0,
  152. "character_changes_count": changes_count,
  153. "parsing_res_tables_count": parsing_res_tables_count,
  154. "table_res_list_count": table_res_list_count,
  155. "table_consistency_fixed": table_consistency_fixed
  156. }
  157. # if changes_count > 0:
  158. # print(f"🔧 已标准化 {changes_count} 个字符(全角→半角)")
  159. else:
  160. converted_json['processing_info'] = {
  161. "normalize_numbers": False,
  162. "changes_applied": False,
  163. "character_changes_count": 0
  164. }
  165. # 保存JSON文件
  166. output_path = Path(output_dir).resolve()
  167. output_path.mkdir(parents=True, exist_ok=True)
  168. json_file_path = output_path / f"{filename}.json"
  169. with open(json_file_path, 'w', encoding='utf-8') as f:
  170. json.dump(converted_json, f, ensure_ascii=False, indent=2)
  171. # 如果启用了标准化且有变化,保存原始版本用于对比
  172. if normalize_numbers and changes_count > 0:
  173. original_output_path = output_path / f"{filename}_original.json"
  174. with open(original_output_path, 'w', encoding='utf-8') as f:
  175. json.dump(original_json, f, ensure_ascii=False, indent=2)
  176. return str(output_path), converted_json
  177. def save_image(image: Union[Image.Image, str, np.ndarray], output_path: str) -> str:
  178. """
  179. 保存单个图像到指定路径
  180. Args:
  181. image: 要保存的图像,可以是PIL Image对象、base64字符串或numpy数组
  182. output_path: 输出文件路径
  183. Returns:
  184. 保存的图像文件路径
  185. """
  186. try:
  187. if isinstance(image, Image.Image):
  188. image.save(output_path)
  189. elif isinstance(image, str):
  190. # 处理base64字符串
  191. img_data = base64.b64decode(image)
  192. with open(output_path, 'wb') as f:
  193. f.write(img_data)
  194. elif isinstance(image, np.ndarray):
  195. # 处理numpy数组
  196. pil_image = Image.fromarray(image)
  197. pil_image.save(output_path)
  198. else:
  199. raise ValueError(f"Unsupported image type: {type(image)}")
  200. # print(f"📷 Saved image: {output_path}")
  201. return str(output_path)
  202. except Exception as e:
  203. print(f"❌ Error saving image {output_path}: {e}")
  204. return ""
  205. def save_output_images(output_images: Dict[str, Any], output_dir: str, output_filename: str) -> Dict[str, str]:
  206. """
  207. 保存API返回的输出图像
  208. Args:
  209. output_images: 图像数组字典或PIL Image对象字典
  210. output_dir: 输出目录
  211. output_filename: 输出文件名前缀
  212. Returns:
  213. 保存的图像文件路径字典
  214. """
  215. if not output_images:
  216. return {}
  217. output_path = Path(output_dir).resolve()
  218. output_path.mkdir(parents=True, exist_ok=True)
  219. saved_images = {}
  220. for img_name, img_data in output_images.items():
  221. try:
  222. # 生成文件名
  223. img_filename = f"{output_filename}_{img_name}.jpg"
  224. img_path = output_path / img_filename
  225. save_image(img_data, str(img_path))
  226. saved_images[img_name] = str(img_path)
  227. except Exception as e:
  228. print(f"❌ Error saving image {img_name}: {e}")
  229. print(f" Image data type: {type(img_data)}")
  230. if hasattr(img_data, 'shape'):
  231. print(f" Image shape: {img_data.shape}")
  232. traceback.print_exc()
  233. return saved_images
  234. def save_markdown_content(markdown_data: Dict[str, Any], output_dir: str,
  235. filename: str, normalize_numbers: bool = True, key_text: str = 'text', key_images: str = 'images') -> str:
  236. """
  237. 保存Markdown内容,支持数字标准化
  238. """
  239. if not markdown_data:
  240. return ""
  241. output_path = Path(output_dir).resolve()
  242. output_path.mkdir(parents=True, exist_ok=True)
  243. # 保存Markdown文本
  244. markdown_text = markdown_data.get(key_text, '')
  245. # 数字标准化处理
  246. changes_count = 0
  247. if normalize_numbers and markdown_text:
  248. original_markdown_text = markdown_text
  249. markdown_text = normalize_markdown_table(markdown_text)
  250. changes_count = len([1 for o, n in zip(original_markdown_text, markdown_text) if o != n])
  251. # if changes_count > 0:
  252. # print(f"🔧 Markdown中已标准化 {changes_count} 个字符(全角→半角)")
  253. md_file_path = output_path / f"{filename}.md"
  254. with open(md_file_path, 'w', encoding='utf-8') as f:
  255. f.write(markdown_text)
  256. # 如果启用了标准化且有变化,保存原始版本用于对比
  257. if normalize_numbers and changes_count > 0:
  258. original_output_path = output_path / f"{filename}_original.md"
  259. with open(original_output_path, 'w', encoding='utf-8') as f:
  260. f.write(original_markdown_text)
  261. # 保存Markdown中的图像
  262. markdown_images = markdown_data.get(key_images, {})
  263. for img_path, img_data in markdown_images.items():
  264. try:
  265. full_img_path = output_path / img_path
  266. full_img_path.parent.mkdir(parents=True, exist_ok=True)
  267. save_image(img_data, str(full_img_path))
  268. except Exception as e:
  269. print(f"❌ Error saving Markdown image {img_path}: {e}")
  270. return str(md_file_path)