app.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import json
  2. import os
  3. from io import StringIO
  4. from typing import Tuple, Union
  5. import uvicorn
  6. from fastapi import FastAPI, HTTPException, UploadFile
  7. from fastapi.responses import JSONResponse
  8. from loguru import logger
  9. import magic_pdf.model as model_config
  10. from magic_pdf.config.enums import SupportedPdfParseMethod
  11. from magic_pdf.data.data_reader_writer import DataWriter, FileBasedDataWriter
  12. from magic_pdf.data.data_reader_writer.s3 import S3DataReader, S3DataWriter
  13. from magic_pdf.data.dataset import PymuDocDataset
  14. from magic_pdf.libs.config_reader import get_bucket_name, get_s3_config
  15. from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
  16. from magic_pdf.operators.models import InferenceResult
  17. from magic_pdf.operators.pipes import PipeResult
  18. model_config.__use_inside_model__ = True
  19. app = FastAPI()
  20. class MemoryDataWriter(DataWriter):
  21. def __init__(self):
  22. self.buffer = StringIO()
  23. def write(self, path: str, data: bytes) -> None:
  24. if isinstance(data, str):
  25. self.buffer.write(data)
  26. else:
  27. self.buffer.write(data.decode("utf-8"))
  28. def write_string(self, path: str, data: str) -> None:
  29. self.buffer.write(data)
  30. def get_value(self) -> str:
  31. return self.buffer.getvalue()
  32. def close(self):
  33. self.buffer.close()
  34. def init_writers(
  35. pdf_path: str = None,
  36. pdf_file: UploadFile = None,
  37. output_path: str = None,
  38. output_image_path: str = None,
  39. ) -> Tuple[
  40. Union[S3DataWriter, FileBasedDataWriter],
  41. Union[S3DataWriter, FileBasedDataWriter],
  42. bytes,
  43. ]:
  44. """
  45. Initialize writers based on path type
  46. Args:
  47. pdf_path: PDF file path (local path or S3 path)
  48. pdf_file: Uploaded PDF file object
  49. output_path: Output directory path
  50. output_image_path: Image output directory path
  51. Returns:
  52. Tuple[writer, image_writer, pdf_bytes]: Returns initialized writer tuple and PDF
  53. file content
  54. """
  55. if pdf_path:
  56. is_s3_path = pdf_path.startswith("s3://")
  57. if is_s3_path:
  58. bucket = get_bucket_name(pdf_path)
  59. ak, sk, endpoint = get_s3_config(bucket)
  60. writer = S3DataWriter(
  61. output_path, bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint
  62. )
  63. image_writer = S3DataWriter(
  64. output_image_path, bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint
  65. )
  66. # 临时创建reader读取文件内容
  67. temp_reader = S3DataReader(
  68. "", bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint
  69. )
  70. pdf_bytes = temp_reader.read(pdf_path)
  71. else:
  72. writer = FileBasedDataWriter(output_path)
  73. image_writer = FileBasedDataWriter(output_image_path)
  74. os.makedirs(output_image_path, exist_ok=True)
  75. with open(pdf_path, "rb") as f:
  76. pdf_bytes = f.read()
  77. else:
  78. # 处理上传的文件
  79. pdf_bytes = pdf_file.file.read()
  80. writer = FileBasedDataWriter(output_path)
  81. image_writer = FileBasedDataWriter(output_image_path)
  82. os.makedirs(output_image_path, exist_ok=True)
  83. return writer, image_writer, pdf_bytes
  84. def process_pdf(
  85. pdf_bytes: bytes,
  86. parse_method: str,
  87. image_writer: Union[S3DataWriter, FileBasedDataWriter],
  88. ) -> Tuple[InferenceResult, PipeResult]:
  89. """
  90. Process PDF file content
  91. Args:
  92. pdf_bytes: Binary content of PDF file
  93. parse_method: Parse method ('ocr', 'txt', 'auto')
  94. image_writer: Image writer
  95. Returns:
  96. Tuple[InferenceResult, PipeResult]: Returns inference result and pipeline result
  97. """
  98. ds = PymuDocDataset(pdf_bytes)
  99. infer_result: InferenceResult = None
  100. pipe_result: PipeResult = None
  101. if parse_method == "ocr":
  102. infer_result = ds.apply(doc_analyze, ocr=True)
  103. pipe_result = infer_result.pipe_ocr_mode(image_writer)
  104. elif parse_method == "txt":
  105. infer_result = ds.apply(doc_analyze, ocr=False)
  106. pipe_result = infer_result.pipe_txt_mode(image_writer)
  107. else: # auto
  108. if ds.classify() == SupportedPdfParseMethod.OCR:
  109. infer_result = ds.apply(doc_analyze, ocr=True)
  110. pipe_result = infer_result.pipe_ocr_mode(image_writer)
  111. else:
  112. infer_result = ds.apply(doc_analyze, ocr=False)
  113. pipe_result = infer_result.pipe_txt_mode(image_writer)
  114. return infer_result, pipe_result
  115. @app.post(
  116. "/pdf_parse",
  117. tags=["projects"],
  118. summary="Parse PDF files (supports local files and S3)",
  119. )
  120. async def pdf_parse(
  121. pdf_file: UploadFile = None,
  122. pdf_path: str = None,
  123. parse_method: str = "auto",
  124. is_json_md_dump: bool = True,
  125. output_dir: str = "output",
  126. return_layout: bool = False,
  127. return_info: bool = False,
  128. return_content_list: bool = False,
  129. ):
  130. try:
  131. if pdf_file is None and pdf_path is None:
  132. raise HTTPException(
  133. status_code=400, detail="Must provide either pdf_file or pdf_path"
  134. )
  135. # Get PDF filename
  136. pdf_name = os.path.basename(pdf_path if pdf_path else pdf_file.filename).split(
  137. "."
  138. )[0]
  139. output_path = f"{output_dir}/{pdf_name}"
  140. output_image_path = f"{output_path}/images"
  141. # Initialize readers/writers and get PDF content
  142. writer, image_writer, pdf_bytes = init_writers(
  143. pdf_path=pdf_path,
  144. pdf_file=pdf_file,
  145. output_path=output_path,
  146. output_image_path=output_image_path,
  147. )
  148. # Process PDF
  149. infer_result, pipe_result = process_pdf(pdf_bytes, parse_method, image_writer)
  150. # Use MemoryDataWriter to get results
  151. content_list_writer = MemoryDataWriter()
  152. md_content_writer = MemoryDataWriter()
  153. middle_json_writer = MemoryDataWriter()
  154. # Use PipeResult's dump method to get data
  155. pipe_result.dump_content_list(content_list_writer, "", "images")
  156. pipe_result.dump_md(md_content_writer, "", "images")
  157. pipe_result.dump_middle_json(middle_json_writer, "")
  158. # Get content
  159. content_list = json.loads(content_list_writer.get_value())
  160. md_content = md_content_writer.get_value()
  161. middle_json = json.loads(middle_json_writer.get_value())
  162. model_json = infer_result.get_infer_res()
  163. # If results need to be saved
  164. if is_json_md_dump:
  165. writer.write_string(
  166. f"{pdf_name}_content_list.json", content_list_writer.get_value()
  167. )
  168. writer.write_string(f"{pdf_name}.md", md_content)
  169. writer.write_string(
  170. f"{pdf_name}_middle.json", middle_json_writer.get_value()
  171. )
  172. writer.write_string(
  173. f"{pdf_name}_model.json",
  174. json.dumps(model_json, indent=4, ensure_ascii=False),
  175. )
  176. # Save visualization results
  177. pipe_result.draw_layout(os.path.join(output_path, f"{pdf_name}_layout.pdf"))
  178. pipe_result.draw_span(os.path.join(output_path, f"{pdf_name}_spans.pdf"))
  179. pipe_result.draw_line_sort(
  180. os.path.join(output_path, f"{pdf_name}_line_sort.pdf")
  181. )
  182. infer_result.draw_model(os.path.join(output_path, f"{pdf_name}_model.pdf"))
  183. # Build return data
  184. data = {}
  185. if return_layout:
  186. data["layout"] = model_json
  187. if return_info:
  188. data["info"] = middle_json
  189. if return_content_list:
  190. data["content_list"] = content_list
  191. data["md_content"] = md_content # md_content is always returned
  192. # Clean up memory writers
  193. content_list_writer.close()
  194. md_content_writer.close()
  195. middle_json_writer.close()
  196. return JSONResponse(data, status_code=200)
  197. except Exception as e:
  198. logger.exception(e)
  199. return JSONResponse(content={"error": str(e)}, status_code=500)
  200. if __name__ == "__main__":
  201. uvicorn.run(app, host="0.0.0.0", port=8888)