utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. import json
  2. import os
  3. from pathlib import Path
  4. from loguru import logger
  5. import magic_pdf.model as model_config
  6. from magic_pdf.config.ocr_content_type import BlockType, ContentType
  7. from magic_pdf.data.data_reader_writer import FileBasedDataReader
  8. from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
  9. from magic_pdf.integrations.rag.type import (CategoryType, ContentObject,
  10. ElementRelation, ElementRelType,
  11. LayoutElements,
  12. LayoutElementsExtra, PageInfo)
  13. from magic_pdf.tools.common import do_parse, prepare_env
  14. def convert_middle_json_to_layout_elements(
  15. json_data: dict,
  16. output_dir: str,
  17. ) -> list[LayoutElements]:
  18. uniq_anno_id = 0
  19. res: list[LayoutElements] = []
  20. for page_no, page_data in enumerate(json_data['pdf_info']):
  21. order_id = 0
  22. page_info = PageInfo(
  23. height=int(page_data['page_size'][1]),
  24. width=int(page_data['page_size'][0]),
  25. page_no=page_no,
  26. )
  27. layout_dets: list[ContentObject] = []
  28. extra_element_relation: list[ElementRelation] = []
  29. for para_block in page_data['para_blocks']:
  30. para_text = ''
  31. para_type = para_block['type']
  32. if para_type == BlockType.Text:
  33. para_text = merge_para_with_text(para_block)
  34. x0, y0, x1, y1 = para_block['bbox']
  35. content = ContentObject(
  36. anno_id=uniq_anno_id,
  37. category_type=CategoryType.text,
  38. text=para_text,
  39. order=order_id,
  40. poly=[x0, y0, x1, y0, x1, y1, x0, y1],
  41. )
  42. uniq_anno_id += 1
  43. order_id += 1
  44. layout_dets.append(content)
  45. elif para_type == BlockType.Title:
  46. para_text = merge_para_with_text(para_block)
  47. x0, y0, x1, y1 = para_block['bbox']
  48. content = ContentObject(
  49. anno_id=uniq_anno_id,
  50. category_type=CategoryType.title,
  51. text=para_text,
  52. order=order_id,
  53. poly=[x0, y0, x1, y0, x1, y1, x0, y1],
  54. )
  55. uniq_anno_id += 1
  56. order_id += 1
  57. layout_dets.append(content)
  58. elif para_type == BlockType.InterlineEquation:
  59. para_text = merge_para_with_text(para_block)
  60. x0, y0, x1, y1 = para_block['bbox']
  61. content = ContentObject(
  62. anno_id=uniq_anno_id,
  63. category_type=CategoryType.interline_equation,
  64. text=para_text,
  65. order=order_id,
  66. poly=[x0, y0, x1, y0, x1, y1, x0, y1],
  67. )
  68. uniq_anno_id += 1
  69. order_id += 1
  70. layout_dets.append(content)
  71. elif para_type == BlockType.Image:
  72. body_anno_id = -1
  73. caption_anno_id = -1
  74. for block in para_block['blocks']:
  75. if block['type'] == BlockType.ImageBody:
  76. for line in block['lines']:
  77. for span in line['spans']:
  78. if span['type'] == ContentType.Image:
  79. x0, y0, x1, y1 = block['bbox']
  80. content = ContentObject(
  81. anno_id=uniq_anno_id,
  82. category_type=CategoryType.image_body,
  83. image_path=os.path.join(
  84. output_dir, span['image_path']),
  85. order=order_id,
  86. poly=[x0, y0, x1, y0, x1, y1, x0, y1],
  87. )
  88. body_anno_id = uniq_anno_id
  89. uniq_anno_id += 1
  90. order_id += 1
  91. layout_dets.append(content)
  92. for block in para_block['blocks']:
  93. if block['type'] == BlockType.ImageCaption:
  94. para_text += merge_para_with_text(block)
  95. x0, y0, x1, y1 = block['bbox']
  96. content = ContentObject(
  97. anno_id=uniq_anno_id,
  98. category_type=CategoryType.image_caption,
  99. text=para_text,
  100. order=order_id,
  101. poly=[x0, y0, x1, y0, x1, y1, x0, y1],
  102. )
  103. caption_anno_id = uniq_anno_id
  104. uniq_anno_id += 1
  105. order_id += 1
  106. layout_dets.append(content)
  107. if body_anno_id > 0 and caption_anno_id > 0:
  108. element_relation = ElementRelation(
  109. relation=ElementRelType.sibling,
  110. source_anno_id=body_anno_id,
  111. target_anno_id=caption_anno_id,
  112. )
  113. extra_element_relation.append(element_relation)
  114. elif para_type == BlockType.Table:
  115. body_anno_id, caption_anno_id, footnote_anno_id = -1, -1, -1
  116. for block in para_block['blocks']:
  117. if block['type'] == BlockType.TableCaption:
  118. para_text += merge_para_with_text(block)
  119. x0, y0, x1, y1 = block['bbox']
  120. content = ContentObject(
  121. anno_id=uniq_anno_id,
  122. category_type=CategoryType.table_caption,
  123. text=para_text,
  124. order=order_id,
  125. poly=[x0, y0, x1, y0, x1, y1, x0, y1],
  126. )
  127. caption_anno_id = uniq_anno_id
  128. uniq_anno_id += 1
  129. order_id += 1
  130. layout_dets.append(content)
  131. for block in para_block['blocks']:
  132. if block['type'] == BlockType.TableBody:
  133. for line in block['lines']:
  134. for span in line['spans']:
  135. if span['type'] == ContentType.Table:
  136. x0, y0, x1, y1 = para_block['bbox']
  137. content = ContentObject(
  138. anno_id=uniq_anno_id,
  139. category_type=CategoryType.table_body,
  140. order=order_id,
  141. poly=[x0, y0, x1, y0, x1, y1, x0, y1],
  142. )
  143. body_anno_id = uniq_anno_id
  144. uniq_anno_id += 1
  145. order_id += 1
  146. # if processed by table model
  147. if span.get('latex', ''):
  148. content.latex = span['latex']
  149. else:
  150. content.image_path = os.path.join(
  151. output_dir, span['image_path'])
  152. layout_dets.append(content)
  153. for block in para_block['blocks']:
  154. if block['type'] == BlockType.TableFootnote:
  155. para_text += merge_para_with_text(block)
  156. x0, y0, x1, y1 = block['bbox']
  157. content = ContentObject(
  158. anno_id=uniq_anno_id,
  159. category_type=CategoryType.table_footnote,
  160. text=para_text,
  161. order=order_id,
  162. poly=[x0, y0, x1, y0, x1, y1, x0, y1],
  163. )
  164. footnote_anno_id = uniq_anno_id
  165. uniq_anno_id += 1
  166. order_id += 1
  167. layout_dets.append(content)
  168. if caption_anno_id != -1 and body_anno_id != -1:
  169. element_relation = ElementRelation(
  170. relation=ElementRelType.sibling,
  171. source_anno_id=body_anno_id,
  172. target_anno_id=caption_anno_id,
  173. )
  174. extra_element_relation.append(element_relation)
  175. if footnote_anno_id != -1 and body_anno_id != -1:
  176. element_relation = ElementRelation(
  177. relation=ElementRelType.sibling,
  178. source_anno_id=body_anno_id,
  179. target_anno_id=footnote_anno_id,
  180. )
  181. extra_element_relation.append(element_relation)
  182. res.append(
  183. LayoutElements(
  184. page_info=page_info,
  185. layout_dets=layout_dets,
  186. extra=LayoutElementsExtra(
  187. element_relation=extra_element_relation),
  188. ))
  189. return res
  190. def inference(path, output_dir, method):
  191. model_config.__use_inside_model__ = True
  192. model_config.__model_mode__ = 'full'
  193. if output_dir == '':
  194. if os.path.isdir(path):
  195. output_dir = os.path.join(path, 'output')
  196. else:
  197. output_dir = os.path.join(os.path.dirname(path), 'output')
  198. local_image_dir, local_md_dir = prepare_env(output_dir,
  199. str(Path(path).stem), method)
  200. def read_fn(path):
  201. disk_rw = FileBasedDataReader(os.path.dirname(path))
  202. return disk_rw.read(os.path.basename(path))
  203. def parse_doc(doc_path: str):
  204. try:
  205. file_name = str(Path(doc_path).stem)
  206. pdf_data = read_fn(doc_path)
  207. do_parse(
  208. output_dir,
  209. file_name,
  210. pdf_data,
  211. [],
  212. method,
  213. False,
  214. f_draw_span_bbox=False,
  215. f_draw_layout_bbox=False,
  216. f_dump_md=False,
  217. f_dump_middle_json=True,
  218. f_dump_model_json=False,
  219. f_dump_orig_pdf=False,
  220. f_dump_content_list=False,
  221. f_draw_model_bbox=False,
  222. )
  223. middle_json_fn = os.path.join(local_md_dir,
  224. f'{file_name}_middle.json')
  225. with open(middle_json_fn) as fd:
  226. jso = json.load(fd)
  227. os.remove(middle_json_fn)
  228. return convert_middle_json_to_layout_elements(jso, local_image_dir)
  229. except Exception as e:
  230. logger.exception(e)
  231. return parse_doc(path)
  232. if __name__ == '__main__':
  233. import pprint
  234. base_dir = '/opt/data/pdf/resources/samples/'
  235. if 0:
  236. with open(base_dir + 'json_outputs/middle.json') as f:
  237. d = json.load(f)
  238. result = convert_middle_json_to_layout_elements(d, '/tmp')
  239. pprint.pp(result)
  240. if 0:
  241. with open(base_dir + 'json_outputs/middle.3.json') as f:
  242. d = json.load(f)
  243. result = convert_middle_json_to_layout_elements(d, '/tmp')
  244. pprint.pp(result)
  245. if 1:
  246. res = inference(
  247. base_dir + 'samples/pdf/one_page_with_table_image.pdf',
  248. '/tmp/output',
  249. 'ocr',
  250. )
  251. pprint.pp(res)