| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209 |
- """
- 编辑器 API 路由
- """
- import json
- from fastapi import APIRouter, UploadFile, File, HTTPException
- from loguru import logger
- from models.schemas import (
- UploadResponse,
- AnalyzeRequest,
- AnalyzeResponse,
- SaveRequest,
- SaveResponse,
- TableStructure,
- ImageSize,
- LoadByPathRequest,
- HealthResponse,
- )
- from services.editor_service import EditorService
- router = APIRouter(prefix="/api", tags=["editor"])
- @router.post("/upload", response_model=UploadResponse)
- async def upload_files(
- json_file: UploadFile = File(..., description="OCR JSON 文件"),
- image_file: UploadFile = File(..., description="图片文件")
- ):
- """
- 上传 OCR JSON 和图片文件,返回分析结果
-
- - 自动检测 OCR 格式(PPStructure / MinerU)
- - 图片超过 4096x4096 会自动缩放
- - 返回 base64 编码的图片和表格结构
- """
- try:
- # 验证文件类型
- if not json_file.filename.endswith('.json'):
- raise HTTPException(status_code=400, detail="请上传 JSON 文件")
-
- allowed_image_types = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff'}
- image_ext = '.' + image_file.filename.split('.')[-1].lower()
- if image_ext not in allowed_image_types:
- raise HTTPException(status_code=400, detail=f"不支持的图片格式: {image_ext}")
-
- # 读取文件内容
- json_content = await json_file.read()
- image_content = await image_file.read()
-
- logger.info(f"收到上传: JSON={json_file.filename}, Image={image_file.filename}")
-
- # 处理上传,从上传的文件名中提取建议的文件名
- result = EditorService.process_upload(
- json_content,
- image_content,
- json_path=json_file.filename
- )
-
- return UploadResponse(
- success=True,
- image_base64=result['image_base64'],
- structure=TableStructure(**result['structure']),
- image_size=ImageSize(**result['image_size']),
- scale_factor=result['scale_factor'],
- ocr_data=result['ocr_data'],
- suggested_filename=result.get('suggested_filename'),
- message="上传成功"
- )
-
- except ValueError as e:
- logger.error(f"上传处理失败: {e}")
- import traceback
- logger.error(traceback.format_exc())
- raise HTTPException(status_code=400, detail=str(e))
- except Exception as e:
- logger.exception(f"上传处理异常: {e}")
- raise HTTPException(status_code=500, detail=f"服务器错误: {e}")
- @router.post("/analyze", response_model=AnalyzeResponse)
- async def analyze_structure(request: AnalyzeRequest):
- """
- 重新分析表格结构(使用不同参数)
- """
- try:
- structure = EditorService.analyze_structure(
- request.ocr_data,
- request.params.model_dump()
- )
-
- return AnalyzeResponse(
- success=True,
- structure=TableStructure(**structure),
- message="分析完成"
- )
-
- except Exception as e:
- logger.exception(f"分析失败: {e}")
- raise HTTPException(status_code=500, detail=f"分析失败: {e}")
- @router.post("/save", response_model=SaveResponse)
- async def save_result(request: SaveRequest):
- """
- 保存结果(结构 JSON + 可选的带线图片)
- """
- try:
- paths = EditorService.save_result(
- structure=request.structure.model_dump(),
- image_base64=request.image_base64,
- output_dir=request.output_dir,
- filename=request.filename,
- image_filename=request.image_filename,
- overwrite_mode=request.overwrite_mode,
- structure_suffix=request.structure_suffix,
- image_suffix=request.image_suffix,
- line_width=request.line_width,
- line_color=(request.line_color[0], request.line_color[1], request.line_color[2])
- )
-
- return SaveResponse(
- success=True,
- structure_path=paths['structure_path'],
- image_path=paths['image_path'],
- message="保存成功"
- )
-
- except Exception as e:
- logger.exception(f"保存失败: {e}")
- raise HTTPException(status_code=500, detail=f"保存失败: {e}")
- @router.get("/health", response_model=HealthResponse)
- async def health_check():
- """健康检查"""
- return HealthResponse(status="ok", service="table-line-editor")
- @router.post("/load-by-path", response_model=UploadResponse)
- async def load_by_path(request: LoadByPathRequest):
- """按路径加载数据(优先加载标注结果)"""
- from pathlib import Path
-
- image_path = Path(request.image_path)
- json_path = Path(request.json_path)
- output_dir = Path(request.output_dir) if request.output_dir else None
-
- try:
- if not image_path.exists():
- raise HTTPException(status_code=404, detail=f"图片文件不存在: {image_path}")
-
- if not json_path.exists():
- raise HTTPException(status_code=404, detail=f"JSON 文件不存在: {json_path}")
-
- # 使用来自前端的后缀,不再读配置文件
- base_name = json_path.stem
- # 确保 structure_suffix 以 .json 结尾
- structure_suffix = request.structure_suffix
- if not structure_suffix.endswith('.json'):
- structure_suffix = structure_suffix + '.json'
- structure_path = output_dir / f"{base_name}{structure_suffix}" if output_dir else None
-
- with open(json_path, 'rb') as f:
- json_content = f.read()
- with open(image_path, 'rb') as f:
- image_content = f.read()
-
- # 如果存在标注结果,优先加载
- if structure_path and structure_path.exists():
- logger.info(f"找到标注结果: {structure_path}")
- with open(structure_path, 'r', encoding='utf-8') as f:
- structure_data = json.load(f)
-
- result = EditorService.process_upload(
- json_content,
- image_content,
- json_path=str(json_path),
- annotated_structure=structure_data
- )
- else:
- logger.info(f"未找到标注结果,使用原始OCR数据")
- result = EditorService.process_upload(
- json_content,
- image_content,
- json_path=str(json_path)
- )
-
- return UploadResponse(
- success=True,
- image_base64=result['image_base64'],
- structure=TableStructure(**result['structure']),
- image_size=ImageSize(**result['image_size']),
- scale_factor=result['scale_factor'],
- ocr_data=result['ocr_data'],
- suggested_filename=result.get('suggested_filename'),
- message="加载成功"
- )
-
- except HTTPException:
- raise
- except ValueError as e:
- logger.error(f"加载处理失败: {e}")
- raise HTTPException(status_code=400, detail=str(e))
- except Exception as e:
- logger.exception(f"加载处理异常: {e}")
- raise HTTPException(status_code=500, detail=f"服务器错误: {e}")
|