app.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. import json
  2. import os
  3. from base64 import b64encode
  4. from glob import glob
  5. from io import StringIO
  6. from typing import Tuple, Union
  7. import uvicorn
  8. from fastapi import FastAPI, HTTPException, UploadFile
  9. from fastapi.responses import JSONResponse
  10. from loguru import logger
  11. import magic_pdf.model as model_config
  12. from magic_pdf.config.enums import SupportedPdfParseMethod
  13. from magic_pdf.data.data_reader_writer import DataWriter, FileBasedDataWriter
  14. from magic_pdf.data.data_reader_writer.s3 import S3DataReader, S3DataWriter
  15. from magic_pdf.data.dataset import PymuDocDataset
  16. from magic_pdf.libs.config_reader import get_bucket_name, get_s3_config
  17. from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
  18. from magic_pdf.operators.models import InferenceResult
  19. from magic_pdf.operators.pipes import PipeResult
  20. model_config.__use_inside_model__ = True
  21. app = FastAPI()
  22. class MemoryDataWriter(DataWriter):
  23. def __init__(self):
  24. self.buffer = StringIO()
  25. def write(self, path: str, data: bytes) -> None:
  26. if isinstance(data, str):
  27. self.buffer.write(data)
  28. else:
  29. self.buffer.write(data.decode("utf-8"))
  30. def write_string(self, path: str, data: str) -> None:
  31. self.buffer.write(data)
  32. def get_value(self) -> str:
  33. return self.buffer.getvalue()
  34. def close(self):
  35. self.buffer.close()
  36. def init_writers(
  37. pdf_path: str = None,
  38. pdf_file: UploadFile = None,
  39. output_path: str = None,
  40. output_image_path: str = None,
  41. ) -> Tuple[
  42. Union[S3DataWriter, FileBasedDataWriter],
  43. Union[S3DataWriter, FileBasedDataWriter],
  44. bytes,
  45. ]:
  46. """
  47. Initialize writers based on path type
  48. Args:
  49. pdf_path: PDF file path (local path or S3 path)
  50. pdf_file: Uploaded PDF file object
  51. output_path: Output directory path
  52. output_image_path: Image output directory path
  53. Returns:
  54. Tuple[writer, image_writer, pdf_bytes]: Returns initialized writer tuple and PDF
  55. file content
  56. """
  57. if pdf_path:
  58. is_s3_path = pdf_path.startswith("s3://")
  59. if is_s3_path:
  60. bucket = get_bucket_name(pdf_path)
  61. ak, sk, endpoint = get_s3_config(bucket)
  62. writer = S3DataWriter(
  63. output_path, bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint
  64. )
  65. image_writer = S3DataWriter(
  66. output_image_path, bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint
  67. )
  68. # 临时创建reader读取文件内容
  69. temp_reader = S3DataReader(
  70. "", bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint
  71. )
  72. pdf_bytes = temp_reader.read(pdf_path)
  73. else:
  74. writer = FileBasedDataWriter(output_path)
  75. image_writer = FileBasedDataWriter(output_image_path)
  76. os.makedirs(output_image_path, exist_ok=True)
  77. with open(pdf_path, "rb") as f:
  78. pdf_bytes = f.read()
  79. else:
  80. # 处理上传的文件
  81. pdf_bytes = pdf_file.file.read()
  82. writer = FileBasedDataWriter(output_path)
  83. image_writer = FileBasedDataWriter(output_image_path)
  84. os.makedirs(output_image_path, exist_ok=True)
  85. return writer, image_writer, pdf_bytes
  86. def process_pdf(
  87. pdf_bytes: bytes,
  88. parse_method: str,
  89. image_writer: Union[S3DataWriter, FileBasedDataWriter],
  90. ) -> Tuple[InferenceResult, PipeResult]:
  91. """
  92. Process PDF file content
  93. Args:
  94. pdf_bytes: Binary content of PDF file
  95. parse_method: Parse method ('ocr', 'txt', 'auto')
  96. image_writer: Image writer
  97. Returns:
  98. Tuple[InferenceResult, PipeResult]: Returns inference result and pipeline result
  99. """
  100. ds = PymuDocDataset(pdf_bytes)
  101. infer_result: InferenceResult = None
  102. pipe_result: PipeResult = None
  103. if parse_method == "ocr":
  104. infer_result = ds.apply(doc_analyze, ocr=True)
  105. pipe_result = infer_result.pipe_ocr_mode(image_writer)
  106. elif parse_method == "txt":
  107. infer_result = ds.apply(doc_analyze, ocr=False)
  108. pipe_result = infer_result.pipe_txt_mode(image_writer)
  109. else: # auto
  110. if ds.classify() == SupportedPdfParseMethod.OCR:
  111. infer_result = ds.apply(doc_analyze, ocr=True)
  112. pipe_result = infer_result.pipe_ocr_mode(image_writer)
  113. else:
  114. infer_result = ds.apply(doc_analyze, ocr=False)
  115. pipe_result = infer_result.pipe_txt_mode(image_writer)
  116. return infer_result, pipe_result
  117. def encode_image(image_path: str) -> str:
  118. """Encode image using base64"""
  119. with open(image_path, "rb") as f:
  120. return b64encode(f.read()).decode()
  121. @app.post(
  122. "/pdf_parse",
  123. tags=["projects"],
  124. summary="Parse PDF files (supports local files and S3)",
  125. )
  126. async def pdf_parse(
  127. pdf_file: UploadFile = None,
  128. pdf_path: str = None,
  129. parse_method: str = "auto",
  130. is_json_md_dump: bool = False,
  131. output_dir: str = "output",
  132. return_layout: bool = False,
  133. return_info: bool = False,
  134. return_content_list: bool = False,
  135. return_images: bool = False,
  136. ):
  137. """
  138. Execute the process of converting PDF to JSON and MD, outputting MD and JSON files
  139. to the specified directory.
  140. Args:
  141. pdf_file: The PDF file to be parsed. Must not be specified together with
  142. `pdf_path`
  143. pdf_path: The path to the PDF file to be parsed. Must not be specified together
  144. with `pdf_file`
  145. parse_method: Parsing method, can be auto, ocr, or txt. Default is auto. If
  146. results are not satisfactory, try ocr
  147. is_json_md_dump: Whether to write parsed data to .json and .md files. Default
  148. to False. Different stages of data will be written to different .json files
  149. (3 in total), md content will be saved to .md file
  150. output_dir: Output directory for results. A folder named after the PDF file
  151. will be created to store all results
  152. return_layout: Whether to return parsed PDF layout. Default to False
  153. return_info: Whether to return parsed PDF info. Default to False
  154. return_content_list: Whether to return parsed PDF content list. Default to False
  155. """
  156. try:
  157. if (pdf_file is None and pdf_path is None) or (
  158. pdf_file is not None and pdf_path is not None
  159. ):
  160. return JSONResponse(
  161. content={"error": "Must provide either pdf_file or pdf_path"},
  162. status_code=400,
  163. )
  164. # Get PDF filename
  165. pdf_name = os.path.basename(pdf_path if pdf_path else pdf_file.filename).split(
  166. "."
  167. )[0]
  168. output_path = f"{output_dir}/{pdf_name}"
  169. output_image_path = f"{output_path}/images"
  170. # Initialize readers/writers and get PDF content
  171. writer, image_writer, pdf_bytes = init_writers(
  172. pdf_path=pdf_path,
  173. pdf_file=pdf_file,
  174. output_path=output_path,
  175. output_image_path=output_image_path,
  176. )
  177. # Process PDF
  178. infer_result, pipe_result = process_pdf(pdf_bytes, parse_method, image_writer)
  179. # Use MemoryDataWriter to get results
  180. content_list_writer = MemoryDataWriter()
  181. md_content_writer = MemoryDataWriter()
  182. middle_json_writer = MemoryDataWriter()
  183. # Use PipeResult's dump method to get data
  184. pipe_result.dump_content_list(content_list_writer, "", "images")
  185. pipe_result.dump_md(md_content_writer, "", "images")
  186. pipe_result.dump_middle_json(middle_json_writer, "")
  187. # Get content
  188. content_list = json.loads(content_list_writer.get_value())
  189. md_content = md_content_writer.get_value()
  190. middle_json = json.loads(middle_json_writer.get_value())
  191. model_json = infer_result.get_infer_res()
  192. # If results need to be saved
  193. if is_json_md_dump:
  194. writer.write_string(
  195. f"{pdf_name}_content_list.json", content_list_writer.get_value()
  196. )
  197. writer.write_string(f"{pdf_name}.md", md_content)
  198. writer.write_string(
  199. f"{pdf_name}_middle.json", middle_json_writer.get_value()
  200. )
  201. writer.write_string(
  202. f"{pdf_name}_model.json",
  203. json.dumps(model_json, indent=4, ensure_ascii=False),
  204. )
  205. # Save visualization results
  206. pipe_result.draw_layout(os.path.join(output_path, f"{pdf_name}_layout.pdf"))
  207. pipe_result.draw_span(os.path.join(output_path, f"{pdf_name}_spans.pdf"))
  208. pipe_result.draw_line_sort(
  209. os.path.join(output_path, f"{pdf_name}_line_sort.pdf")
  210. )
  211. infer_result.draw_model(os.path.join(output_path, f"{pdf_name}_model.pdf"))
  212. # Build return data
  213. data = {}
  214. if return_layout:
  215. data["layout"] = model_json
  216. if return_info:
  217. data["info"] = middle_json
  218. if return_content_list:
  219. data["content_list"] = content_list
  220. if return_images:
  221. image_paths = glob(f"{output_image_path}/*.jpg")
  222. data["images"] = {
  223. os.path.basename(
  224. image_path
  225. ): f"data:image/jpeg;base64,{encode_image(image_path)}"
  226. for image_path in image_paths
  227. }
  228. data["md_content"] = md_content # md_content is always returned
  229. # Clean up memory writers
  230. content_list_writer.close()
  231. md_content_writer.close()
  232. middle_json_writer.close()
  233. return JSONResponse(data, status_code=200)
  234. except Exception as e:
  235. logger.exception(e)
  236. return JSONResponse(content={"error": str(e)}, status_code=500)
  237. if __name__ == "__main__":
  238. uvicorn.run(app, host="0.0.0.0", port=8888)