pdf2json_infer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. import sys
  2. from typing import Tuple
  3. import os
  4. import boto3, json
  5. from botocore.config import Config
  6. from libs.commons import fitz
  7. from loguru import logger
  8. from pathlib import Path
  9. from tqdm import tqdm
  10. import numpy as np
  11. # sys.path.insert(0, "/mnt/petrelfs/ouyanglinke/code-clean/")
  12. # print(sys.path)
  13. from validation import cal_edit_distance, format_gt_bbox, label_match, detect_val
  14. # from pdf2text_recogFigure_20231107 import parse_images # 获取figures的bbox
  15. # from pdf2text_recogTable_20231107 import parse_tables # 获取tables的bbox
  16. # from pdf2text_recogEquation_20231108 import parse_equations # 获取equations的bbox
  17. # from pdf2text_recogTitle_20231113 import parse_titles # 获取Title的bbox
  18. # from pdf2text_recogPara import parse_blocks_per_page
  19. # from bbox_sort import bbox_sort, CONTENT_IDX, CONTENT_TYPE_IDX
  20. from layout.bbox_sort import bbox_sort, CONTENT_IDX, CONTENT_TYPE_IDX
  21. from pre_proc.detect_images import parse_images # 获取figures的bbox
  22. from pdf2text_recogTable import parse_tables # 获取tables的bbox
  23. from pre_proc.detect_equation import parse_equations # 获取equations的bbox
  24. # from pdf2text_recogFootnote import parse_footnotes # 获取footnotes的bbox
  25. from pdf2text_recogPara import process_blocks_per_page
  26. from libs.commons import parse_aws_param, parse_bucket_key, read_file, join_path
  27. def cut_image(bbox: Tuple, page_num: int, page: fitz.Page, save_parent_path: str, s3_profile: str):
  28. """
  29. 从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径
  30. save_path:需要同时支持s3和本地, 图片存放在save_path下,文件名是: {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。
  31. """
  32. # 拼接路径
  33. image_save_path = join_path(save_parent_path, f"{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}.jpg")
  34. try:
  35. # 将坐标转换为fitz.Rect对象
  36. rect = fitz.Rect(*bbox)
  37. # 配置缩放倍数为3倍
  38. zoom = fitz.Matrix(3, 3)
  39. # 截取图片
  40. pix = page.get_pixmap(clip=rect, matrix=zoom)
  41. # 打印图片文件名
  42. # print(f"Saved {image_save_path}")
  43. if image_save_path.startswith("s3://"):
  44. ak, sk, end_point, addressing_style = parse_aws_param(s3_profile)
  45. cli = boto3.client(service_name="s3", aws_access_key_id=ak, aws_secret_access_key=sk, endpoint_url=end_point,
  46. config=Config(s3={'addressing_style': addressing_style}))
  47. bucket_name, bucket_key = parse_bucket_key(image_save_path)
  48. # 将字节流上传到s3
  49. cli.upload_fileobj(pix.tobytes(output='jpeg', jpg_quality=95), bucket_name, bucket_key)
  50. else:
  51. # 保存图片到本地
  52. # 先检查一下image_save_path的父目录是否存在,如果不存在,就创建
  53. parent_dir = os.path.dirname(image_save_path)
  54. if not os.path.exists(parent_dir):
  55. os.makedirs(parent_dir)
  56. pix.save(image_save_path, jpg_quality=95)
  57. # 为了直接能在markdown里看,这里把地址改为相对于mardown的地址
  58. pth = Path(image_save_path)
  59. image_save_path = f"{pth.parent.name}/{pth.name}"
  60. return image_save_path
  61. except Exception as e:
  62. logger.exception(e)
  63. return image_save_path
  64. def get_images_by_bboxes(book_name:str, page_num:int, page: fitz.Page, save_path:str, s3_profile:str, image_bboxes:list, table_bboxes:list, equation_inline_bboxes:list, equation_interline_bboxes:list) -> dict:
  65. """
  66. 返回一个dict, key为bbox, 值是图片地址
  67. """
  68. ret = {}
  69. # 图片的保存路径组成是这样的: {s3_or_local_path}/{book_name}/{images|tables|equations}/{page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg
  70. image_save_path = join_path(save_path, book_name, "images")
  71. table_save_path = join_path(save_path, book_name, "tables")
  72. equation_inline_save_path = join_path(save_path, book_name, "equations_inline")
  73. equation_interline_save_path = join_path(save_path, book_name, "equation_interline")
  74. for bbox in image_bboxes:
  75. image_path = cut_image(bbox, page_num, page, image_save_path, s3_profile)
  76. ret[bbox] = (image_path, "image") # 第二个元素是"image",表示是图片
  77. for bbox in table_bboxes:
  78. image_path = cut_image(bbox, page_num, page, table_save_path, s3_profile)
  79. ret[bbox] = (image_path, "table")
  80. # 对公式目前只截图,不返回
  81. for bbox in equation_inline_bboxes:
  82. cut_image(bbox, page_num, page, equation_inline_save_path, s3_profile)
  83. for bbox in equation_interline_bboxes:
  84. cut_image(bbox, page_num, page, equation_interline_save_path, s3_profile)
  85. return ret
  86. def reformat_bboxes(images_box_path_dict:list, paras_dict:dict):
  87. """
  88. 把bbox重新组装成一个list,每个元素[x0, y0, x1, y1, block_content, idx_x, idx_y], 初始时候idx_x, idx_y都是None. 对于图片、公式来说,block_content是图片的地址, 对于段落来说,block_content是段落的内容
  89. """
  90. all_bboxes = []
  91. for bbox, image_info in images_box_path_dict.items():
  92. all_bboxes.append([bbox[0], bbox[1], bbox[2], bbox[3], image_info, None, None, 'image'])
  93. paras_dict = paras_dict[f"page_{paras_dict['page_id']}"]
  94. for block_id, kvpair in paras_dict.items():
  95. bbox = kvpair['bbox']
  96. content = kvpair
  97. all_bboxes.append([bbox[0], bbox[1], bbox[2], bbox[3], content, None, None, 'text'])
  98. return all_bboxes
  99. def concat2markdown(all_bboxes:list):
  100. """
  101. 对排序后的bboxes拼接内容
  102. """
  103. content_md = ""
  104. for box in all_bboxes:
  105. content_type = box[CONTENT_TYPE_IDX]
  106. if content_type == 'image':
  107. image_type = box[CONTENT_IDX][1]
  108. image_path = box[CONTENT_IDX][0]
  109. content_md += f"![{image_type}]({image_path})"
  110. content_md += "\n\n"
  111. elif content_type == 'text': # 组装文本
  112. paras = box[CONTENT_IDX]['paras']
  113. text_content = ""
  114. for para_id, para in paras.items():# 拼装内部的段落文本
  115. text_content += para['text']
  116. text_content += "\n\n"
  117. content_md += text_content
  118. else:
  119. raise Exception(f"ERROR: {content_type} is not supported!")
  120. return content_md
  121. def main(s3_pdf_path: str, s3_pdf_profile: str, pdf_model_path:str, pdf_model_profile:str, save_path: str, page_num: int):
  122. """
  123. """
  124. pth = Path(s3_pdf_path)
  125. book_name = pth.name
  126. #book_name = "".join(os.path.basename(s3_pdf_path).split(".")[0:-1])
  127. res_dir_path = None
  128. exclude_bboxes = []
  129. # text_content_save_path = f"{save_path}/{book_name}/book.md"
  130. # metadata_save_path = f"{save_path}/{book_name}/metadata.json"
  131. try:
  132. pdf_bytes = read_file(s3_pdf_path, s3_pdf_profile)
  133. pdf_docs = fitz.open("pdf", pdf_bytes)
  134. page_id = page_num - 1
  135. page = pdf_docs[page_id] # 验证集只需要读取特定页面即可
  136. model_output_json = join_path(pdf_model_path, f"page_{page_num}.json") # 模型输出的页面编号从1开始的
  137. json_from_docx = read_file(model_output_json, pdf_model_profile) # TODO 这个读取方法名字应该改一下,避免语义歧义
  138. json_from_docx_obj = json.loads(json_from_docx)
  139. # 解析图片
  140. image_bboxes = parse_images(page_id, page, json_from_docx_obj)
  141. # 解析表格
  142. table_bboxes = parse_tables(page_id, page, json_from_docx_obj)
  143. # 解析公式
  144. equations_interline_bboxes, equations_inline_bboxes = parse_equations(page_id, page, json_from_docx_obj)
  145. # # 解析标题
  146. # title_bboxs = parse_titles(page_id, page, res_dir_path, json_from_docx_obj, exclude_bboxes)
  147. # # 解析页眉
  148. # header_bboxs = parse_headers(page_id, page, res_dir_path, json_from_docx_obj, exclude_bboxes)
  149. # # 解析页码
  150. # pageNo_bboxs = parse_pageNos(page_id, page, res_dir_path, json_from_docx_obj, exclude_bboxes)
  151. # # 解析脚注
  152. # footnote_bboxs = parse_footnotes(page_id, page, res_dir_path, json_from_docx_obj, exclude_bboxes)
  153. # # 解析页脚
  154. # footer_bboxs = parse_footers(page_id, page, res_dir_path, json_from_docx_obj, exclude_bboxes)
  155. # # 评估Layout是否规整、简单
  156. # isSimpleLayout_flag, fullColumn_cnt, subColumn_cnt, curPage_loss = evaluate_pdf_layout(page_id, page, res_dir_path, json_from_docx_obj, exclude_bboxes)
  157. # 把图、表、公式都进行截图,保存到本地,返回图片路径作为内容
  158. images_box_path_dict = get_images_by_bboxes(book_name, page_id, page, save_path, s3_pdf_profile, image_bboxes, table_bboxes, equations_inline_bboxes,
  159. equations_interline_bboxes) # 只要表格和图片的截图
  160. # 解析文字段落
  161. footer_bboxes = []
  162. header_bboxes = []
  163. exclude_bboxes = image_bboxes + table_bboxes
  164. paras_dict = process_blocks_per_page(page, page_id, image_bboxes, table_bboxes, equations_inline_bboxes, equations_interline_bboxes, footer_bboxes, header_bboxes)
  165. # paras_dict = postprocess_paras_pipeline(paras_dict)
  166. # 最后一步,根据bbox进行从左到右,从上到下的排序,之后拼接起来, 排序
  167. all_bboxes = reformat_bboxes(images_box_path_dict, paras_dict) # 由于公式目前还没有,所以equation_bboxes是None,多数存在段落里,暂时不解析
  168. # 返回的是一个数组,每个元素[x0, y0, x1, y1, block_content, idx_x, idx_y, type], 初始时候idx_x, idx_y都是None. 对于图片、公式来说,block_content是图片的地址, 对于段落来说,block_content是段落的内容
  169. # sorted_bboxes = bbox_sort(all_bboxes)
  170. # markdown_text = concat2markdown(sorted_bboxes)
  171. # parent_dir = os.path.dirname(text_content_save_path)
  172. # if not os.path.exists(parent_dir):
  173. # os.makedirs(parent_dir)
  174. # with open(text_content_save_path, "a") as f:
  175. # f.write(markdown_text)
  176. # f.write(chr(12)) #换页符
  177. # end for
  178. # 写一个小的json,记录元数据
  179. # metadata = {"book_name": book_name, "pdf_path": s3_pdf_path, "pdf_model_path": pdf_model_path, "save_path": save_path}
  180. # with open(metadata_save_path, "w") as f:
  181. # json.dump(metadata, f, ensure_ascii=False, indent=4)
  182. return all_bboxes
  183. except Exception as e:
  184. print(f"ERROR: {s3_pdf_path}, {e}", file=sys.stderr)
  185. logger.exception(e)
  186. # @click.command()
  187. # @click.option('--pdf-file-sub-path', help='s3上pdf文件的路径')
  188. # @click.option('--save-path', help='解析出来的图片,文本的保存父目录')
  189. def validation(validation_dataset: str, pdf_bin_file_profile: str, pdf_model_dir: str, pdf_model_profile: str, save_path: str):
  190. #pdf_bin_file_path = "s3://llm-raw-snew/llm-raw-scihub/scimag07865000-07865999/10.1007/"
  191. # pdf_bin_file_parent_path = "s3://llm-raw-snew/llm-raw-scihub/"
  192. # pdf_model_parent_dir = "s3://llm-pdf-text/layout_det/scihub/"
  193. # p = Path(pdf_file_sub_path)
  194. # pdf_parent_path = p.parent
  195. # pdf_file_name = p.name # pdf文件名字,含后缀
  196. # pdf_bin_file_path = join_path(pdf_bin_file_parent_path, pdf_parent_path)
  197. with open(validation_dataset, 'r') as f:
  198. samples = json.load(f)
  199. labels = []
  200. det_res = []
  201. edit_distance_list = []
  202. for sample in tqdm(samples):
  203. pdf_name = sample['pdf_name']
  204. s3_pdf_path = sample['s3_path']
  205. page_num = sample['page']
  206. gt_order = sample['order']
  207. pre = main(s3_pdf_path, pdf_bin_file_profile, join_path(pdf_model_dir, pdf_name), pdf_model_profile, save_path, page_num)
  208. pre_dict_list = []
  209. for item in pre:
  210. pre_sample = {
  211. 'box': [item[0],item[1],item[2],item[3]],
  212. 'type': item[7],
  213. 'score': 1
  214. }
  215. pre_dict_list.append(pre_sample)
  216. det_res.append(pre_dict_list)
  217. match_change_dict = { # 待确认
  218. "figure": "image",
  219. "svg_figure": "image",
  220. "inline_fomula": "equations_inline",
  221. "fomula": "equation_interline",
  222. "figure_caption": "text",
  223. "table_caption": "text",
  224. "fomula_caption": "text"
  225. }
  226. gt_annos = sample['annotations']
  227. matched_label = label_match(gt_annos, match_change_dict)
  228. labels.append(matched_label)
  229. # 判断排序函数的精度
  230. # 目前不考虑caption与图表相同序号的问题
  231. ignore_category = ['abandon', 'figure_caption', 'table_caption', 'formula_caption']
  232. gt_bboxes = format_gt_bbox(gt_annos, ignore_category)
  233. sorted_bboxes = bbox_sort(gt_bboxes)
  234. edit_distance = cal_edit_distance(sorted_bboxes)
  235. edit_distance_list.append(edit_distance)
  236. label_classes = ["image", "text", "table", "equation_interline"]
  237. detect_matrix = detect_val(labels, det_res, label_classes)
  238. print('detect_matrix', detect_matrix)
  239. edit_distance_mean = np.mean(edit_distance_list)
  240. print('edit_distance_mean', edit_distance_mean)
  241. if __name__ == '__main__':
  242. # 输入可以用以下命令生成批量pdf
  243. # aws s3 ls s3://llm-pdf-text/layout_det/scihub/ --profile langchao | tail -n 10 | awk '{print "s3://llm-pdf-text/layout_det/scihub/"$4}' | xargs -I{} aws s3 ls {} --recursive --profile langchao | awk '{print substr($4,19)}' | parallel -j 1 echo {//} | sort -u
  244. pdf_bin_file_profile = "outsider"
  245. pdf_model_dir = "s3://llm-pdf-text/eval_1k/layout_res/"
  246. pdf_model_profile = "langchao"
  247. # validation_dataset = "/mnt/petrelfs/share_data/ouyanglinke/OCR/OCR_validation_dataset.json"
  248. validation_dataset = "/mnt/petrelfs/share_data/ouyanglinke/OCR/OCR_validation_dataset_subset.json" # 测试
  249. save_path = "/mnt/petrelfs/share_data/ouyanglinke/OCR/OCR_val_result"
  250. validation(validation_dataset, pdf_bin_file_profile, pdf_model_dir, pdf_model_profile, save_path)