user_api.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. """用户输入: model数组,每个元素代表一个页面 pdf在s3的路径 截图保存的s3位置.
  2. 然后:
  3. 1)根据s3路径,调用spark集群的api,拿到ak,sk,endpoint,构造出s3PDFReader
  4. 2)根据用户输入的s3地址,调用spark集群的api,拿到ak,sk,endpoint,构造出s3ImageWriter
  5. 其余部分至于构造s3cli, 获取ak,sk都在code-clean里写代码完成。不要反向依赖!!!
  6. """
  7. from loguru import logger
  8. from magic_pdf.data.data_reader_writer import DataWriter
  9. from magic_pdf.data.dataset import Dataset
  10. from magic_pdf.libs.version import __version__
  11. from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
  12. from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
  13. from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt
  14. from magic_pdf.config.constants import PARSE_TYPE_TXT, PARSE_TYPE_OCR
  15. def parse_txt_pdf(
  16. dataset: Dataset,
  17. model_list: list,
  18. imageWriter: DataWriter,
  19. is_debug=False,
  20. start_page_id=0,
  21. end_page_id=None,
  22. lang=None,
  23. *args,
  24. **kwargs
  25. ):
  26. """解析文本类pdf."""
  27. pdf_info_dict = parse_pdf_by_txt(
  28. dataset,
  29. model_list,
  30. imageWriter,
  31. start_page_id=start_page_id,
  32. end_page_id=end_page_id,
  33. debug_mode=is_debug,
  34. lang=lang,
  35. )
  36. pdf_info_dict['_parse_type'] = PARSE_TYPE_TXT
  37. pdf_info_dict['_version_name'] = __version__
  38. if lang is not None:
  39. pdf_info_dict['_lang'] = lang
  40. return pdf_info_dict
  41. def parse_ocr_pdf(
  42. dataset: Dataset,
  43. model_list: list,
  44. imageWriter: DataWriter,
  45. is_debug=False,
  46. start_page_id=0,
  47. end_page_id=None,
  48. lang=None,
  49. *args,
  50. **kwargs
  51. ):
  52. """解析ocr类pdf."""
  53. pdf_info_dict = parse_pdf_by_ocr(
  54. dataset,
  55. model_list,
  56. imageWriter,
  57. start_page_id=start_page_id,
  58. end_page_id=end_page_id,
  59. debug_mode=is_debug,
  60. lang=lang,
  61. )
  62. pdf_info_dict['_parse_type'] = PARSE_TYPE_OCR
  63. pdf_info_dict['_version_name'] = __version__
  64. if lang is not None:
  65. pdf_info_dict['_lang'] = lang
  66. return pdf_info_dict
  67. def parse_union_pdf(
  68. dataset: Dataset,
  69. model_list: list,
  70. imageWriter: DataWriter,
  71. is_debug=False,
  72. start_page_id=0,
  73. end_page_id=None,
  74. lang=None,
  75. *args,
  76. **kwargs
  77. ):
  78. """ocr和文本混合的pdf,全部解析出来."""
  79. def parse_pdf(method):
  80. try:
  81. return method(
  82. dataset,
  83. model_list,
  84. imageWriter,
  85. start_page_id=start_page_id,
  86. end_page_id=end_page_id,
  87. debug_mode=is_debug,
  88. lang=lang,
  89. )
  90. except Exception as e:
  91. logger.exception(e)
  92. return None
  93. pdf_info_dict = parse_pdf(parse_pdf_by_txt)
  94. if pdf_info_dict is None or pdf_info_dict.get('_need_drop', False):
  95. logger.warning('parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr')
  96. if len(model_list) == 0:
  97. layout_model = kwargs.get('layout_model', None)
  98. formula_enable = kwargs.get('formula_enable', None)
  99. table_enable = kwargs.get('table_enable', None)
  100. infer_res = doc_analyze(
  101. dataset,
  102. ocr=True,
  103. start_page_id=start_page_id,
  104. end_page_id=end_page_id,
  105. lang=lang,
  106. layout_model=layout_model,
  107. formula_enable=formula_enable,
  108. table_enable=table_enable,
  109. )
  110. model_list = infer_res.get_infer_res()
  111. pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
  112. if pdf_info_dict is None:
  113. raise Exception('Both parse_pdf_by_txt and parse_pdf_by_ocr failed.')
  114. else:
  115. pdf_info_dict['_parse_type'] = PARSE_TYPE_OCR
  116. else:
  117. pdf_info_dict['_parse_type'] = PARSE_TYPE_TXT
  118. pdf_info_dict['_version_name'] = __version__
  119. if lang is not None:
  120. pdf_info_dict['_lang'] = lang
  121. return pdf_info_dict