ppchatocrv3.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. import os
  16. import re
  17. import uuid
  18. from typing import Awaitable, Final, List, Literal, Optional, Tuple, Union
  19. from urllib.parse import parse_qs, urlparse
  20. import cv2
  21. import numpy as np
  22. from fastapi import FastAPI, HTTPException
  23. from numpy.typing import ArrayLike
  24. from pydantic import BaseModel, Field
  25. from typing_extensions import Annotated, TypeAlias, assert_never
  26. from .....utils import logging
  27. from ...ppchatocrv3 import PPChatOCRPipeline
  28. from .. import file_storage
  29. from .. import utils as serving_utils
  30. from ..app import AppConfig, create_app
  31. from ..models import Response, ResultResponse
  32. _DEFAULT_MAX_IMG_SIZE: Final[Tuple[int, int]] = (2000, 2000)
  33. _DEFAULT_MAX_NUM_IMGS: Final[int] = 10
  34. FileType: TypeAlias = Literal[0, 1]
  35. class InferenceParams(BaseModel):
  36. maxLongSide: Optional[Annotated[int, Field(gt=0)]] = None
  37. class AnalyzeImageRequest(BaseModel):
  38. file: str
  39. fileType: Optional[FileType] = None
  40. useOricls: bool = True
  41. useCurve: bool = True
  42. useUvdoc: bool = True
  43. inferenceParams: Optional[InferenceParams] = None
  44. Point: TypeAlias = Annotated[List[int], Field(min_length=2, max_length=2)]
  45. Polygon: TypeAlias = Annotated[List[Point], Field(min_length=3)]
  46. BoundingBox: TypeAlias = Annotated[List[float], Field(min_length=4, max_length=4)]
  47. class Text(BaseModel):
  48. poly: Polygon
  49. text: str
  50. score: float
  51. class Table(BaseModel):
  52. bbox: BoundingBox
  53. html: str
  54. class VisionResult(BaseModel):
  55. texts: List[Text]
  56. tables: List[Table]
  57. inputImage: str
  58. ocrImage: str
  59. layoutImage: str
  60. class AnalyzeImageResult(BaseModel):
  61. visionResults: List[VisionResult]
  62. visionInfo: dict
  63. class AIStudioParams(BaseModel):
  64. accessToken: str
  65. apiType: Literal["aistudio"] = "aistudio"
  66. class QianfanParams(BaseModel):
  67. apiKey: str
  68. secretKey: str
  69. apiType: Literal["qianfan"] = "qianfan"
  70. LLMName: TypeAlias = Literal[
  71. "ernie-3.5",
  72. "ernie-3.5-8k",
  73. "ernie-lite",
  74. "ernie-4.0",
  75. "ernie-4.0-turbo-8k",
  76. "ernie-speed",
  77. "ernie-speed-128k",
  78. "ernie-tiny-8k",
  79. "ernie-char-8k",
  80. ]
  81. LLMParams: TypeAlias = Union[AIStudioParams, QianfanParams]
  82. class BuildVectorStoreRequest(BaseModel):
  83. visionInfo: dict
  84. minChars: Optional[int] = None
  85. llmRequestInterval: Optional[float] = None
  86. llmName: Optional[LLMName] = None
  87. llmParams: Optional[Annotated[LLMParams, Field(discriminator="apiType")]] = None
  88. class BuildVectorStoreResult(BaseModel):
  89. vectorStore: dict
  90. class RetrieveKnowledgeRequest(BaseModel):
  91. keys: List[str]
  92. vectorStore: dict
  93. visionInfo: dict
  94. llmName: Optional[LLMName] = None
  95. llmParams: Optional[Annotated[LLMParams, Field(discriminator="apiType")]] = None
  96. class RetrieveKnowledgeResult(BaseModel):
  97. retrievalResult: str
  98. class ChatRequest(BaseModel):
  99. keys: List[str]
  100. visionInfo: dict
  101. taskDescription: Optional[str] = None
  102. rules: Optional[str] = None
  103. fewShot: Optional[str] = None
  104. useVectorStore: bool = True
  105. vectorStore: Optional[dict] = None
  106. retrievalResult: Optional[str] = None
  107. returnPrompts: bool = True
  108. llmName: Optional[LLMName] = None
  109. llmParams: Optional[Annotated[LLMParams, Field(discriminator="apiType")]] = None
  110. class Prompts(BaseModel):
  111. ocr: str
  112. table: str
  113. html: str
  114. class ChatResult(BaseModel):
  115. chatResult: str
  116. prompts: Optional[Prompts] = None
  117. def _generate_request_id() -> str:
  118. return str(uuid.uuid4())
  119. def _infer_file_type(url: str) -> FileType:
  120. # Is it more reliable to guess the file type based on the response headers?
  121. SUPPORTED_IMG_EXTS: Final[List[str]] = [".jpg", ".jpeg", ".png"]
  122. url_parts = urlparse(url)
  123. ext = os.path.splitext(url_parts.path)[1]
  124. # HACK: The support for BOS URLs with query params is implementation-based,
  125. # not interface-based.
  126. is_bos_url = (
  127. re.fullmatch(r"(?:bj|bd|su|gz|cd|hkg|fwh|fsh)\.bcebos\.com", url_parts.netloc)
  128. is not None
  129. )
  130. if is_bos_url and url_parts.query:
  131. params = parse_qs(url_parts.query)
  132. if (
  133. "responseContentDisposition" not in params
  134. or len(params["responseContentDisposition"]) != 1
  135. ):
  136. raise ValueError("`responseContentDisposition` not found")
  137. match_ = re.match(
  138. r"attachment;filename=(.*)", params["responseContentDisposition"][0]
  139. )
  140. if not match_ or not match_.groups()[0] is not None:
  141. raise ValueError(
  142. "Failed to extract the filename from `responseContentDisposition`"
  143. )
  144. ext = os.path.splitext(match_.groups()[0])[1]
  145. ext = ext.lower()
  146. if ext == ".pdf":
  147. return 0
  148. elif ext in SUPPORTED_IMG_EXTS:
  149. return 1
  150. else:
  151. raise ValueError("Unsupported file type")
  152. def _llm_params_to_dict(llm_params: LLMParams) -> dict:
  153. if llm_params.apiType == "aistudio":
  154. return {"api_type": "aistudio", "access_token": llm_params.accessToken}
  155. elif llm_params.apiType == "qianfan":
  156. return {
  157. "api_type": "qianfan",
  158. "ak": llm_params.apiKey,
  159. "sk": llm_params.secretKey,
  160. }
  161. else:
  162. assert_never(llm_params.apiType)
  163. def _bytes_to_arrays(
  164. file_bytes: bytes,
  165. file_type: FileType,
  166. *,
  167. max_img_size: Tuple[int, int],
  168. max_num_imgs: int,
  169. ) -> List[np.ndarray]:
  170. if file_type == 0:
  171. images = serving_utils.read_pdf(
  172. file_bytes, resize=True, max_num_imgs=max_num_imgs
  173. )
  174. elif file_type == 1:
  175. images = [serving_utils.image_bytes_to_array(file_bytes)]
  176. else:
  177. assert_never(file_type)
  178. h, w = images[0].shape[0:2]
  179. if w > max_img_size[1] or h > max_img_size[0]:
  180. if w / h > max_img_size[0] / max_img_size[1]:
  181. factor = max_img_size[0] / w
  182. else:
  183. factor = max_img_size[1] / h
  184. images = [cv2.resize(img, (int(factor * w), int(factor * h))) for img in images]
  185. return images
  186. def _postprocess_image(
  187. img: ArrayLike,
  188. request_id: str,
  189. filename: str,
  190. file_storage_config: file_storage.FileStorageConfig,
  191. ) -> str:
  192. key = f"{request_id}/{filename}"
  193. ext = os.path.splitext(filename)[1]
  194. img = np.asarray(img)
  195. _, encoded_img = cv2.imencode(ext, img)
  196. encoded_img = encoded_img.tobytes()
  197. return file_storage.postprocess_file(
  198. encoded_img, config=file_storage_config, key=key
  199. )
  200. def create_pipeline_app(pipeline: PPChatOCRPipeline, app_config: AppConfig) -> FastAPI:
  201. app, ctx = create_app(
  202. pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
  203. )
  204. if "file_storage_config" in ctx.extra:
  205. ctx.extra["file_storage_config"] = file_storage.parse_file_storage_config(
  206. ctx.extra["file_storage_config"]
  207. )
  208. else:
  209. ctx.extra["file_storage_config"] = file_storage.InMemoryStorageConfig()
  210. ctx.extra.setdefault("max_img_size", _DEFAULT_MAX_IMG_SIZE)
  211. ctx.extra.setdefault("max_num_imgs", _DEFAULT_MAX_NUM_IMGS)
  212. @app.post(
  213. "/chatocr-vision",
  214. operation_id="analyzeImage",
  215. responses={422: {"model": Response}},
  216. )
  217. async def _analyze_image(
  218. request: AnalyzeImageRequest,
  219. ) -> ResultResponse[AnalyzeImageResult]:
  220. pipeline = ctx.pipeline
  221. aiohttp_session = ctx.aiohttp_session
  222. request_id = _generate_request_id()
  223. if request.fileType is None:
  224. if serving_utils.is_url(request.file):
  225. try:
  226. file_type = _infer_file_type(request.file)
  227. except Exception as e:
  228. logging.exception(e)
  229. raise HTTPException(
  230. status_code=422,
  231. detail="The file type cannot be inferred from the URL. Please specify the file type explicitly.",
  232. )
  233. else:
  234. raise HTTPException(status_code=422, detail="Unknown file type")
  235. else:
  236. file_type = request.fileType
  237. if request.inferenceParams:
  238. max_long_side = request.inferenceParams.maxLongSide
  239. if max_long_side:
  240. raise HTTPException(
  241. status_code=422,
  242. detail="`max_long_side` is currently not supported.",
  243. )
  244. try:
  245. file_bytes = await serving_utils.get_raw_bytes(
  246. request.file, aiohttp_session
  247. )
  248. images = await serving_utils.call_async(
  249. _bytes_to_arrays,
  250. file_bytes,
  251. file_type,
  252. max_img_size=ctx.extra["max_img_size"],
  253. max_num_imgs=ctx.extra["max_num_imgs"],
  254. )
  255. result = await pipeline.infer(
  256. images,
  257. use_oricls=request.useOricls,
  258. use_curve=request.useCurve,
  259. use_uvdoc=request.useUvdoc,
  260. )
  261. vision_results: List[VisionResult] = []
  262. for i, (img, item) in enumerate(zip(images, result[0])):
  263. pp_img_futures: List[Awaitable] = []
  264. future = serving_utils.call_async(
  265. _postprocess_image,
  266. img,
  267. request_id=request_id,
  268. filename=f"input_image_{i}.jpg",
  269. file_storage_config=ctx.extra["file_storage_config"],
  270. )
  271. pp_img_futures.append(future)
  272. future = serving_utils.call_async(
  273. _postprocess_image,
  274. item["ocr_result"].img,
  275. request_id=request_id,
  276. filename=f"ocr_image_{i}.jpg",
  277. file_storage_config=ctx.extra["file_storage_config"],
  278. )
  279. pp_img_futures.append(future)
  280. future = serving_utils.call_async(
  281. _postprocess_image,
  282. item["layout_result"].img,
  283. request_id=request_id,
  284. filename=f"layout_image_{i}.jpg",
  285. file_storage_config=ctx.extra["file_storage_config"],
  286. )
  287. pp_img_futures.append(future)
  288. texts: List[Text] = []
  289. for poly, text, score in zip(
  290. item["ocr_result"]["dt_polys"],
  291. item["ocr_result"]["rec_text"],
  292. item["ocr_result"]["rec_score"],
  293. ):
  294. texts.append(Text(poly=poly, text=text, score=score))
  295. tables = [
  296. Table(bbox=r["layout_bbox"], html=r["html"])
  297. for r in item["table_result"]
  298. ]
  299. input_img, ocr_img, layout_img = await asyncio.gather(*pp_img_futures)
  300. vision_result = VisionResult(
  301. texts=texts,
  302. tables=tables,
  303. inputImage=input_img,
  304. ocrImage=ocr_img,
  305. layoutImage=layout_img,
  306. )
  307. vision_results.append(vision_result)
  308. return ResultResponse(
  309. logId=serving_utils.generate_log_id(),
  310. errorCode=0,
  311. errorMsg="Success",
  312. result=AnalyzeImageResult(
  313. visionResults=vision_results,
  314. visionInfo=result[1],
  315. ),
  316. )
  317. except Exception as e:
  318. logging.exception(e)
  319. raise HTTPException(status_code=500, detail="Internal server error")
  320. @app.post(
  321. "/chatocr-vector",
  322. operation_id="buildVectorStore",
  323. responses={422: {"model": Response}},
  324. )
  325. async def _build_vector_store(
  326. request: BuildVectorStoreRequest,
  327. ) -> ResultResponse[BuildVectorStoreResult]:
  328. pipeline = ctx.pipeline
  329. try:
  330. kwargs = {"visual_info": request.visionInfo}
  331. if request.minChars is not None:
  332. kwargs["min_characters"] = request.minChars
  333. else:
  334. kwargs["min_characters"] = 0
  335. if request.llmRequestInterval is not None:
  336. kwargs["llm_request_interval"] = request.llmRequestInterval
  337. if request.llmName is not None:
  338. kwargs["llm_name"] = request.llmName
  339. if request.llmParams is not None:
  340. kwargs["llm_params"] = _llm_params_to_dict(request.llmParams)
  341. result = await serving_utils.call_async(
  342. pipeline.pipeline.get_vector_text, **kwargs
  343. )
  344. return ResultResponse(
  345. logId=serving_utils.generate_log_id(),
  346. errorCode=0,
  347. errorMsg="Success",
  348. result=BuildVectorStoreResult(vectorStore=result),
  349. )
  350. except Exception as e:
  351. logging.exception(e)
  352. raise HTTPException(status_code=500, detail="Internal server error")
  353. @app.post(
  354. "/chatocr-retrieval",
  355. operation_id="retrieveKnowledge",
  356. responses={422: {"model": Response}},
  357. )
  358. async def _retrieve_knowledge(
  359. request: RetrieveKnowledgeRequest,
  360. ) -> ResultResponse[RetrieveKnowledgeResult]:
  361. pipeline = ctx.pipeline
  362. try:
  363. kwargs = {
  364. "key_list": request.keys,
  365. "vector": request.vectorStore,
  366. "visual_info": request.visionInfo,
  367. }
  368. if request.llmName is not None:
  369. kwargs["llm_name"] = request.llmName
  370. if request.llmParams is not None:
  371. kwargs["llm_params"] = _llm_params_to_dict(request.llmParams)
  372. result = await serving_utils.call_async(
  373. pipeline.pipeline.get_retrieval_text, **kwargs
  374. )
  375. return ResultResponse(
  376. logId=serving_utils.generate_log_id(),
  377. errorCode=0,
  378. errorMsg="Success",
  379. result=RetrieveKnowledgeResult(retrievalResult=result["retrieval"]),
  380. )
  381. except Exception as e:
  382. logging.exception(e)
  383. raise HTTPException(status_code=500, detail="Internal server error")
  384. @app.post(
  385. "/chatocr-chat", operation_id="chat", responses={422: {"model": Response}}
  386. )
  387. async def _chat(
  388. request: ChatRequest,
  389. ) -> ResultResponse[ChatResult]:
  390. pipeline = ctx.pipeline
  391. try:
  392. kwargs = {
  393. "key_list": request.keys,
  394. "visual_info": request.visionInfo,
  395. }
  396. if request.taskDescription is not None:
  397. kwargs["user_task_description"] = request.taskDescription
  398. if request.rules is not None:
  399. kwargs["rules"] = request.rules
  400. if request.fewShot is not None:
  401. kwargs["few_shot"] = request.fewShot
  402. kwargs["use_vector"] = request.useVectorStore
  403. if request.vectorStore is not None:
  404. kwargs["vector"] = request.vectorStore
  405. if request.retrievalResult is not None:
  406. kwargs["retrieval_result"] = request.retrievalResult
  407. kwargs["save_prompt"] = request.returnPrompts
  408. if request.llmName is not None:
  409. kwargs["llm_name"] = request.llmName
  410. if request.llmParams is not None:
  411. kwargs["llm_params"] = _llm_params_to_dict(request.llmParams)
  412. result = await serving_utils.call_async(pipeline.pipeline.chat, **kwargs)
  413. if result["prompt"]:
  414. prompts = Prompts(
  415. ocr=result["prompt"]["ocr_prompt"],
  416. table=result["prompt"]["table_prompt"],
  417. html=result["prompt"]["html_prompt"],
  418. )
  419. chat_result = ChatResult(
  420. chatResult=result["chat_res"],
  421. prompts=prompts,
  422. )
  423. else:
  424. chat_result = ChatResult(
  425. chatResult=result["chat_res"],
  426. )
  427. return ResultResponse(
  428. logId=serving_utils.generate_log_id(),
  429. errorCode=0,
  430. errorMsg="Success",
  431. result=chat_result,
  432. )
  433. except Exception as e:
  434. logging.exception(e)
  435. raise HTTPException(status_code=500, detail="Internal server error")
  436. return app