app.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import json
  2. import os
  3. from io import StringIO
  4. from typing import Tuple, Union
  5. import uvicorn
  6. from fastapi import FastAPI, HTTPException, UploadFile
  7. from fastapi.responses import JSONResponse
  8. from loguru import logger
  9. import magic_pdf.model as model_config
  10. from magic_pdf.config.enums import SupportedPdfParseMethod
  11. from magic_pdf.data.data_reader_writer import DataWriter, FileBasedDataWriter
  12. from magic_pdf.data.data_reader_writer.s3 import S3DataReader, S3DataWriter
  13. from magic_pdf.data.dataset import PymuDocDataset
  14. from magic_pdf.libs.config_reader import get_bucket_name, get_s3_config
  15. from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
  16. from magic_pdf.operators.models import InferenceResult
  17. from magic_pdf.operators.pipes import PipeResult
  18. model_config.__use_inside_model__ = True
  19. app = FastAPI()
  20. class MemoryDataWriter(DataWriter):
  21. def __init__(self):
  22. self.buffer = StringIO()
  23. def write(self, path: str, data: bytes) -> None:
  24. if isinstance(data, str):
  25. self.buffer.write(data)
  26. else:
  27. self.buffer.write(data.decode('utf-8'))
  28. def write_string(self, path: str, data: str) -> None:
  29. self.buffer.write(data)
  30. def get_value(self) -> str:
  31. return self.buffer.getvalue()
  32. def close(self):
  33. self.buffer.close()
  34. def init_writers(
  35. pdf_path: str = None,
  36. pdf_file: UploadFile = None,
  37. output_path: str = None,
  38. output_image_path: str = None,
  39. ) -> Tuple[Union[S3DataWriter, FileBasedDataWriter], Union[S3DataWriter, FileBasedDataWriter], bytes]:
  40. """
  41. Initialize writers based on path type
  42. Args:
  43. pdf_path: PDF file path (local path or S3 path)
  44. pdf_file: Uploaded PDF file object
  45. output_path: Output directory path
  46. output_image_path: Image output directory path
  47. Returns:
  48. Tuple[writer, image_writer, pdf_bytes]: Returns initialized writer tuple and PDF file content
  49. """
  50. if pdf_path:
  51. is_s3_path = pdf_path.startswith('s3://')
  52. if is_s3_path:
  53. bucket = get_bucket_name(pdf_path)
  54. ak, sk, endpoint = get_s3_config(bucket)
  55. writer = S3DataWriter(output_path, bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint)
  56. image_writer = S3DataWriter(output_image_path, bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint)
  57. # 临时创建reader读取文件内容
  58. temp_reader = S3DataReader("", bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint)
  59. pdf_bytes = temp_reader.read(pdf_path)
  60. else:
  61. writer = FileBasedDataWriter(output_path)
  62. image_writer = FileBasedDataWriter(output_image_path)
  63. os.makedirs(output_image_path, exist_ok=True)
  64. with open(pdf_path, 'rb') as f:
  65. pdf_bytes = f.read()
  66. else:
  67. # 处理上传的文件
  68. pdf_bytes = pdf_file.file.read()
  69. writer = FileBasedDataWriter(output_path)
  70. image_writer = FileBasedDataWriter(output_image_path)
  71. os.makedirs(output_image_path, exist_ok=True)
  72. return writer, image_writer, pdf_bytes
  73. def process_pdf(
  74. pdf_bytes: bytes,
  75. parse_method: str,
  76. image_writer: Union[S3DataWriter, FileBasedDataWriter]
  77. ) -> Tuple[InferenceResult, PipeResult]:
  78. """
  79. Process PDF file content
  80. Args:
  81. pdf_bytes: Binary content of PDF file
  82. parse_method: Parse method ('ocr', 'txt', 'auto')
  83. image_writer: Image writer
  84. Returns:
  85. Tuple[InferenceResult, PipeResult]: Returns inference result and pipeline result
  86. """
  87. ds = PymuDocDataset(pdf_bytes)
  88. infer_result : InferenceResult = None
  89. pipe_result : PipeResult = None
  90. if parse_method == 'ocr':
  91. infer_result = ds.apply(doc_analyze, ocr=True)
  92. pipe_result = infer_result.pipe_ocr_mode(image_writer)
  93. elif parse_method == 'txt':
  94. infer_result = ds.apply(doc_analyze, ocr=False)
  95. pipe_result = infer_result.pipe_txt_mode(image_writer)
  96. else: # auto
  97. if ds.classify() == SupportedPdfParseMethod.OCR:
  98. infer_result = ds.apply(doc_analyze, ocr=True)
  99. pipe_result = infer_result.pipe_ocr_mode(image_writer)
  100. else:
  101. infer_result = ds.apply(doc_analyze, ocr=False)
  102. pipe_result = infer_result.pipe_txt_mode(image_writer)
  103. return infer_result, pipe_result
  104. @app.post('/pdf_parse', tags=['projects'], summary='Parse PDF files (supports local files and S3)')
  105. async def pdf_parse(
  106. pdf_file: UploadFile = None,
  107. pdf_path: str = None,
  108. parse_method: str = 'auto',
  109. is_json_md_dump: bool = True,
  110. output_dir: str = 'output',
  111. return_layout: bool = False,
  112. return_info: bool = False,
  113. return_content_list: bool = False,
  114. ):
  115. try:
  116. if pdf_file is None and pdf_path is None:
  117. raise HTTPException(status_code=400, detail="Must provide either pdf_file or pdf_path")
  118. # Get PDF filename
  119. pdf_name = os.path.basename(pdf_path if pdf_path else pdf_file.filename).split('.')[0]
  120. output_path = f"{output_dir}/{pdf_name}"
  121. output_image_path = f"{output_path}/images"
  122. # Initialize readers/writers and get PDF content
  123. writer, image_writer, pdf_bytes = init_writers(
  124. pdf_path=pdf_path,
  125. pdf_file=pdf_file,
  126. output_path=output_path,
  127. output_image_path=output_image_path
  128. )
  129. # Process PDF
  130. infer_result, pipe_result = process_pdf(pdf_bytes, parse_method, image_writer)
  131. # Use MemoryDataWriter to get results
  132. content_list_writer = MemoryDataWriter()
  133. md_content_writer = MemoryDataWriter()
  134. middle_json_writer = MemoryDataWriter()
  135. # Use PipeResult's dump method to get data
  136. pipe_result.dump_content_list(content_list_writer, "", "images")
  137. pipe_result.dump_md(md_content_writer, "", "images")
  138. pipe_result.dump_middle_json(middle_json_writer, "")
  139. # Get content
  140. content_list = json.loads(content_list_writer.get_value())
  141. md_content = md_content_writer.get_value()
  142. middle_json = json.loads(middle_json_writer.get_value())
  143. model_json = infer_result.get_infer_res()
  144. # If results need to be saved
  145. if is_json_md_dump:
  146. writer.write_string(f"{pdf_name}_content_list.json", content_list_writer.get_value())
  147. writer.write_string(f"{pdf_name}.md", md_content)
  148. writer.write_string(f"{pdf_name}_middle.json", middle_json_writer.get_value())
  149. writer.write_string(f"{pdf_name}_model.json", json.dumps(model_json, indent=4, ensure_ascii=False))
  150. # Save visualization results
  151. pipe_result.draw_layout(os.path.join(output_path, f'{pdf_name}_layout.pdf'))
  152. pipe_result.draw_span(os.path.join(output_path, f'{pdf_name}_spans.pdf'))
  153. pipe_result.draw_line_sort(os.path.join(output_path, f'{pdf_name}_line_sort.pdf'))
  154. infer_result.draw_model(os.path.join(output_path, f'{pdf_name}_model.pdf'))
  155. # Build return data
  156. data = {}
  157. if return_layout:
  158. data['layout'] = model_json
  159. if return_info:
  160. data['info'] = middle_json
  161. if return_content_list:
  162. data['content_list'] = content_list
  163. data['md_content'] = md_content # md_content is always returned
  164. # Clean up memory writers
  165. content_list_writer.close()
  166. md_content_writer.close()
  167. middle_json_writer.close()
  168. return JSONResponse(data, status_code=200)
  169. except Exception as e:
  170. logger.exception(e)
  171. return JSONResponse(content={'error': str(e)}, status_code=500)
  172. if __name__ == '__main__':
  173. uvicorn.run(app, host='0.0.0.0', port=8888)