pdf2json_infer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. import sys
  2. from typing import Tuple
  3. import os
  4. import boto3, json
  5. from botocore.config import Config
  6. from magic_pdf.libs import fitz
  7. from loguru import logger
  8. from pathlib import Path
  9. from tqdm import tqdm
  10. import numpy as np
  11. # sys.path.insert(0, "/mnt/petrelfs/ouyanglinke/code-clean/")
  12. # print(sys.path)
  13. from validation import cal_edit_distance, format_gt_bbox, label_match, detect_val
  14. # from pdf2text_recogFigure_20231107 import parse_images # 获取figures的bbox
  15. # from pdf2text_recogTable_20231107 import parse_tables # 获取tables的bbox
  16. # from pdf2text_recogEquation_20231108 import parse_equations # 获取equations的bbox
  17. # from pdf2text_recogTitle_20231113 import parse_titles # 获取Title的bbox
  18. # from pdf2text_recogPara import parse_blocks_per_page
  19. # from bbox_sort import bbox_sort, CONTENT_IDX, CONTENT_TYPE_IDX
  20. from magic_pdf.layout.bbox_sort import bbox_sort, CONTENT_IDX, CONTENT_TYPE_IDX
  21. from magic_pdf.pre_proc import parse_images # 获取figures的bbox
  22. from magic_pdf.pre_proc.detect_tables import parse_tables # 获取tables的bbox
  23. from magic_pdf.pre_proc import parse_equations # 获取equations的bbox
  24. # from pdf2text_recogFootnote import parse_footnotes # 获取footnotes的bbox
  25. from magic_pdf.post_proc.detect_para import process_blocks_per_page
  26. from magic_pdf.libs import parse_aws_param, parse_bucket_key, read_file, join_path
  27. def cut_image(bbox: Tuple, page_num: int, page: fitz.Page, save_parent_path: str, s3_profile: str):
  28. """
  29. 从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径
  30. save_path:需要同时支持s3和本地, 图片存放在save_path下,文件名是: {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。
  31. """
  32. # 拼接路径
  33. image_save_path = join_path(save_parent_path, f"{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}.jpg")
  34. try:
  35. # 将坐标转换为fitz.Rect对象
  36. rect = fitz.Rect(*bbox)
  37. # 配置缩放倍数为3倍
  38. zoom = fitz.Matrix(3, 3)
  39. # 截取图片
  40. pix = page.get_pixmap(clip=rect, matrix=zoom)
  41. # 打印图片文件名
  42. # print(f"Saved {image_save_path}")
  43. if image_save_path.startswith("s3://"):
  44. ak, sk, end_point, addressing_style = parse_aws_param(s3_profile)
  45. cli = boto3.client(service_name="s3", aws_access_key_id=ak, aws_secret_access_key=sk, endpoint_url=end_point,
  46. config=Config(s3={'addressing_style': addressing_style}))
  47. bucket_name, bucket_key = parse_bucket_key(image_save_path)
  48. # 将字节流上传到s3
  49. cli.upload_fileobj(pix.tobytes(output='jpeg', jpg_quality=95), bucket_name, bucket_key)
  50. else:
  51. # 保存图片到本地
  52. # 先检查一下image_save_path的父目录是否存在,如果不存在,就创建
  53. parent_dir = os.path.dirname(image_save_path)
  54. if not os.path.exists(parent_dir):
  55. os.makedirs(parent_dir)
  56. pix.save(image_save_path, jpg_quality=95)
  57. # 为了直接能在markdown里看,这里把地址改为相对于mardown的地址
  58. pth = Path(image_save_path)
  59. image_save_path = f"{pth.parent.name}/{pth.name}"
  60. return image_save_path
  61. except Exception as e:
  62. logger.exception(e)
  63. return image_save_path
  64. def get_images_by_bboxes(book_name:str, page_num:int, page: fitz.Page, save_path:str, s3_profile:str, image_bboxes:list, table_bboxes:list, equation_inline_bboxes:list, equation_interline_bboxes:list) -> dict:
  65. """
  66. 返回一个dict, key为bbox, 值是图片地址
  67. """
  68. ret = {}
  69. # 图片的保存路径组成是这样的: {s3_or_local_path}/{book_name}/{images|tables|equations}/{page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg
  70. image_save_path = join_path(save_path, book_name, "images")
  71. table_save_path = join_path(save_path, book_name, "tables")
  72. equation_inline_save_path = join_path(save_path, book_name, "equations_inline")
  73. equation_interline_save_path = join_path(save_path, book_name, "equation_interline")
  74. for bbox in image_bboxes:
  75. image_path = cut_image(bbox, page_num, page, image_save_path, s3_profile)
  76. ret[bbox] = (image_path, "image") # 第二个元素是"image",表示是图片
  77. for bbox in table_bboxes:
  78. image_path = cut_image(bbox, page_num, page, table_save_path, s3_profile)
  79. ret[bbox] = (image_path, "table")
  80. # 对公式目前只截图,不返回
  81. for bbox in equation_inline_bboxes:
  82. cut_image(bbox, page_num, page, equation_inline_save_path, s3_profile)
  83. for bbox in equation_interline_bboxes:
  84. cut_image(bbox, page_num, page, equation_interline_save_path, s3_profile)
  85. return ret
  86. def reformat_bboxes(images_box_path_dict:list, paras_dict:dict):
  87. """
  88. 把bbox重新组装成一个list,每个元素[x0, y0, x1, y1, block_content, idx_x, idx_y], 初始时候idx_x, idx_y都是None. 对于图片、公式来说,block_content是图片的地址, 对于段落来说,block_content是段落的内容
  89. """
  90. all_bboxes = []
  91. for bbox, image_info in images_box_path_dict.items():
  92. all_bboxes.append([bbox[0], bbox[1], bbox[2], bbox[3], image_info, None, None, 'image'])
  93. paras_dict = paras_dict[f"page_{paras_dict['page_id']}"]
  94. for block_id, kvpair in paras_dict.items():
  95. bbox = kvpair['bbox']
  96. content = kvpair
  97. all_bboxes.append([bbox[0], bbox[1], bbox[2], bbox[3], content, None, None, 'text'])
  98. return all_bboxes
  99. def concat2markdown(all_bboxes:list):
  100. """
  101. 对排序后的bboxes拼接内容
  102. """
  103. content_md = ""
  104. for box in all_bboxes:
  105. content_type = box[CONTENT_TYPE_IDX]
  106. if content_type == 'image':
  107. image_type = box[CONTENT_IDX][1]
  108. image_path = box[CONTENT_IDX][0]
  109. content_md += f"![{image_type}]({image_path})"
  110. content_md += "\n\n"
  111. elif content_type == 'text': # 组装文本
  112. paras = box[CONTENT_IDX]['paras']
  113. text_content = ""
  114. for para_id, para in paras.items():# 拼装内部的段落文本
  115. text_content += para['text']
  116. text_content += "\n\n"
  117. content_md += text_content
  118. else:
  119. raise Exception(f"ERROR: {content_type} is not supported!")
  120. return content_md
  121. def main(s3_pdf_path: str, s3_pdf_profile: str, pdf_model_path:str, pdf_model_profile:str, save_path: str, page_num: int):
  122. """
  123. """
  124. pth = Path(s3_pdf_path)
  125. book_name = pth.name
  126. #book_name = "".join(os.path.basename(s3_pdf_path).split(".")[0:-1])
  127. res_dir_path = None
  128. exclude_bboxes = []
  129. # text_content_save_path = f"{save_path}/{book_name}/book.md"
  130. # metadata_save_path = f"{save_path}/{book_name}/metadata.json"
  131. try:
  132. pdf_bytes = read_file(s3_pdf_path, s3_pdf_profile)
  133. pdf_docs = fitz.open("pdf", pdf_bytes)
  134. page_id = page_num - 1
  135. page = pdf_docs[page_id] # 验证集只需要读取特定页面即可
  136. model_output_json = join_path(pdf_model_path, f"page_{page_num}.json") # 模型输出的页面编号从1开始的
  137. json_from_docx = read_file(model_output_json, pdf_model_profile) # TODO 这个读取方法名字应该改一下,避免语义歧义
  138. json_from_docx_obj = json.loads(json_from_docx)
  139. # 解析图片
  140. image_bboxes = parse_images(page_id, page, json_from_docx_obj)
  141. # 解析表格
  142. table_bboxes = parse_tables(page_id, page, json_from_docx_obj)
  143. # 解析公式
  144. equations_interline_bboxes, equations_inline_bboxes = parse_equations(page_id, page, json_from_docx_obj)
  145. # # 解析标题
  146. # title_bboxs = parse_titles(page_id, page, res_dir_path, json_from_docx_obj, exclude_bboxes)
  147. # # 解析页眉
  148. # header_bboxs = parse_headers(page_id, page, res_dir_path, json_from_docx_obj, exclude_bboxes)
  149. # # 解析页码
  150. # pageNo_bboxs = parse_pageNos(page_id, page, res_dir_path, json_from_docx_obj, exclude_bboxes)
  151. # # 解析脚注
  152. # footnote_bboxs = parse_footnotes(page_id, page, res_dir_path, json_from_docx_obj, exclude_bboxes)
  153. # # 解析页脚
  154. # footer_bboxs = parse_footers(page_id, page, res_dir_path, json_from_docx_obj, exclude_bboxes)
  155. # # 评估Layout是否规整、简单
  156. # isSimpleLayout_flag, fullColumn_cnt, subColumn_cnt, curPage_loss = evaluate_pdf_layout(page_id, page, res_dir_path, json_from_docx_obj, exclude_bboxes)
  157. # 把图、表、公式都进行截图,保存到本地,返回图片路径作为内容
  158. images_box_path_dict = get_images_by_bboxes(book_name, page_id, page, save_path, s3_pdf_profile, image_bboxes, table_bboxes, equations_inline_bboxes,
  159. equations_interline_bboxes) # 只要表格和图片的截图
  160. # 解析文字段落
  161. footer_bboxes = []
  162. header_bboxes = []
  163. exclude_bboxes = image_bboxes + table_bboxes
  164. paras_dict = process_blocks_per_page(page, page_id, image_bboxes, table_bboxes, equations_inline_bboxes, equations_interline_bboxes, footer_bboxes, header_bboxes)
  165. # paras_dict = postprocess_paras_pipeline(paras_dict)
  166. # 最后一步,根据bbox进行从左到右,从上到下的排序,之后拼接起来, 排序
  167. all_bboxes = reformat_bboxes(images_box_path_dict, paras_dict) # 由于公式目前还没有,所以equation_bboxes是None,多数存在段落里,暂时不解析
  168. # 返回的是一个数组,每个元素[x0, y0, x1, y1, block_content, idx_x, idx_y, type], 初始时候idx_x, idx_y都是None. 对于图片、公式来说,block_content是图片的地址, 对于段落来说,block_content是段落的内容
  169. # sorted_bboxes = bbox_sort(all_bboxes)
  170. # markdown_text = concat2markdown(sorted_bboxes)
  171. # parent_dir = os.path.dirname(text_content_save_path)
  172. # if not os.path.exists(parent_dir):
  173. # os.makedirs(parent_dir)
  174. # with open(text_content_save_path, "a") as f:
  175. # f.write(markdown_text)
  176. # f.write(chr(12)) #换页符
  177. # end for
  178. # 写一个小的json,记录元数据
  179. # metadata = {"book_name": book_name, "pdf_path": s3_pdf_path, "pdf_model_path": pdf_model_path, "save_path": save_path}
  180. # with open(metadata_save_path, "w") as f:
  181. # json.dump(metadata, f, ensure_ascii=False, indent=4)
  182. return all_bboxes
  183. except Exception as e:
  184. print(f"ERROR: {s3_pdf_path}, {e}", file=sys.stderr)
  185. logger.exception(e)
  186. # @click.command()
  187. # @click.option('--pdf-file-sub-path', help='s3上pdf文件的路径')
  188. # @click.option('--save-path', help='解析出来的图片,文本的保存父目录')
  189. def validation(validation_dataset: str, pdf_bin_file_profile: str, pdf_model_dir: str, pdf_model_profile: str, save_path: str):
  190. #pdf_bin_file_path = "s3://llm-raw-snew/llm-raw-scihub/scimag07865000-07865999/10.1007/"
  191. # pdf_bin_file_parent_path = "s3://llm-raw-snew/llm-raw-scihub/"
  192. # pdf_model_parent_dir = "s3://llm-pdf-text/layout_det/scihub/"
  193. # p = Path(pdf_file_sub_path)
  194. # pdf_parent_path = p.parent
  195. # pdf_file_name = p.name # pdf文件名字,含后缀
  196. # pdf_bin_file_path = join_path(pdf_bin_file_parent_path, pdf_parent_path)
  197. with open(validation_dataset, 'r') as f:
  198. samples = json.load(f)
  199. labels = []
  200. det_res = []
  201. edit_distance_list = []
  202. for sample in tqdm(samples):
  203. pdf_name = sample['pdf_name']
  204. s3_pdf_path = sample['s3_path']
  205. page_num = sample['page']
  206. gt_order = sample['order']
  207. pre = main(s3_pdf_path, pdf_bin_file_profile, join_path(pdf_model_dir, pdf_name), pdf_model_profile, save_path, page_num)
  208. pre_dict_list = []
  209. for item in pre:
  210. pre_sample = {
  211. 'box': [item[0],item[1],item[2],item[3]],
  212. 'type': item[7],
  213. 'score': 1
  214. }
  215. pre_dict_list.append(pre_sample)
  216. det_res.append(pre_dict_list)
  217. match_change_dict = { # 待确认
  218. "figure": "image",
  219. "svg_figure": "image",
  220. "inline_fomula": "equations_inline",
  221. "fomula": "equation_interline",
  222. "figure_caption": "text",
  223. "table_caption": "text",
  224. "fomula_caption": "text"
  225. }
  226. gt_annos = sample['annotations']
  227. matched_label = label_match(gt_annos, match_change_dict)
  228. labels.append(matched_label)
  229. # 判断排序函数的精度
  230. # 目前不考虑caption与图表相同序号的问题
  231. ignore_category = ['abandon', 'figure_caption', 'table_caption', 'formula_caption']
  232. gt_bboxes = format_gt_bbox(gt_annos, ignore_category)
  233. sorted_bboxes = bbox_sort(gt_bboxes)
  234. edit_distance = cal_edit_distance(sorted_bboxes)
  235. edit_distance_list.append(edit_distance)
  236. label_classes = ["image", "text", "table", "equation_interline"]
  237. detect_matrix = detect_val(labels, det_res, label_classes)
  238. print('detect_matrix', detect_matrix)
  239. edit_distance_mean = np.mean(edit_distance_list)
  240. print('edit_distance_mean', edit_distance_mean)
  241. if __name__ == '__main__':
  242. # 输入可以用以下命令生成批量pdf
  243. # aws s3 ls s3://llm-pdf-text/layout_det/scihub/ --profile langchao | tail -n 10 | awk '{print "s3://llm-pdf-text/layout_det/scihub/"$4}' | xargs -I{} aws s3 ls {} --recursive --profile langchao | awk '{print substr($4,19)}' | parallel -j 1 echo {//} | sort -u
  244. pdf_bin_file_profile = "outsider"
  245. pdf_model_dir = "s3://llm-pdf-text/eval_1k/layout_res/"
  246. pdf_model_profile = "langchao"
  247. # validation_dataset = "/mnt/petrelfs/share_data/ouyanglinke/OCR/OCR_validation_dataset.json"
  248. validation_dataset = "/mnt/petrelfs/share_data/ouyanglinke/OCR/OCR_validation_dataset_subset.json" # 测试
  249. save_path = "/mnt/petrelfs/share_data/ouyanglinke/OCR/OCR_val_result"
  250. validation(validation_dataset, pdf_bin_file_profile, pdf_model_dir, pdf_model_profile, save_path)