user_api.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. """用户输入: model数组,每个元素代表一个页面 pdf在s3的路径 截图保存的s3位置.
  2. 然后:
  3. 1)根据s3路径,调用spark集群的api,拿到ak,sk,endpoint,构造出s3PDFReader
  4. 2)根据用户输入的s3地址,调用spark集群的api,拿到ak,sk,endpoint,构造出s3ImageWriter
  5. 其余部分至于构造s3cli, 获取ak,sk都在code-clean里写代码完成。不要反向依赖!!!
  6. """
  7. from loguru import logger
  8. from magic_pdf.data.data_reader_writer import DataWriter
  9. from magic_pdf.data.dataset import Dataset
  10. from magic_pdf.libs.version import __version__
  11. from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
  12. from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
  13. from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt
  14. PARSE_TYPE_TXT = 'txt'
  15. PARSE_TYPE_OCR = 'ocr'
  16. def parse_txt_pdf(
  17. dataset: Dataset,
  18. model_list: list,
  19. imageWriter: DataWriter,
  20. is_debug=False,
  21. start_page_id=0,
  22. end_page_id=None,
  23. lang=None,
  24. *args,
  25. **kwargs
  26. ):
  27. """解析文本类pdf."""
  28. pdf_info_dict = parse_pdf_by_txt(
  29. dataset,
  30. model_list,
  31. imageWriter,
  32. start_page_id=start_page_id,
  33. end_page_id=end_page_id,
  34. debug_mode=is_debug,
  35. lang=lang,
  36. )
  37. pdf_info_dict['_parse_type'] = PARSE_TYPE_TXT
  38. pdf_info_dict['_version_name'] = __version__
  39. if lang is not None:
  40. pdf_info_dict['_lang'] = lang
  41. return pdf_info_dict
  42. def parse_ocr_pdf(
  43. dataset: Dataset,
  44. model_list: list,
  45. imageWriter: DataWriter,
  46. is_debug=False,
  47. start_page_id=0,
  48. end_page_id=None,
  49. lang=None,
  50. *args,
  51. **kwargs
  52. ):
  53. """解析ocr类pdf."""
  54. pdf_info_dict = parse_pdf_by_ocr(
  55. dataset,
  56. model_list,
  57. imageWriter,
  58. start_page_id=start_page_id,
  59. end_page_id=end_page_id,
  60. debug_mode=is_debug,
  61. lang=lang,
  62. )
  63. pdf_info_dict['_parse_type'] = PARSE_TYPE_OCR
  64. pdf_info_dict['_version_name'] = __version__
  65. if lang is not None:
  66. pdf_info_dict['_lang'] = lang
  67. return pdf_info_dict
  68. def parse_union_pdf(
  69. dataset: Dataset,
  70. model_list: list,
  71. imageWriter: DataWriter,
  72. is_debug=False,
  73. start_page_id=0,
  74. end_page_id=None,
  75. lang=None,
  76. *args,
  77. **kwargs
  78. ):
  79. """ocr和文本混合的pdf,全部解析出来."""
  80. def parse_pdf(method):
  81. try:
  82. return method(
  83. dataset,
  84. model_list,
  85. imageWriter,
  86. start_page_id=start_page_id,
  87. end_page_id=end_page_id,
  88. debug_mode=is_debug,
  89. lang=lang,
  90. )
  91. except Exception as e:
  92. logger.exception(e)
  93. return None
  94. pdf_info_dict = parse_pdf(parse_pdf_by_txt)
  95. if pdf_info_dict is None or pdf_info_dict.get('_need_drop', False):
  96. logger.warning('parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr')
  97. if len(model_list) == 0:
  98. layout_model = kwargs.get('layout_model', None)
  99. formula_enable = kwargs.get('formula_enable', None)
  100. table_enable = kwargs.get('table_enable', None)
  101. infer_res = doc_analyze(
  102. dataset,
  103. ocr=True,
  104. start_page_id=start_page_id,
  105. end_page_id=end_page_id,
  106. lang=lang,
  107. layout_model=layout_model,
  108. formula_enable=formula_enable,
  109. table_enable=table_enable,
  110. )
  111. model_list = infer_res.get_infer_res()
  112. pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
  113. if pdf_info_dict is None:
  114. raise Exception('Both parse_pdf_by_txt and parse_pdf_by_ocr failed.')
  115. else:
  116. pdf_info_dict['_parse_type'] = PARSE_TYPE_OCR
  117. else:
  118. pdf_info_dict['_parse_type'] = PARSE_TYPE_TXT
  119. pdf_info_dict['_version_name'] = __version__
  120. if lang is not None:
  121. pdf_info_dict['_lang'] = lang
  122. return pdf_info_dict