ocr.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. from typing import Awaitable, Final, List, Literal, Optional, Tuple
  16. import numpy as np
  17. from fastapi import HTTPException
  18. from numpy.typing import ArrayLike
  19. from pydantic import BaseModel, Field
  20. from typing_extensions import Annotated, TypeAlias, assert_never
  21. from ......utils import logging
  22. from ... import utils as serving_utils
  23. from .cv import postprocess_image
  24. from ...models import DataInfo
  25. from ...storage import create_storage, SupportsGetURL
  26. from ...app import AppContext
  27. DEFAULT_MAX_NUM_INPUT_IMGS: Final[int] = 10
  28. DEFAULT_MAX_OUTPUT_IMG_SIZE: Final[Tuple[int, int]] = (2000, 2000)
  29. FileType: TypeAlias = Literal[0, 1]
  30. class InferenceParams(BaseModel):
  31. maxLongSide: Optional[Annotated[int, Field(gt=0)]] = None
  32. class InferRequest(BaseModel):
  33. file: str
  34. fileType: Optional[FileType] = None
  35. inferenceParams: Optional[InferenceParams] = None
  36. def update_app_context(app_context: AppContext) -> None:
  37. extra_cfg = app_context.config.extra or {}
  38. app_context.extra["file_storage"] = None
  39. if "file_storage" in extra_cfg:
  40. app_context.extra["file_storage"] = create_storage(extra_cfg["file_storage"])
  41. app_context.extra["return_img_urls"] = extra_cfg.get("return_img_urls", False)
  42. if app_context.extra["return_img_urls"]:
  43. file_storage = app_context.extra["file_storage"]
  44. if not file_storage:
  45. raise ValueError(
  46. "The file storage must be properly configured when URLs need to be returned."
  47. )
  48. if not isinstance(file_storage, SupportsGetURL):
  49. raise TypeError(
  50. f"`{type(file_storage).__name__}` does not support getting URLs."
  51. )
  52. app_context.extra["max_num_input_imgs"] = extra_cfg.get(
  53. "max_num_input_imgs", DEFAULT_MAX_NUM_INPUT_IMGS
  54. )
  55. app_context.extra["max_output_img_size"] = extra_cfg.get(
  56. "max_output_img_size", DEFAULT_MAX_OUTPUT_IMG_SIZE
  57. )
  58. def get_file_type(request: InferRequest) -> Literal["IMAGE", "PDF"]:
  59. if request.fileType is None:
  60. if serving_utils.is_url(request.file):
  61. try:
  62. file_type = serving_utils.infer_file_type(request.file)
  63. except Exception:
  64. logging.exception("Failed to infer the file type")
  65. raise HTTPException(
  66. status_code=422,
  67. detail="The file type cannot be inferred from the URL. Please specify the file type explicitly.",
  68. )
  69. else:
  70. raise HTTPException(status_code=422, detail="Unknown file type")
  71. else:
  72. file_type = "PDF" if request.fileType == 0 else "IMAGE"
  73. return file_type
  74. async def get_images(
  75. request: InferRequest, app_context: AppContext
  76. ) -> Tuple[List[np.ndarray], DataInfo]:
  77. file_type = get_file_type(request)
  78. # XXX: Currently, we use 500 for consistency. However, 422 may be more
  79. # appropriate.
  80. try:
  81. file_bytes = await serving_utils.get_raw_bytes(
  82. request.file,
  83. app_context.aiohttp_session,
  84. )
  85. images, data_info = await serving_utils.call_async(
  86. serving_utils.file_to_images,
  87. file_bytes,
  88. file_type,
  89. max_num_imgs=app_context.extra["max_num_input_imgs"],
  90. )
  91. except Exception:
  92. logging.exception("Unexpected exception")
  93. raise HTTPException(status_code=500, detail="Internal server error")
  94. if file_type == "IMAGE":
  95. return images, DataInfo(image=data_info)
  96. elif file_type == "PDF":
  97. return images, DataInfo(pdf=data_info)
  98. else:
  99. assert_never()
  100. async def postprocess_images(
  101. *,
  102. log_id: str,
  103. index: str,
  104. app_context: AppContext,
  105. input_image: Optional[ArrayLike] = None,
  106. layout_image: Optional[ArrayLike] = None,
  107. ocr_image: Optional[ArrayLike] = None,
  108. ) -> List[str]:
  109. if input_image is None and layout_image is None and ocr_image is None:
  110. raise ValueError("At least one of the images must be provided.")
  111. file_storage = app_context.extra["file_storage"]
  112. return_img_urls = app_context.extra["return_img_urls"]
  113. max_img_size = app_context.extra["max_output_img_size"]
  114. futures: List[Awaitable] = []
  115. if input_image is not None:
  116. future = serving_utils.call_async(
  117. postprocess_image,
  118. input_image,
  119. log_id=log_id,
  120. filename=f"input_image_{index}.jpg",
  121. file_storage=file_storage,
  122. return_url=return_img_urls,
  123. max_img_size=max_img_size,
  124. )
  125. futures.append(future)
  126. if layout_image is not None:
  127. future = serving_utils.call_async(
  128. postprocess_image,
  129. layout_image,
  130. log_id=log_id,
  131. filename=f"layout_image_{index}.jpg",
  132. file_storage=file_storage,
  133. return_url=return_img_urls,
  134. max_img_size=max_img_size,
  135. )
  136. futures.append(future)
  137. if ocr_image is not None:
  138. future = serving_utils.call_async(
  139. postprocess_image,
  140. ocr_image,
  141. log_id=log_id,
  142. filename=f"ocr_image_{index}.jpg",
  143. file_storage=file_storage,
  144. max_img_size=max_img_size,
  145. )
  146. futures.append(future)
  147. return await asyncio.gather(*futures)