ppchatocrv3.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. import os
  16. from typing import Awaitable, Final, List, Literal, Optional, Tuple, Union
  17. import numpy as np
  18. from fastapi import FastAPI, HTTPException
  19. from numpy.typing import ArrayLike
  20. from pydantic import BaseModel, Field
  21. from typing_extensions import Annotated, TypeAlias, assert_never
  22. from .....utils import logging
  23. from .... import results
  24. from ...ppchatocrv3 import PPChatOCRPipeline
  25. from ..storage import SupportsGetURL, Storage, create_storage
  26. from .. import utils as serving_utils
  27. from ..app import AppConfig, create_app
  28. from ..models import Response, ResultResponse
  29. _DEFAULT_MAX_IMG_SIZE: Final[Tuple[int, int]] = (2000, 2000)
  30. _DEFAULT_MAX_NUM_IMGS: Final[int] = 10
  31. FileType: TypeAlias = Literal[0, 1]
  32. class InferenceParams(BaseModel):
  33. maxLongSide: Optional[Annotated[int, Field(gt=0)]] = None
  34. class AnalyzeImagesRequest(BaseModel):
  35. file: str
  36. fileType: Optional[FileType] = None
  37. useImgOrientationCls: bool = True
  38. useImgUnwrapping: bool = True
  39. useSealTextDet: bool = True
  40. inferenceParams: Optional[InferenceParams] = None
  41. Point: TypeAlias = Annotated[List[int], Field(min_length=2, max_length=2)]
  42. Polygon: TypeAlias = Annotated[List[Point], Field(min_length=3)]
  43. BoundingBox: TypeAlias = Annotated[List[float], Field(min_length=4, max_length=4)]
  44. class Text(BaseModel):
  45. poly: Polygon
  46. text: str
  47. score: float
  48. class Table(BaseModel):
  49. bbox: BoundingBox
  50. html: str
  51. class VisionResult(BaseModel):
  52. texts: List[Text]
  53. tables: List[Table]
  54. inputImage: str
  55. ocrImage: str
  56. layoutImage: str
  57. class AnalyzeImagesResult(BaseModel):
  58. visionResults: List[VisionResult]
  59. visionInfo: dict
  60. class QianfanParams(BaseModel):
  61. apiKey: str
  62. secretKey: str
  63. apiType: Literal["qianfan"] = "qianfan"
  64. class AIStudioParams(BaseModel):
  65. accessToken: str
  66. apiType: Literal["aistudio"] = "aistudio"
  67. LLMName: TypeAlias = Literal[
  68. "ernie-3.5",
  69. "ernie-3.5-8k",
  70. "ernie-lite",
  71. "ernie-4.0",
  72. "ernie-4.0-turbo-8k",
  73. "ernie-speed",
  74. "ernie-speed-128k",
  75. "ernie-tiny-8k",
  76. "ernie-char-8k",
  77. ]
  78. LLMParams: TypeAlias = Union[QianfanParams, AIStudioParams]
  79. class BuildVectorStoreRequest(BaseModel):
  80. visionInfo: dict
  81. minChars: Optional[int] = None
  82. llmRequestInterval: Optional[float] = None
  83. llmName: Optional[LLMName] = None
  84. llmParams: Optional[Annotated[LLMParams, Field(discriminator="apiType")]] = None
  85. class BuildVectorStoreResult(BaseModel):
  86. vectorStore: str
  87. class RetrieveKnowledgeRequest(BaseModel):
  88. keys: List[str]
  89. vectorStore: str
  90. llmName: Optional[LLMName] = None
  91. llmParams: Optional[Annotated[LLMParams, Field(discriminator="apiType")]] = None
  92. class RetrieveKnowledgeResult(BaseModel):
  93. retrievalResult: str
  94. class ChatRequest(BaseModel):
  95. keys: List[str]
  96. visionInfo: dict
  97. vectorStore: Optional[str] = None
  98. retrievalResult: Optional[str] = None
  99. taskDescription: Optional[str] = None
  100. rules: Optional[str] = None
  101. fewShot: Optional[str] = None
  102. llmName: Optional[LLMName] = None
  103. llmParams: Optional[Annotated[LLMParams, Field(discriminator="apiType")]] = None
  104. returnPrompts: bool = False
  105. class Prompts(BaseModel):
  106. ocr: str
  107. table: Optional[str] = None
  108. html: Optional[str] = None
  109. class ChatResult(BaseModel):
  110. chatResult: dict
  111. prompts: Optional[Prompts] = None
  112. def _llm_params_to_dict(llm_params: LLMParams) -> dict:
  113. if llm_params.apiType == "qianfan":
  114. return {
  115. "api_type": "qianfan",
  116. "ak": llm_params.apiKey,
  117. "sk": llm_params.secretKey,
  118. }
  119. if llm_params.apiType == "aistudio":
  120. return {"api_type": "aistudio", "access_token": llm_params.accessToken}
  121. else:
  122. assert_never(llm_params.apiType)
  123. def _postprocess_image(
  124. img: ArrayLike,
  125. request_id: str,
  126. filename: str,
  127. file_storage: Optional[Storage],
  128. ) -> str:
  129. key = f"{request_id}/{filename}"
  130. ext = os.path.splitext(filename)[1]
  131. img = np.asarray(img)
  132. img_bytes = serving_utils.image_array_to_bytes(img, ext=ext)
  133. if file_storage is not None:
  134. file_storage.set(key, img_bytes)
  135. if isinstance(file_storage, SupportsGetURL):
  136. return file_storage.get_url(key)
  137. return serving_utils.base64_encode(img_bytes)
  138. def create_pipeline_app(pipeline: PPChatOCRPipeline, app_config: AppConfig) -> FastAPI:
  139. app, ctx = create_app(
  140. pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
  141. )
  142. if ctx.config.extra and "file_storage" in ctx.config.extra:
  143. ctx.extra["file_storage"] = create_storage(ctx.config.extra["file_storage"])
  144. else:
  145. ctx.extra["file_storage"] = None
  146. ctx.extra.setdefault("max_img_size", _DEFAULT_MAX_IMG_SIZE)
  147. ctx.extra.setdefault("max_num_imgs", _DEFAULT_MAX_NUM_IMGS)
  148. @app.post(
  149. "/chatocr-vision",
  150. operation_id="analyzeImages",
  151. responses={422: {"model": Response}},
  152. )
  153. async def _analyze_images(
  154. request: AnalyzeImagesRequest,
  155. ) -> ResultResponse[AnalyzeImagesResult]:
  156. pipeline = ctx.pipeline
  157. aiohttp_session = ctx.aiohttp_session
  158. request_id = serving_utils.generate_request_id()
  159. if request.fileType is None:
  160. if serving_utils.is_url(request.file):
  161. try:
  162. file_type = serving_utils.infer_file_type(request.file)
  163. except Exception as e:
  164. logging.exception(e)
  165. raise HTTPException(
  166. status_code=422,
  167. detail="The file type cannot be inferred from the URL. Please specify the file type explicitly.",
  168. )
  169. else:
  170. raise HTTPException(status_code=422, detail="Unknown file type")
  171. else:
  172. file_type = "PDF" if request.fileType == 0 else "IMAGE"
  173. if request.inferenceParams:
  174. max_long_side = request.inferenceParams.maxLongSide
  175. if max_long_side:
  176. raise HTTPException(
  177. status_code=422,
  178. detail="`max_long_side` is currently not supported.",
  179. )
  180. try:
  181. file_bytes = await serving_utils.get_raw_bytes(
  182. request.file, aiohttp_session
  183. )
  184. images = await serving_utils.call_async(
  185. serving_utils.file_to_images,
  186. file_bytes,
  187. file_type,
  188. max_img_size=ctx.extra["max_img_size"],
  189. max_num_imgs=ctx.extra["max_num_imgs"],
  190. )
  191. result = await pipeline.call(
  192. pipeline.pipeline.visual_predict,
  193. images,
  194. use_doc_image_ori_cls_model=request.useImgOrientationCls,
  195. use_doc_image_unwarp_model=request.useImgUnwrapping,
  196. use_seal_text_det_model=request.useSealTextDet,
  197. )
  198. vision_results: List[VisionResult] = []
  199. for i, (img, item) in enumerate(zip(images, result[0])):
  200. pp_img_futures: List[Awaitable] = []
  201. future = serving_utils.call_async(
  202. _postprocess_image,
  203. img,
  204. request_id=request_id,
  205. filename=f"input_image_{i}.jpg",
  206. file_storage=ctx.extra["file_storage"],
  207. )
  208. pp_img_futures.append(future)
  209. future = serving_utils.call_async(
  210. _postprocess_image,
  211. item["ocr_result"].img,
  212. request_id=request_id,
  213. filename=f"ocr_image_{i}.jpg",
  214. file_storage=ctx.extra["file_storage"],
  215. )
  216. pp_img_futures.append(future)
  217. future = serving_utils.call_async(
  218. _postprocess_image,
  219. item["layout_result"].img,
  220. request_id=request_id,
  221. filename=f"layout_image_{i}.jpg",
  222. file_storage=ctx.extra["file_storage"],
  223. )
  224. pp_img_futures.append(future)
  225. texts: List[Text] = []
  226. for poly, text, score in zip(
  227. item["ocr_result"]["dt_polys"],
  228. item["ocr_result"]["rec_text"],
  229. item["ocr_result"]["rec_score"],
  230. ):
  231. texts.append(Text(poly=poly, text=text, score=score))
  232. tables = [
  233. Table(bbox=r["layout_bbox"], html=r["html"])
  234. for r in item["table_result"]
  235. ]
  236. input_img, ocr_img, layout_img = await asyncio.gather(*pp_img_futures)
  237. vision_result = VisionResult(
  238. texts=texts,
  239. tables=tables,
  240. inputImage=input_img,
  241. ocrImage=ocr_img,
  242. layoutImage=layout_img,
  243. )
  244. vision_results.append(vision_result)
  245. return ResultResponse(
  246. logId=serving_utils.generate_log_id(),
  247. errorCode=0,
  248. errorMsg="Success",
  249. result=AnalyzeImagesResult(
  250. visionResults=vision_results,
  251. visionInfo=result[1],
  252. ),
  253. )
  254. except Exception as e:
  255. logging.exception(e)
  256. raise HTTPException(status_code=500, detail="Internal server error")
  257. @app.post(
  258. "/chatocr-vector",
  259. operation_id="buildVectorStore",
  260. responses={422: {"model": Response}},
  261. )
  262. async def _build_vector_store(
  263. request: BuildVectorStoreRequest,
  264. ) -> ResultResponse[BuildVectorStoreResult]:
  265. pipeline = ctx.pipeline
  266. try:
  267. kwargs = {"visual_info": results.VisualInfoResult(request.visionInfo)}
  268. if request.minChars is not None:
  269. kwargs["min_characters"] = request.minChars
  270. if request.llmRequestInterval is not None:
  271. kwargs["llm_request_interval"] = request.llmRequestInterval
  272. if request.llmName is not None:
  273. kwargs["llm_name"] = request.llmName
  274. if request.llmParams is not None:
  275. kwargs["llm_params"] = _llm_params_to_dict(request.llmParams)
  276. result = await serving_utils.call_async(
  277. pipeline.pipeline.build_vector, **kwargs
  278. )
  279. return ResultResponse(
  280. logId=serving_utils.generate_log_id(),
  281. errorCode=0,
  282. errorMsg="Success",
  283. result=BuildVectorStoreResult(vectorStore=result["vector"]),
  284. )
  285. except Exception as e:
  286. logging.exception(e)
  287. raise HTTPException(status_code=500, detail="Internal server error")
  288. @app.post(
  289. "/chatocr-retrieval",
  290. operation_id="retrieveKnowledge",
  291. responses={422: {"model": Response}},
  292. )
  293. async def _retrieve_knowledge(
  294. request: RetrieveKnowledgeRequest,
  295. ) -> ResultResponse[RetrieveKnowledgeResult]:
  296. pipeline = ctx.pipeline
  297. try:
  298. kwargs = {
  299. "key_list": request.keys,
  300. "vector": results.VectorResult({"vector": request.vectorStore}),
  301. }
  302. if request.llmName is not None:
  303. kwargs["llm_name"] = request.llmName
  304. if request.llmParams is not None:
  305. kwargs["llm_params"] = _llm_params_to_dict(request.llmParams)
  306. result = await serving_utils.call_async(
  307. pipeline.pipeline.retrieval, **kwargs
  308. )
  309. return ResultResponse(
  310. logId=serving_utils.generate_log_id(),
  311. errorCode=0,
  312. errorMsg="Success",
  313. result=RetrieveKnowledgeResult(retrievalResult=result["retrieval"]),
  314. )
  315. except Exception as e:
  316. logging.exception(e)
  317. raise HTTPException(status_code=500, detail="Internal server error")
  318. @app.post(
  319. "/chatocr-chat",
  320. operation_id="chat",
  321. responses={422: {"model": Response}},
  322. response_model_exclude_none=True,
  323. )
  324. async def _chat(
  325. request: ChatRequest,
  326. ) -> ResultResponse[ChatResult]:
  327. pipeline = ctx.pipeline
  328. try:
  329. kwargs = {
  330. "key_list": request.keys,
  331. "visual_info": results.VisualInfoResult(request.visionInfo),
  332. }
  333. if request.vectorStore is not None:
  334. kwargs["vector"] = results.VectorResult({"vector": request.vectorStore})
  335. if request.retrievalResult is not None:
  336. kwargs["retrieval_result"] = results.RetrievalResult(
  337. {"retrieval": request.retrievalResult}
  338. )
  339. if request.taskDescription is not None:
  340. kwargs["user_task_description"] = request.taskDescription
  341. if request.rules is not None:
  342. kwargs["rules"] = request.rules
  343. if request.fewShot is not None:
  344. kwargs["few_shot"] = request.fewShot
  345. if request.llmName is not None:
  346. kwargs["llm_name"] = request.llmName
  347. if request.llmParams is not None:
  348. kwargs["llm_params"] = _llm_params_to_dict(request.llmParams)
  349. kwargs["save_prompt"] = request.returnPrompts
  350. result = await serving_utils.call_async(pipeline.pipeline.chat, **kwargs)
  351. if result["prompt"]:
  352. prompts = Prompts(
  353. ocr=result["prompt"]["ocr_prompt"],
  354. table=result["prompt"]["table_prompt"] or None,
  355. html=result["prompt"]["html_prompt"] or None,
  356. )
  357. else:
  358. prompts = None
  359. chat_result = ChatResult(
  360. chatResult=result["chat_res"],
  361. prompts=prompts,
  362. )
  363. return ResultResponse(
  364. logId=serving_utils.generate_log_id(),
  365. errorCode=0,
  366. errorMsg="Success",
  367. result=chat_result,
  368. )
  369. except Exception as e:
  370. logging.exception(e)
  371. raise HTTPException(status_code=500, detail="Internal server error")
  372. return app