app.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. import json
  2. import os
  3. from base64 import b64encode
  4. from glob import glob
  5. from io import StringIO
  6. import tempfile
  7. from typing import Tuple, Union
  8. import uvicorn
  9. from fastapi import FastAPI, HTTPException, UploadFile
  10. from fastapi.responses import JSONResponse
  11. from loguru import logger
  12. from magic_pdf.data.read_api import read_local_images, read_local_office
  13. import magic_pdf.model as model_config
  14. from magic_pdf.config.enums import SupportedPdfParseMethod
  15. from magic_pdf.data.data_reader_writer import DataWriter, FileBasedDataWriter
  16. from magic_pdf.data.data_reader_writer.s3 import S3DataReader, S3DataWriter
  17. from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
  18. from magic_pdf.libs.config_reader import get_bucket_name, get_s3_config
  19. from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
  20. from magic_pdf.operators.models import InferenceResult
  21. from magic_pdf.operators.pipes import PipeResult
  22. from fastapi import Form
  23. model_config.__use_inside_model__ = True
  24. app = FastAPI()
  25. pdf_extensions = [".pdf"]
  26. office_extensions = [".ppt", ".pptx", ".doc", ".docx"]
  27. image_extensions = [".png", ".jpg", ".jpeg"]
  28. class MemoryDataWriter(DataWriter):
  29. def __init__(self):
  30. self.buffer = StringIO()
  31. def write(self, path: str, data: bytes) -> None:
  32. if isinstance(data, str):
  33. self.buffer.write(data)
  34. else:
  35. self.buffer.write(data.decode("utf-8"))
  36. def write_string(self, path: str, data: str) -> None:
  37. self.buffer.write(data)
  38. def get_value(self) -> str:
  39. return self.buffer.getvalue()
  40. def close(self):
  41. self.buffer.close()
  42. def init_writers(
  43. file_path: str = None,
  44. file: UploadFile = None,
  45. output_path: str = None,
  46. output_image_path: str = None,
  47. ) -> Tuple[
  48. Union[S3DataWriter, FileBasedDataWriter],
  49. Union[S3DataWriter, FileBasedDataWriter],
  50. bytes,
  51. ]:
  52. """
  53. Initialize writers based on path type
  54. Args:
  55. file_path: file path (local path or S3 path)
  56. file: Uploaded file object
  57. output_path: Output directory path
  58. output_image_path: Image output directory path
  59. Returns:
  60. Tuple[writer, image_writer, file_bytes]: Returns initialized writer tuple and file content
  61. """
  62. file_extension:str = None
  63. if file_path:
  64. is_s3_path = file_path.startswith("s3://")
  65. if is_s3_path:
  66. bucket = get_bucket_name(file_path)
  67. ak, sk, endpoint = get_s3_config(bucket)
  68. writer = S3DataWriter(
  69. output_path, bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint
  70. )
  71. image_writer = S3DataWriter(
  72. output_image_path, bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint
  73. )
  74. # 临时创建reader读取文件内容
  75. temp_reader = S3DataReader(
  76. "", bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint
  77. )
  78. file_bytes = temp_reader.read(file_path)
  79. file_extension = os.path.splitext(file_path)[1]
  80. else:
  81. writer = FileBasedDataWriter(output_path)
  82. image_writer = FileBasedDataWriter(output_image_path)
  83. os.makedirs(output_image_path, exist_ok=True)
  84. with open(file_path, "rb") as f:
  85. file_bytes = f.read()
  86. file_extension = os.path.splitext(file_path)[1]
  87. else:
  88. # 处理上传的文件
  89. file_bytes = file.file.read()
  90. file_extension = os.path.splitext(file.filename)[1]
  91. writer = FileBasedDataWriter(output_path)
  92. image_writer = FileBasedDataWriter(output_image_path)
  93. os.makedirs(output_image_path, exist_ok=True)
  94. return writer, image_writer, file_bytes, file_extension
  95. def process_file(
  96. file_bytes: bytes,
  97. file_extension: str,
  98. parse_method: str,
  99. image_writer: Union[S3DataWriter, FileBasedDataWriter],
  100. ) -> Tuple[InferenceResult, PipeResult]:
  101. """
  102. Process PDF file content
  103. Args:
  104. file_bytes: Binary content of file
  105. file_extension: file extension
  106. parse_method: Parse method ('ocr', 'txt', 'auto')
  107. image_writer: Image writer
  108. Returns:
  109. Tuple[InferenceResult, PipeResult]: Returns inference result and pipeline result
  110. """
  111. ds: Union[PymuDocDataset, ImageDataset] = None
  112. if file_extension in pdf_extensions:
  113. ds = PymuDocDataset(file_bytes)
  114. elif file_extension in office_extensions:
  115. # 需要使用office解析
  116. temp_dir = tempfile.mkdtemp()
  117. with open(os.path.join(temp_dir, f"temp_file.{file_extension}"), "wb") as f:
  118. f.write(file_bytes)
  119. ds = read_local_office(temp_dir)[0]
  120. elif file_extension in image_extensions:
  121. # 需要使用ocr解析
  122. temp_dir = tempfile.mkdtemp()
  123. with open(os.path.join(temp_dir, f"temp_file.{file_extension}"), "wb") as f:
  124. f.write(file_bytes)
  125. ds = read_local_images(temp_dir)[0]
  126. infer_result: InferenceResult = None
  127. pipe_result: PipeResult = None
  128. if parse_method == "ocr":
  129. infer_result = ds.apply(doc_analyze, ocr=True)
  130. pipe_result = infer_result.pipe_ocr_mode(image_writer)
  131. elif parse_method == "txt":
  132. infer_result = ds.apply(doc_analyze, ocr=False)
  133. pipe_result = infer_result.pipe_txt_mode(image_writer)
  134. else: # auto
  135. if ds.classify() == SupportedPdfParseMethod.OCR:
  136. infer_result = ds.apply(doc_analyze, ocr=True)
  137. pipe_result = infer_result.pipe_ocr_mode(image_writer)
  138. else:
  139. infer_result = ds.apply(doc_analyze, ocr=False)
  140. pipe_result = infer_result.pipe_txt_mode(image_writer)
  141. return infer_result, pipe_result
  142. def encode_image(image_path: str) -> str:
  143. """Encode image using base64"""
  144. with open(image_path, "rb") as f:
  145. return b64encode(f.read()).decode()
  146. @app.post(
  147. "/file_parse",
  148. tags=["projects"],
  149. summary="Parse files (supports local files and S3)",
  150. )
  151. async def file_parse(
  152. file: UploadFile = None,
  153. file_path: str = Form(None),
  154. parse_method: str = Form("auto"),
  155. is_json_md_dump: bool = Form(False),
  156. output_dir: str = Form("output"),
  157. return_layout: bool = Form(False),
  158. return_info: bool = Form(False),
  159. return_content_list: bool = Form(False),
  160. return_images: bool = Form(False),
  161. ):
  162. """
  163. Execute the process of converting PDF to JSON and MD, outputting MD and JSON files
  164. to the specified directory.
  165. Args:
  166. file: The PDF file to be parsed. Must not be specified together with
  167. `file_path`
  168. file_path: The path to the PDF file to be parsed. Must not be specified together
  169. with `file`
  170. parse_method: Parsing method, can be auto, ocr, or txt. Default is auto. If
  171. results are not satisfactory, try ocr
  172. is_json_md_dump: Whether to write parsed data to .json and .md files. Default
  173. to False. Different stages of data will be written to different .json files
  174. (3 in total), md content will be saved to .md file
  175. output_dir: Output directory for results. A folder named after the PDF file
  176. will be created to store all results
  177. return_layout: Whether to return parsed PDF layout. Default to False
  178. return_info: Whether to return parsed PDF info. Default to False
  179. return_content_list: Whether to return parsed PDF content list. Default to False
  180. """
  181. try:
  182. if (file is None and file_path is None) or (
  183. file is not None and file_path is not None
  184. ):
  185. return JSONResponse(
  186. content={"error": "Must provide either file or file_path"},
  187. status_code=400,
  188. )
  189. # Get PDF filename
  190. file_name = os.path.basename(file_path if file_path else file.filename).split(
  191. "."
  192. )[0]
  193. output_path = f"{output_dir}/{file_name}"
  194. output_image_path = f"{output_path}/images"
  195. # Initialize readers/writers and get PDF content
  196. writer, image_writer, file_bytes, file_extension = init_writers(
  197. file_path=file_path,
  198. file=file,
  199. output_path=output_path,
  200. output_image_path=output_image_path,
  201. )
  202. # Process PDF
  203. infer_result, pipe_result = process_file(file_bytes, file_extension, parse_method, image_writer)
  204. # Use MemoryDataWriter to get results
  205. content_list_writer = MemoryDataWriter()
  206. md_content_writer = MemoryDataWriter()
  207. middle_json_writer = MemoryDataWriter()
  208. # Use PipeResult's dump method to get data
  209. pipe_result.dump_content_list(content_list_writer, "", "images")
  210. pipe_result.dump_md(md_content_writer, "", "images")
  211. pipe_result.dump_middle_json(middle_json_writer, "")
  212. # Get content
  213. content_list = json.loads(content_list_writer.get_value())
  214. md_content = md_content_writer.get_value()
  215. middle_json = json.loads(middle_json_writer.get_value())
  216. model_json = infer_result.get_infer_res()
  217. # If results need to be saved
  218. if is_json_md_dump:
  219. writer.write_string(
  220. f"{file_name}_content_list.json", content_list_writer.get_value()
  221. )
  222. writer.write_string(f"{file_name}.md", md_content)
  223. writer.write_string(
  224. f"{file_name}_middle.json", middle_json_writer.get_value()
  225. )
  226. writer.write_string(
  227. f"{file_name}_model.json",
  228. json.dumps(model_json, indent=4, ensure_ascii=False),
  229. )
  230. # Save visualization results
  231. pipe_result.draw_layout(os.path.join(output_path, f"{file_name}_layout.pdf"))
  232. pipe_result.draw_span(os.path.join(output_path, f"{file_name}_spans.pdf"))
  233. pipe_result.draw_line_sort(
  234. os.path.join(output_path, f"{file_name}_line_sort.pdf")
  235. )
  236. infer_result.draw_model(os.path.join(output_path, f"{file_name}_model.pdf"))
  237. # Build return data
  238. data = {}
  239. if return_layout:
  240. data["layout"] = model_json
  241. if return_info:
  242. data["info"] = middle_json
  243. if return_content_list:
  244. data["content_list"] = content_list
  245. if return_images:
  246. image_paths = glob(f"{output_image_path}/*.jpg")
  247. data["images"] = {
  248. os.path.basename(
  249. image_path
  250. ): f"data:image/jpeg;base64,{encode_image(image_path)}"
  251. for image_path in image_paths
  252. }
  253. data["md_content"] = md_content # md_content is always returned
  254. # Clean up memory writers
  255. content_list_writer.close()
  256. md_content_writer.close()
  257. middle_json_writer.close()
  258. return JSONResponse(data, status_code=200)
  259. except Exception as e:
  260. logger.exception(e)
  261. return JSONResponse(content={"error": str(e)}, status_code=500)
  262. if __name__ == "__main__":
  263. uvicorn.run(app, host="0.0.0.0", port=8888)