layout_parsing.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Literal, Optional
  15. from fastapi import FastAPI, HTTPException
  16. from pydantic import BaseModel, Field
  17. from typing_extensions import Annotated, TypeAlias
  18. from ._common import cv as cv_common, ocr as ocr_common
  19. from .....utils import logging
  20. from ...layout_parsing import LayoutParsingPipeline
  21. from .. import utils as serving_utils
  22. from ..app import AppConfig, create_app
  23. from ..models import NoResultResponse, ResultResponse, DataInfo
  24. class InferRequest(ocr_common.InferRequest):
  25. useImgOrientationCls: bool = True
  26. useImgUnwarping: bool = True
  27. useSealTextDet: bool = True
  28. BoundingBox: TypeAlias = Annotated[List[float], Field(min_length=4, max_length=4)]
  29. class LayoutElement(BaseModel):
  30. bbox: BoundingBox
  31. label: str
  32. text: str
  33. layoutType: Literal["single", "double"]
  34. image: Optional[str] = None
  35. class LayoutParsingResult(BaseModel):
  36. layoutElements: List[LayoutElement]
  37. class InferResult(BaseModel):
  38. layoutParsingResults: List[LayoutParsingResult]
  39. dataInfo: DataInfo
  40. def create_pipeline_app(
  41. pipeline: LayoutParsingPipeline, app_config: AppConfig
  42. ) -> FastAPI:
  43. app, ctx = create_app(
  44. pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
  45. )
  46. ocr_common.update_app_context(ctx)
  47. @app.post(
  48. "/layout-parsing",
  49. operation_id="infer",
  50. responses={422: {"model": NoResultResponse}},
  51. response_model_exclude_none=True,
  52. )
  53. async def _infer(
  54. request: InferRequest,
  55. ) -> ResultResponse[InferResult]:
  56. pipeline = ctx.pipeline
  57. log_id = serving_utils.generate_log_id()
  58. if request.inferenceParams:
  59. max_long_side = request.inferenceParams.maxLongSide
  60. if max_long_side:
  61. raise HTTPException(
  62. status_code=422,
  63. detail="`max_long_side` is currently not supported.",
  64. )
  65. images, data_info = await ocr_common.get_images(request, ctx)
  66. try:
  67. result = await pipeline.infer(
  68. images,
  69. use_doc_image_ori_cls_model=request.useImgOrientationCls,
  70. use_doc_image_unwarp_model=request.useImgUnwarping,
  71. use_seal_text_det_model=request.useSealTextDet,
  72. )
  73. layout_parsing_results: List[LayoutParsingResult] = []
  74. for i, item in enumerate(result):
  75. layout_elements: List[LayoutElement] = []
  76. for j, subitem in enumerate(
  77. item["layout_parsing_result"]["parsing_result"]
  78. ):
  79. dyn_keys = subitem.keys() - {"input_path", "layout_bbox", "layout"}
  80. if len(dyn_keys) != 1:
  81. raise RuntimeError(f"Unexpected result: {subitem}")
  82. label = next(iter(dyn_keys))
  83. if label in ("image", "figure", "img", "fig"):
  84. image_ = await serving_utils.call_async(
  85. cv_common.postprocess_image,
  86. subitem[label]["img"],
  87. log_id=log_id,
  88. filename=f"image_{i}_{j}.jpg",
  89. file_storage=ctx.extra["file_storage"],
  90. return_url=ctx.extra["return_img_urls"],
  91. max_img_size=ctx.extra["max_output_img_size"],
  92. )
  93. text = subitem[label]["image_text"]
  94. else:
  95. image_ = None
  96. text = subitem[label]
  97. layout_elements.append(
  98. LayoutElement(
  99. bbox=subitem["layout_bbox"],
  100. label=label,
  101. text=text,
  102. layoutType=subitem["layout"],
  103. image=image_,
  104. )
  105. )
  106. layout_parsing_results.append(
  107. LayoutParsingResult(layoutElements=layout_elements)
  108. )
  109. return ResultResponse[InferResult](
  110. logId=log_id,
  111. result=InferResult(
  112. layoutParsingResults=layout_parsing_results,
  113. dataInfo=data_info,
  114. ),
  115. )
  116. except Exception:
  117. logging.exception("Unexpected exception")
  118. raise HTTPException(status_code=500, detail="Internal server error")
  119. return app