ppchatocrv3.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. import os
  16. import re
  17. import uuid
  18. from typing import Awaitable, Final, List, Literal, Optional, Tuple, Union
  19. from urllib.parse import parse_qs, urlparse
  20. import cv2
  21. import numpy as np
  22. from fastapi import FastAPI, HTTPException
  23. from numpy.typing import ArrayLike
  24. from pydantic import BaseModel, Field
  25. from typing_extensions import Annotated, TypeAlias, assert_never
  26. from .....utils import logging
  27. from .... import results
  28. from ...ppchatocrv3 import PPChatOCRPipeline
  29. from .. import file_storage
  30. from .. import utils as serving_utils
  31. from ..app import AppConfig, create_app
  32. from ..models import Response, ResultResponse
  33. _DEFAULT_MAX_IMG_SIZE: Final[Tuple[int, int]] = (2000, 2000)
  34. _DEFAULT_MAX_NUM_IMGS: Final[int] = 10
  35. FileType: TypeAlias = Literal[0, 1]
  36. class InferenceParams(BaseModel):
  37. maxLongSide: Optional[Annotated[int, Field(gt=0)]] = None
  38. class AnalyzeImageRequest(BaseModel):
  39. file: str
  40. fileType: Optional[FileType] = None
  41. useImgOrientationCls: bool = True
  42. useImgUnwrapping: bool = True
  43. useSealTextDet: bool = True
  44. inferenceParams: Optional[InferenceParams] = None
  45. Point: TypeAlias = Annotated[List[int], Field(min_length=2, max_length=2)]
  46. Polygon: TypeAlias = Annotated[List[Point], Field(min_length=3)]
  47. BoundingBox: TypeAlias = Annotated[List[float], Field(min_length=4, max_length=4)]
  48. class Text(BaseModel):
  49. poly: Polygon
  50. text: str
  51. score: float
  52. class Table(BaseModel):
  53. bbox: BoundingBox
  54. html: str
  55. class VisionResult(BaseModel):
  56. texts: List[Text]
  57. tables: List[Table]
  58. inputImage: str
  59. ocrImage: str
  60. layoutImage: str
  61. class AnalyzeImageResult(BaseModel):
  62. visionResults: List[VisionResult]
  63. visionInfo: dict
  64. class AIStudioParams(BaseModel):
  65. accessToken: str
  66. apiType: Literal["aistudio"] = "aistudio"
  67. class QianfanParams(BaseModel):
  68. apiKey: str
  69. secretKey: str
  70. apiType: Literal["qianfan"] = "qianfan"
  71. LLMName: TypeAlias = Literal[
  72. "ernie-3.5",
  73. "ernie-3.5-8k",
  74. "ernie-lite",
  75. "ernie-4.0",
  76. "ernie-4.0-turbo-8k",
  77. "ernie-speed",
  78. "ernie-speed-128k",
  79. "ernie-tiny-8k",
  80. "ernie-char-8k",
  81. ]
  82. LLMParams: TypeAlias = Union[AIStudioParams, QianfanParams]
  83. class BuildVectorStoreRequest(BaseModel):
  84. visionInfo: dict
  85. minChars: Optional[int] = None
  86. llmRequestInterval: Optional[float] = None
  87. llmName: Optional[LLMName] = None
  88. llmParams: Optional[Annotated[LLMParams, Field(discriminator="apiType")]] = None
  89. class BuildVectorStoreResult(BaseModel):
  90. vectorStore: str
  91. class RetrieveKnowledgeRequest(BaseModel):
  92. keys: List[str]
  93. vectorStore: str
  94. llmName: Optional[LLMName] = None
  95. llmParams: Optional[Annotated[LLMParams, Field(discriminator="apiType")]] = None
  96. class RetrieveKnowledgeResult(BaseModel):
  97. retrievalResult: str
  98. class ChatRequest(BaseModel):
  99. keys: List[str]
  100. visionInfo: dict
  101. taskDescription: Optional[str] = None
  102. rules: Optional[str] = None
  103. fewShot: Optional[str] = None
  104. vectorStore: Optional[str] = None
  105. retrievalResult: Optional[str] = None
  106. returnPrompts: bool = True
  107. llmName: Optional[LLMName] = None
  108. llmParams: Optional[Annotated[LLMParams, Field(discriminator="apiType")]] = None
  109. class Prompts(BaseModel):
  110. ocr: str
  111. table: Optional[str] = None
  112. html: Optional[str] = None
  113. class ChatResult(BaseModel):
  114. chatResult: dict
  115. prompts: Optional[Prompts] = None
  116. def _generate_request_id() -> str:
  117. return str(uuid.uuid4())
  118. def _infer_file_type(url: str) -> FileType:
  119. # Is it more reliable to guess the file type based on the response headers?
  120. SUPPORTED_IMG_EXTS: Final[List[str]] = [".jpg", ".jpeg", ".png"]
  121. url_parts = urlparse(url)
  122. ext = os.path.splitext(url_parts.path)[1]
  123. # HACK: The support for BOS URLs with query params is implementation-based,
  124. # not interface-based.
  125. is_bos_url = (
  126. re.fullmatch(r"(?:bj|bd|su|gz|cd|hkg|fwh|fsh)\.bcebos\.com", url_parts.netloc)
  127. is not None
  128. )
  129. if is_bos_url and url_parts.query:
  130. params = parse_qs(url_parts.query)
  131. if (
  132. "responseContentDisposition" not in params
  133. or len(params["responseContentDisposition"]) != 1
  134. ):
  135. raise ValueError("`responseContentDisposition` not found")
  136. match_ = re.match(
  137. r"attachment;filename=(.*)", params["responseContentDisposition"][0]
  138. )
  139. if not match_ or not match_.groups()[0] is not None:
  140. raise ValueError(
  141. "Failed to extract the filename from `responseContentDisposition`"
  142. )
  143. ext = os.path.splitext(match_.groups()[0])[1]
  144. ext = ext.lower()
  145. if ext == ".pdf":
  146. return 0
  147. elif ext in SUPPORTED_IMG_EXTS:
  148. return 1
  149. else:
  150. raise ValueError("Unsupported file type")
  151. def _llm_params_to_dict(llm_params: LLMParams) -> dict:
  152. if llm_params.apiType == "aistudio":
  153. return {"api_type": "aistudio", "access_token": llm_params.accessToken}
  154. elif llm_params.apiType == "qianfan":
  155. return {
  156. "api_type": "qianfan",
  157. "ak": llm_params.apiKey,
  158. "sk": llm_params.secretKey,
  159. }
  160. else:
  161. assert_never(llm_params.apiType)
  162. def _bytes_to_arrays(
  163. file_bytes: bytes,
  164. file_type: FileType,
  165. *,
  166. max_img_size: Tuple[int, int],
  167. max_num_imgs: int,
  168. ) -> List[np.ndarray]:
  169. if file_type == 0:
  170. images = serving_utils.read_pdf(
  171. file_bytes, resize=True, max_num_imgs=max_num_imgs
  172. )
  173. elif file_type == 1:
  174. images = [serving_utils.image_bytes_to_array(file_bytes)]
  175. else:
  176. assert_never(file_type)
  177. h, w = images[0].shape[0:2]
  178. if w > max_img_size[1] or h > max_img_size[0]:
  179. if w / h > max_img_size[0] / max_img_size[1]:
  180. factor = max_img_size[0] / w
  181. else:
  182. factor = max_img_size[1] / h
  183. images = [cv2.resize(img, (int(factor * w), int(factor * h))) for img in images]
  184. return images
  185. def _postprocess_image(
  186. img: ArrayLike,
  187. request_id: str,
  188. filename: str,
  189. file_storage_config: file_storage.FileStorageConfig,
  190. ) -> str:
  191. key = f"{request_id}/{filename}"
  192. ext = os.path.splitext(filename)[1]
  193. img = np.asarray(img)
  194. _, encoded_img = cv2.imencode(ext, img)
  195. encoded_img = encoded_img.tobytes()
  196. return file_storage.postprocess_file(
  197. encoded_img, config=file_storage_config, key=key
  198. )
  199. def create_pipeline_app(pipeline: PPChatOCRPipeline, app_config: AppConfig) -> FastAPI:
  200. app, ctx = create_app(
  201. pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
  202. )
  203. if "file_storage_config" in ctx.extra:
  204. ctx.extra["file_storage_config"] = file_storage.parse_file_storage_config(
  205. ctx.extra["file_storage_config"]
  206. )
  207. else:
  208. ctx.extra["file_storage_config"] = file_storage.InMemoryStorageConfig()
  209. ctx.extra.setdefault("max_img_size", _DEFAULT_MAX_IMG_SIZE)
  210. ctx.extra.setdefault("max_num_imgs", _DEFAULT_MAX_NUM_IMGS)
  211. @app.post(
  212. "/chatocr-vision",
  213. operation_id="analyzeImage",
  214. responses={422: {"model": Response}},
  215. )
  216. async def _analyze_image(
  217. request: AnalyzeImageRequest,
  218. ) -> ResultResponse[AnalyzeImageResult]:
  219. pipeline = ctx.pipeline
  220. aiohttp_session = ctx.aiohttp_session
  221. request_id = _generate_request_id()
  222. if request.fileType is None:
  223. if serving_utils.is_url(request.file):
  224. try:
  225. file_type = _infer_file_type(request.file)
  226. except Exception as e:
  227. logging.exception(e)
  228. raise HTTPException(
  229. status_code=422,
  230. detail="The file type cannot be inferred from the URL. Please specify the file type explicitly.",
  231. )
  232. else:
  233. raise HTTPException(status_code=422, detail="Unknown file type")
  234. else:
  235. file_type = request.fileType
  236. if request.inferenceParams:
  237. max_long_side = request.inferenceParams.maxLongSide
  238. if max_long_side:
  239. raise HTTPException(
  240. status_code=422,
  241. detail="`max_long_side` is currently not supported.",
  242. )
  243. try:
  244. file_bytes = await serving_utils.get_raw_bytes(
  245. request.file, aiohttp_session
  246. )
  247. images = await serving_utils.call_async(
  248. _bytes_to_arrays,
  249. file_bytes,
  250. file_type,
  251. max_img_size=ctx.extra["max_img_size"],
  252. max_num_imgs=ctx.extra["max_num_imgs"],
  253. )
  254. result = await pipeline.call(
  255. pipeline.pipeline.visual_predict,
  256. images,
  257. use_doc_image_ori_cls_model=request.useImgOrientationCls,
  258. use_doc_image_unwarp_model=request.useImgUnwrapping,
  259. use_seal_text_det_model=request.useSealTextDet,
  260. )
  261. vision_results: List[VisionResult] = []
  262. for i, (img, item) in enumerate(zip(images, result[0])):
  263. pp_img_futures: List[Awaitable] = []
  264. future = serving_utils.call_async(
  265. _postprocess_image,
  266. img,
  267. request_id=request_id,
  268. filename=f"input_image_{i}.jpg",
  269. file_storage_config=ctx.extra["file_storage_config"],
  270. )
  271. pp_img_futures.append(future)
  272. future = serving_utils.call_async(
  273. _postprocess_image,
  274. item["ocr_result"].img,
  275. request_id=request_id,
  276. filename=f"ocr_image_{i}.jpg",
  277. file_storage_config=ctx.extra["file_storage_config"],
  278. )
  279. pp_img_futures.append(future)
  280. future = serving_utils.call_async(
  281. _postprocess_image,
  282. item["layout_result"].img,
  283. request_id=request_id,
  284. filename=f"layout_image_{i}.jpg",
  285. file_storage_config=ctx.extra["file_storage_config"],
  286. )
  287. pp_img_futures.append(future)
  288. texts: List[Text] = []
  289. for poly, text, score in zip(
  290. item["ocr_result"]["dt_polys"],
  291. item["ocr_result"]["rec_text"],
  292. item["ocr_result"]["rec_score"],
  293. ):
  294. texts.append(Text(poly=poly, text=text, score=score))
  295. tables = [
  296. Table(bbox=r["layout_bbox"], html=r["html"])
  297. for r in item["table_result"]
  298. ]
  299. input_img, ocr_img, layout_img = await asyncio.gather(*pp_img_futures)
  300. vision_result = VisionResult(
  301. texts=texts,
  302. tables=tables,
  303. inputImage=input_img,
  304. ocrImage=ocr_img,
  305. layoutImage=layout_img,
  306. )
  307. vision_results.append(vision_result)
  308. return ResultResponse(
  309. logId=serving_utils.generate_log_id(),
  310. errorCode=0,
  311. errorMsg="Success",
  312. result=AnalyzeImageResult(
  313. visionResults=vision_results,
  314. visionInfo=result[1],
  315. ),
  316. )
  317. except Exception as e:
  318. logging.exception(e)
  319. raise HTTPException(status_code=500, detail="Internal server error")
  320. @app.post(
  321. "/chatocr-vector",
  322. operation_id="buildVectorStore",
  323. responses={422: {"model": Response}},
  324. )
  325. async def _build_vector_store(
  326. request: BuildVectorStoreRequest,
  327. ) -> ResultResponse[BuildVectorStoreResult]:
  328. pipeline = ctx.pipeline
  329. try:
  330. kwargs = {"visual_info": results.VisualInfoResult(request.visionInfo)}
  331. if request.minChars is not None:
  332. kwargs["min_characters"] = request.minChars
  333. else:
  334. kwargs["min_characters"] = 0
  335. if request.llmRequestInterval is not None:
  336. kwargs["llm_request_interval"] = request.llmRequestInterval
  337. if request.llmName is not None:
  338. kwargs["llm_name"] = request.llmName
  339. if request.llmParams is not None:
  340. kwargs["llm_params"] = _llm_params_to_dict(request.llmParams)
  341. result = await serving_utils.call_async(
  342. pipeline.pipeline.build_vector, **kwargs
  343. )
  344. return ResultResponse(
  345. logId=serving_utils.generate_log_id(),
  346. errorCode=0,
  347. errorMsg="Success",
  348. result=BuildVectorStoreResult(vectorStore=result["vector"]),
  349. )
  350. except Exception as e:
  351. logging.exception(e)
  352. raise HTTPException(status_code=500, detail="Internal server error")
  353. @app.post(
  354. "/chatocr-retrieval",
  355. operation_id="retrieveKnowledge",
  356. responses={422: {"model": Response}},
  357. )
  358. async def _retrieve_knowledge(
  359. request: RetrieveKnowledgeRequest,
  360. ) -> ResultResponse[RetrieveKnowledgeResult]:
  361. pipeline = ctx.pipeline
  362. try:
  363. kwargs = {
  364. "key_list": request.keys,
  365. "vector": results.VectorResult({"vector": request.vectorStore}),
  366. }
  367. if request.llmName is not None:
  368. kwargs["llm_name"] = request.llmName
  369. if request.llmParams is not None:
  370. kwargs["llm_params"] = _llm_params_to_dict(request.llmParams)
  371. result = await serving_utils.call_async(
  372. pipeline.pipeline.retrieval, **kwargs
  373. )
  374. return ResultResponse(
  375. logId=serving_utils.generate_log_id(),
  376. errorCode=0,
  377. errorMsg="Success",
  378. result=RetrieveKnowledgeResult(retrievalResult=result["retrieval"]),
  379. )
  380. except Exception as e:
  381. logging.exception(e)
  382. raise HTTPException(status_code=500, detail="Internal server error")
  383. @app.post(
  384. "/chatocr-chat",
  385. operation_id="chat",
  386. responses={422: {"model": Response}},
  387. response_model_exclude_none=True,
  388. )
  389. async def _chat(
  390. request: ChatRequest,
  391. ) -> ResultResponse[ChatResult]:
  392. pipeline = ctx.pipeline
  393. try:
  394. kwargs = {
  395. "key_list": request.keys,
  396. "visual_info": results.VisualInfoResult(request.visionInfo),
  397. }
  398. if request.taskDescription is not None:
  399. kwargs["user_task_description"] = request.taskDescription
  400. if request.rules is not None:
  401. kwargs["rules"] = request.rules
  402. if request.fewShot is not None:
  403. kwargs["few_shot"] = request.fewShot
  404. if request.vectorStore is not None:
  405. kwargs["vector"] = results.VectorResult({"vector": request.vectorStore})
  406. if request.retrievalResult is not None:
  407. kwargs["retrieval_result"] = results.RetrievalResult(
  408. {"retrieval": request.retrievalResult}
  409. )
  410. kwargs["save_prompt"] = request.returnPrompts
  411. if request.llmName is not None:
  412. kwargs["llm_name"] = request.llmName
  413. if request.llmParams is not None:
  414. kwargs["llm_params"] = _llm_params_to_dict(request.llmParams)
  415. result = await serving_utils.call_async(pipeline.pipeline.chat, **kwargs)
  416. if result["prompt"]:
  417. prompts = Prompts(
  418. ocr=result["prompt"]["ocr_prompt"],
  419. table=result["prompt"]["table_prompt"] or None,
  420. html=result["prompt"]["html_prompt"] or None,
  421. )
  422. else:
  423. prompts = None
  424. chat_result = ChatResult(
  425. chatResult=result["chat_res"],
  426. prompts=prompts,
  427. )
  428. return ResultResponse(
  429. logId=serving_utils.generate_log_id(),
  430. errorCode=0,
  431. errorMsg="Success",
  432. result=chat_result,
  433. )
  434. except Exception as e:
  435. logging.exception(e)
  436. raise HTTPException(status_code=500, detail="Internal server error")
  437. return app