magic_model.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import json
  2. from magic_pdf.libs.commons import fitz
  3. from loguru import logger
  4. from magic_pdf.libs.commons import join_path
  5. from magic_pdf.libs.coordinate_transform import get_scale_ratio
  6. from magic_pdf.libs.ocr_content_type import ContentType
  7. from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
  8. from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
  9. class MagicModel():
  10. """
  11. 每个函数没有得到元素的时候返回空list
  12. """
  13. def __fix_axis(self):
  14. need_remove_list = []
  15. for model_page_info in self.__model_list:
  16. page_no = model_page_info['page_info']['page_no']
  17. horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(model_page_info, self.__docs[page_no])
  18. layout_dets = model_page_info["layout_dets"]
  19. for layout_det in layout_dets:
  20. x0, y0, _, _, x1, y1, _, _ = layout_det["poly"]
  21. bbox = [
  22. int(x0 / horizontal_scale_ratio),
  23. int(y0 / vertical_scale_ratio),
  24. int(x1 / horizontal_scale_ratio),
  25. int(y1 / vertical_scale_ratio),
  26. ]
  27. layout_det["bbox"] = bbox
  28. # 删除高度或者宽度为0的spans
  29. if bbox[2] - bbox[0] == 0 or bbox[3] - bbox[1] == 0:
  30. need_remove_list.append(layout_det)
  31. for need_remove in need_remove_list:
  32. layout_dets.remove(need_remove)
  33. def __init__(self, model_list: list, docs: fitz.Document):
  34. self.__model_list = model_list
  35. self.__docs = docs
  36. self.__fix_axis()
  37. def get_imgs(self, page_no: int): # @许瑞
  38. image_block = {
  39. }
  40. image_block['bbox'] = [x0, y0, x1, y1] # 计算出来
  41. image_block['img_body_bbox'] = [x0, y0, x1, y1]
  42. image_blcok['img_caption_bbox'] = [x0, y0, x1, y1] # 如果没有就是None,但是保证key存在
  43. return [image_block, ]
  44. def get_tables(self, page_no: int) -> list: # 3个坐标, caption, table主体,table-note
  45. pass # 许瑞, 结构和image一样
  46. def get_equations(self, page_no: int) -> list: # 有坐标,也有字
  47. return inline_equations, interline_equations # @凯文
  48. def get_discarded(self, page_no: int) -> list: # 自研模型,只有坐标
  49. pass # @凯文
  50. def get_text_blocks(self, page_no: int) -> list: # 自研模型搞的,只有坐标,没有字
  51. pass # @凯文
  52. def get_title_blocks(self, page_no: int) -> list: # 自研模型,只有坐标,没字
  53. pass # @凯文
  54. def get_ocr_text(self, page_no: int) -> list: # paddle 搞的,有字也有坐标
  55. text_spans = []
  56. model_page_info = self.__model_list[page_no]
  57. layout_dets = model_page_info["layout_dets"]
  58. for layout_det in layout_dets:
  59. if layout_det["category_id"] == "15":
  60. span = {
  61. "bbox": layout_det['bbox'],
  62. "content": layout_det["text"],
  63. }
  64. text_spans.append(span)
  65. return text_spans
  66. def get_all_spans(self, page_no: int) -> list:
  67. all_spans = []
  68. model_page_info = self.__model_list[page_no]
  69. layout_dets = model_page_info["layout_dets"]
  70. allow_category_id_list = [3, 5, 13, 14, 15]
  71. """当成span拼接的"""
  72. # 3: 'image', # 图片
  73. # 4: 'table', # 表格
  74. # 13: 'inline_equation', # 行内公式
  75. # 14: 'interline_equation', # 行间公式
  76. # 15: 'text', # ocr识别文本
  77. for layout_det in layout_dets:
  78. category_id = layout_det["category_id"]
  79. if category_id in allow_category_id_list:
  80. span = {
  81. "bbox": layout_det['bbox']
  82. }
  83. if category_id == 3:
  84. span["type"] = ContentType.Image
  85. elif category_id == 5:
  86. span["type"] = ContentType.Table
  87. elif category_id == 13:
  88. span["content"] = layout_det["latex"]
  89. span["type"] = ContentType.InlineEquation
  90. elif category_id == 14:
  91. span["content"] = layout_det["latex"]
  92. span["type"] = ContentType.InterlineEquation
  93. elif category_id == 15:
  94. span["content"] = layout_det["text"]
  95. span["type"] = ContentType.Text
  96. all_spans.append(span)
  97. return all_spans
  98. def get_page_size(self, page_no: int): # 获取页面宽高
  99. # 获取当前页的page对象
  100. page = self.__docs[page_no]
  101. # 获取当前页的宽高
  102. page_w = page.rect.width
  103. page_h = page.rect.height
  104. return page_w, page_h
  105. if __name__ == '__main__':
  106. drw = DiskReaderWriter(r"D:/project/20231108code-clean")
  107. pdf_file_path = r"linshixuqiu\19983-00.pdf"
  108. model_file_path = r"linshixuqiu\19983-00_new.json"
  109. pdf_bytes = drw.read(pdf_file_path, AbsReaderWriter.MODE_BIN)
  110. model_json_txt = drw.read(model_file_path, AbsReaderWriter.MODE_TXT)
  111. model_list = json.loads(model_json_txt)
  112. write_path = r"D:\project\20231108code-clean\linshixuqiu\19983-00"
  113. img_bucket_path = "imgs"
  114. img_writer = DiskReaderWriter(join_path(write_path, img_bucket_path))
  115. pdf_docs = fitz.open("pdf", pdf_bytes)
  116. magic_model = MagicModel(model_list, pdf_docs)