| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277 |
- import uuid
- import os
- import re
- import tempfile
- import asyncio
- import uvicorn
- import click
- import zipfile
- from pathlib import Path
- import glob
- from fastapi import FastAPI, UploadFile, File, Form
- from fastapi.middleware.gzip import GZipMiddleware
- from fastapi.responses import JSONResponse, FileResponse
- from starlette.background import BackgroundTask
- from typing import List, Optional
- from loguru import logger
- from base64 import b64encode
- from mineru.cli.common import aio_do_parse, read_fn, pdf_suffixes, image_suffixes
- from mineru.utils.cli_parser import arg_parse
- from mineru.utils.guess_suffix_or_lang import guess_suffix_by_path
- from mineru.version import __version__
- app = FastAPI()
- app.add_middleware(GZipMiddleware, minimum_size=1000)
- def sanitize_filename(filename: str) -> str:
- """
- 格式化压缩文件的文件名
- 移除路径遍历字符, 保留 Unicode 字母、数字、._-
- 禁止隐藏文件
- """
- sanitized = re.sub(r'[/\\\.]{2,}|[/\\]', '', filename)
- sanitized = re.sub(r'[^\w.-]', '_', sanitized, flags=re.UNICODE)
- if sanitized.startswith('.'):
- sanitized = '_' + sanitized[1:]
- return sanitized or 'unnamed'
- def cleanup_file(file_path: str) -> None:
- """清理临时 zip 文件"""
- try:
- if os.path.exists(file_path):
- os.remove(file_path)
- except Exception as e:
- logger.warning(f"fail clean file {file_path}: {e}")
- def encode_image(image_path: str) -> str:
- """Encode image using base64"""
- with open(image_path, "rb") as f:
- return b64encode(f.read()).decode()
- def get_infer_result(file_suffix_identifier: str, pdf_name: str, parse_dir: str) -> Optional[str]:
- """从结果文件中读取推理结果"""
- result_file_path = os.path.join(parse_dir, f"{pdf_name}{file_suffix_identifier}")
- if os.path.exists(result_file_path):
- with open(result_file_path, "r", encoding="utf-8") as fp:
- return fp.read()
- return None
- @app.post(path="/file_parse",)
- async def parse_pdf(
- files: List[UploadFile] = File(...),
- output_dir: str = Form("./output"),
- lang_list: List[str] = Form(["ch"]),
- backend: str = Form("pipeline"),
- parse_method: str = Form("auto"),
- formula_enable: bool = Form(True),
- table_enable: bool = Form(True),
- server_url: Optional[str] = Form(None),
- return_md: bool = Form(True),
- return_middle_json: bool = Form(False),
- return_model_output: bool = Form(False),
- return_content_list: bool = Form(False),
- return_images: bool = Form(False),
- response_format_zip: bool = Form(False),
- start_page_id: int = Form(0),
- end_page_id: int = Form(99999),
- ):
- # 获取命令行配置参数
- config = getattr(app.state, "config", {})
- try:
- # 创建唯一的输出目录
- unique_dir = os.path.join(output_dir, str(uuid.uuid4()))
- os.makedirs(unique_dir, exist_ok=True)
- # 处理上传的PDF文件
- pdf_file_names = []
- pdf_bytes_list = []
- for file in files:
- content = await file.read()
- file_path = Path(file.filename)
- # 创建临时文件
- temp_path = Path(unique_dir) / file_path.name
- with open(temp_path, "wb") as f:
- f.write(content)
- # 如果是图像文件或PDF,使用read_fn处理
- file_suffix = guess_suffix_by_path(temp_path)
- if file_suffix in pdf_suffixes + image_suffixes:
- try:
- pdf_bytes = read_fn(temp_path)
- pdf_bytes_list.append(pdf_bytes)
- pdf_file_names.append(file_path.stem)
- os.remove(temp_path) # 删除临时文件
- except Exception as e:
- return JSONResponse(
- status_code=400,
- content={"error": f"Failed to load file: {str(e)}"}
- )
- else:
- return JSONResponse(
- status_code=400,
- content={"error": f"Unsupported file type: {file_suffix}"}
- )
- # 设置语言列表,确保与文件数量一致
- actual_lang_list = lang_list
- if len(actual_lang_list) != len(pdf_file_names):
- # 如果语言列表长度不匹配,使用第一个语言或默认"ch"
- actual_lang_list = [actual_lang_list[0] if actual_lang_list else "ch"] * len(pdf_file_names)
- # 调用异步处理函数
- await aio_do_parse(
- output_dir=unique_dir,
- pdf_file_names=pdf_file_names,
- pdf_bytes_list=pdf_bytes_list,
- p_lang_list=actual_lang_list,
- backend=backend,
- parse_method=parse_method,
- formula_enable=formula_enable,
- table_enable=table_enable,
- server_url=server_url,
- f_draw_layout_bbox=False,
- f_draw_span_bbox=False,
- f_dump_md=return_md,
- f_dump_middle_json=return_middle_json,
- f_dump_model_output=return_model_output,
- f_dump_orig_pdf=False,
- f_dump_content_list=return_content_list,
- start_page_id=start_page_id,
- end_page_id=end_page_id,
- **config
- )
- # 根据 response_format_zip 决定返回类型
- if response_format_zip:
- zip_fd, zip_path = tempfile.mkstemp(suffix=".zip", prefix="mineru_results_")
- os.close(zip_fd)
- with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
- for pdf_name in pdf_file_names:
- safe_pdf_name = sanitize_filename(pdf_name)
- if backend.startswith("pipeline"):
- parse_dir = os.path.join(unique_dir, pdf_name, parse_method)
- else:
- parse_dir = os.path.join(unique_dir, pdf_name, "vlm")
- if not os.path.exists(parse_dir):
- continue
- # 写入文本类结果
- if return_md:
- path = os.path.join(parse_dir, f"{pdf_name}.md")
- if os.path.exists(path):
- zf.write(path, arcname=os.path.join(safe_pdf_name, f"{safe_pdf_name}.md"))
- if return_middle_json:
- path = os.path.join(parse_dir, f"{pdf_name}_middle.json")
- if os.path.exists(path):
- zf.write(path, arcname=os.path.join(safe_pdf_name, f"{safe_pdf_name}_middle.json"))
- if return_model_output:
- path = os.path.join(parse_dir, f"{pdf_name}_model.json")
- if os.path.exists(path):
- zf.write(path, arcname=os.path.join(safe_pdf_name, os.path.basename(path)))
- if return_content_list:
- path = os.path.join(parse_dir, f"{pdf_name}_content_list.json")
- if os.path.exists(path):
- zf.write(path, arcname=os.path.join(safe_pdf_name, f"{safe_pdf_name}_content_list.json"))
- # 写入图片
- if return_images:
- images_dir = os.path.join(parse_dir, "images")
- image_paths = glob.glob(os.path.join(glob.escape(images_dir), "*.jpg"))
- for image_path in image_paths:
- zf.write(image_path, arcname=os.path.join(safe_pdf_name, "images", os.path.basename(image_path)))
- return FileResponse(
- path=zip_path,
- media_type="application/zip",
- filename="results.zip",
- background=BackgroundTask(cleanup_file, zip_path)
- )
- else:
- # 构建 JSON 结果
- result_dict = {}
- for pdf_name in pdf_file_names:
- result_dict[pdf_name] = {}
- data = result_dict[pdf_name]
- if backend.startswith("pipeline"):
- parse_dir = os.path.join(unique_dir, pdf_name, parse_method)
- else:
- parse_dir = os.path.join(unique_dir, pdf_name, "vlm")
- if os.path.exists(parse_dir):
- if return_md:
- data["md_content"] = get_infer_result(".md", pdf_name, parse_dir)
- if return_middle_json:
- data["middle_json"] = get_infer_result("_middle.json", pdf_name, parse_dir)
- if return_model_output:
- data["model_output"] = get_infer_result("_model.json", pdf_name, parse_dir)
- if return_content_list:
- data["content_list"] = get_infer_result("_content_list.json", pdf_name, parse_dir)
- if return_images:
- images_dir = os.path.join(parse_dir, "images")
- safe_pattern = os.path.join(glob.escape(images_dir), "*.jpg")
- image_paths = glob.glob(safe_pattern)
- data["images"] = {
- os.path.basename(
- image_path
- ): f"data:image/jpeg;base64,{encode_image(image_path)}"
- for image_path in image_paths
- }
- return JSONResponse(
- status_code=200,
- content={
- "backend": backend,
- "version": __version__,
- "results": result_dict
- }
- )
- except Exception as e:
- logger.exception(e)
- return JSONResponse(
- status_code=500,
- content={"error": f"Failed to process file: {str(e)}"}
- )
- @click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
- @click.pass_context
- @click.option('--host', default='127.0.0.1', help='Server host (default: 127.0.0.1)')
- @click.option('--port', default=8000, type=int, help='Server port (default: 8000)')
- @click.option('--reload', is_flag=True, help='Enable auto-reload (development mode)')
- def main(ctx, host, port, reload, **kwargs):
- kwargs.update(arg_parse(ctx))
- # 将配置参数存储到应用状态中
- app.state.config = kwargs
- """启动MinerU FastAPI服务器的命令行入口"""
- print(f"Start MinerU FastAPI Service: http://{host}:{port}")
- print("The API documentation can be accessed at the following address:")
- print(f"- Swagger UI: http://{host}:{port}/docs")
- print(f"- ReDoc: http://{host}:{port}/redoc")
- uvicorn.run(
- "mineru.cli.fast_api:app",
- host=host,
- port=port,
- reload=reload
- )
- if __name__ == "__main__":
- main()
|