|
@@ -0,0 +1,515 @@
|
|
|
|
|
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
|
|
|
|
+#
|
|
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
|
|
+# You may obtain a copy of the License at
|
|
|
|
|
+#
|
|
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
+#
|
|
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
|
|
+# limitations under the License.
|
|
|
|
|
+
|
|
|
|
|
+import asyncio
|
|
|
|
|
+import os
|
|
|
|
|
+import re
|
|
|
|
|
+import uuid
|
|
|
|
|
+from typing import Awaitable, Final, List, Literal, Optional, Tuple, Union
|
|
|
|
|
+from urllib.parse import parse_qs, urlparse
|
|
|
|
|
+
|
|
|
|
|
+import cv2
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+from fastapi import FastAPI, HTTPException
|
|
|
|
|
+from numpy.typing import ArrayLike
|
|
|
|
|
+from pydantic import BaseModel, Field
|
|
|
|
|
+from typing_extensions import Annotated, TypeAlias, assert_never
|
|
|
|
|
+
|
|
|
|
|
+from .....utils import logging
|
|
|
|
|
+from ...ppchatocrv3 import PPChatOCRPipeline
|
|
|
|
|
+from .. import file_storage
|
|
|
|
|
+from .. import utils as serving_utils
|
|
|
|
|
+from ..app import AppConfig, create_app
|
|
|
|
|
+from ..models import Response, ResultResponse
|
|
|
|
|
+
|
|
|
|
|
+_DEFAULT_MAX_IMG_SIZE: Final[Tuple[int, int]] = (2000, 2000)
|
|
|
|
|
+_DEFAULT_MAX_NUM_IMGS: Final[int] = 10
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+FileType: TypeAlias = Literal[0, 1]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class InferenceParams(BaseModel):
|
|
|
|
|
+ maxLongSide: Optional[Annotated[int, Field(gt=0)]] = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class AnalyzeImageRequest(BaseModel):
|
|
|
|
|
+ file: str
|
|
|
|
|
+ fileType: Optional[FileType] = None
|
|
|
|
|
+ useOricls: bool = True
|
|
|
|
|
+ useCurve: bool = True
|
|
|
|
|
+ useUvdoc: bool = True
|
|
|
|
|
+ inferenceParams: Optional[InferenceParams] = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+Point: TypeAlias = Annotated[List[int], Field(min_length=2, max_length=2)]
|
|
|
|
|
+Polygon: TypeAlias = Annotated[List[Point], Field(min_length=3)]
|
|
|
|
|
+BoundingBox: TypeAlias = Annotated[List[float], Field(min_length=4, max_length=4)]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class Text(BaseModel):
|
|
|
|
|
+ poly: Polygon
|
|
|
|
|
+ text: str
|
|
|
|
|
+ score: float
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class Table(BaseModel):
|
|
|
|
|
+ bbox: BoundingBox
|
|
|
|
|
+ html: str
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class VisionResult(BaseModel):
|
|
|
|
|
+ texts: List[Text]
|
|
|
|
|
+ tables: List[Table]
|
|
|
|
|
+ inputImage: str
|
|
|
|
|
+ ocrImage: str
|
|
|
|
|
+ layoutImage: str
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class AnalyzeImageResult(BaseModel):
|
|
|
|
|
+ visionResults: List[VisionResult]
|
|
|
|
|
+ visionInfo: dict
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class AIStudioParams(BaseModel):
|
|
|
|
|
+ accessToken: str
|
|
|
|
|
+ apiType: Literal["aistudio"] = "aistudio"
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class QianfanParams(BaseModel):
|
|
|
|
|
+ apiKey: str
|
|
|
|
|
+ secretKey: str
|
|
|
|
|
+ apiType: Literal["qianfan"] = "qianfan"
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+LLMName: TypeAlias = Literal[
|
|
|
|
|
+ "ernie-3.5",
|
|
|
|
|
+ "ernie-3.5-8k",
|
|
|
|
|
+ "ernie-lite",
|
|
|
|
|
+ "ernie-4.0",
|
|
|
|
|
+ "ernie-4.0-turbo-8k",
|
|
|
|
|
+ "ernie-speed",
|
|
|
|
|
+ "ernie-speed-128k",
|
|
|
|
|
+ "ernie-tiny-8k",
|
|
|
|
|
+ "ernie-char-8k",
|
|
|
|
|
+]
|
|
|
|
|
+LLMParams: TypeAlias = Union[AIStudioParams, QianfanParams]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class BuildVectorStoreRequest(BaseModel):
|
|
|
|
|
+ visionInfo: dict
|
|
|
|
|
+ minChars: Optional[int] = None
|
|
|
|
|
+ llmRequestInterval: Optional[float] = None
|
|
|
|
|
+ llmName: Optional[LLMName] = None
|
|
|
|
|
+ llmParams: Optional[Annotated[LLMParams, Field(discriminator="apiType")]] = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class BuildVectorStoreResult(BaseModel):
|
|
|
|
|
+ vectorStore: dict
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class RetrieveKnowledgeRequest(BaseModel):
|
|
|
|
|
+ keys: List[str]
|
|
|
|
|
+ vectorStore: dict
|
|
|
|
|
+ visionInfo: dict
|
|
|
|
|
+ llmName: Optional[LLMName] = None
|
|
|
|
|
+ llmParams: Optional[Annotated[LLMParams, Field(discriminator="apiType")]] = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class RetrieveKnowledgeResult(BaseModel):
|
|
|
|
|
+ retrievalResult: str
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ChatRequest(BaseModel):
|
|
|
|
|
+ keys: List[str]
|
|
|
|
|
+ visionInfo: dict
|
|
|
|
|
+ taskDescription: Optional[str] = None
|
|
|
|
|
+ rules: Optional[str] = None
|
|
|
|
|
+ fewShot: Optional[str] = None
|
|
|
|
|
+ useVectorStore: bool = True
|
|
|
|
|
+ vectorStore: Optional[dict] = None
|
|
|
|
|
+ retrievalResult: Optional[str] = None
|
|
|
|
|
+ returnPrompts: bool = True
|
|
|
|
|
+ llmName: Optional[LLMName] = None
|
|
|
|
|
+ llmParams: Optional[Annotated[LLMParams, Field(discriminator="apiType")]] = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class Prompts(BaseModel):
|
|
|
|
|
+ ocr: str
|
|
|
|
|
+ table: str
|
|
|
|
|
+ html: str
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ChatResult(BaseModel):
|
|
|
|
|
+ chatResult: str
|
|
|
|
|
+ prompts: Optional[Prompts] = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _generate_request_id() -> str:
|
|
|
|
|
+ return str(uuid.uuid4())
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _infer_file_type(url: str) -> FileType:
|
|
|
|
|
+ # Is it more reliable to guess the file type based on the response headers?
|
|
|
|
|
+ SUPPORTED_IMG_EXTS: Final[List[str]] = [".jpg", ".jpeg", ".png"]
|
|
|
|
|
+
|
|
|
|
|
+ url_parts = urlparse(url)
|
|
|
|
|
+ ext = os.path.splitext(url_parts.path)[1]
|
|
|
|
|
+ # HACK: The support for BOS URLs with query params is implementation-based,
|
|
|
|
|
+ # not interface-based.
|
|
|
|
|
+ is_bos_url = (
|
|
|
|
|
+ re.fullmatch(r"(?:bj|bd|su|gz|cd|hkg|fwh|fsh)\.bcebos\.com", url_parts.netloc)
|
|
|
|
|
+ is not None
|
|
|
|
|
+ )
|
|
|
|
|
+ if is_bos_url and url_parts.query:
|
|
|
|
|
+ params = parse_qs(url_parts.query)
|
|
|
|
|
+ if (
|
|
|
|
|
+ "responseContentDisposition" not in params
|
|
|
|
|
+ or len(params["responseContentDisposition"]) != 1
|
|
|
|
|
+ ):
|
|
|
|
|
+ raise ValueError("`responseContentDisposition` not found")
|
|
|
|
|
+ match_ = re.match(
|
|
|
|
|
+ r"attachment;filename=(.*)", params["responseContentDisposition"][0]
|
|
|
|
|
+ )
|
|
|
|
|
+ if not match_ or not match_.groups()[0] is not None:
|
|
|
|
|
+ raise ValueError(
|
|
|
|
|
+ "Failed to extract the filename from `responseContentDisposition`"
|
|
|
|
|
+ )
|
|
|
|
|
+ ext = os.path.splitext(match_.groups()[0])[1]
|
|
|
|
|
+ ext = ext.lower()
|
|
|
|
|
+ if ext == ".pdf":
|
|
|
|
|
+ return 0
|
|
|
|
|
+ elif ext in SUPPORTED_IMG_EXTS:
|
|
|
|
|
+ return 1
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise ValueError("Unsupported file type")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _llm_params_to_dict(llm_params: LLMParams) -> dict:
|
|
|
|
|
+ if llm_params.apiType == "aistudio":
|
|
|
|
|
+ return {"api_type": "aistudio", "access_token": llm_params.accessToken}
|
|
|
|
|
+ elif llm_params.apiType == "qianfan":
|
|
|
|
|
+ return {
|
|
|
|
|
+ "api_type": "qianfan",
|
|
|
|
|
+ "ak": llm_params.apiKey,
|
|
|
|
|
+ "sk": llm_params.secretKey,
|
|
|
|
|
+ }
|
|
|
|
|
+ else:
|
|
|
|
|
+ assert_never(llm_params.apiType)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _bytes_to_arrays(
|
|
|
|
|
+ file_bytes: bytes,
|
|
|
|
|
+ file_type: FileType,
|
|
|
|
|
+ *,
|
|
|
|
|
+ max_img_size: Tuple[int, int],
|
|
|
|
|
+ max_num_imgs: int,
|
|
|
|
|
+) -> List[np.ndarray]:
|
|
|
|
|
+ if file_type == 0:
|
|
|
|
|
+ images = serving_utils.read_pdf(
|
|
|
|
|
+ file_bytes, resize=True, max_num_imgs=max_num_imgs
|
|
|
|
|
+ )
|
|
|
|
|
+ elif file_type == 1:
|
|
|
|
|
+ images = [serving_utils.image_bytes_to_array(file_bytes)]
|
|
|
|
|
+ else:
|
|
|
|
|
+ assert_never(file_type)
|
|
|
|
|
+ h, w = images[0].shape[0:2]
|
|
|
|
|
+ if w > max_img_size[1] or h > max_img_size[0]:
|
|
|
|
|
+ if w / h > max_img_size[0] / max_img_size[1]:
|
|
|
|
|
+ factor = max_img_size[0] / w
|
|
|
|
|
+ else:
|
|
|
|
|
+ factor = max_img_size[1] / h
|
|
|
|
|
+ images = [cv2.resize(img, (int(factor * w), int(factor * h))) for img in images]
|
|
|
|
|
+ return images
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _postprocess_image(
|
|
|
|
|
+ img: ArrayLike,
|
|
|
|
|
+ request_id: str,
|
|
|
|
|
+ filename: str,
|
|
|
|
|
+ file_storage_config: file_storage.FileStorageConfig,
|
|
|
|
|
+) -> str:
|
|
|
|
|
+ key = f"{request_id}/{filename}"
|
|
|
|
|
+ ext = os.path.splitext(filename)[1]
|
|
|
|
|
+ img = np.asarray(img)
|
|
|
|
|
+ _, encoded_img = cv2.imencode(ext, img)
|
|
|
|
|
+ encoded_img = encoded_img.tobytes()
|
|
|
|
|
+ return file_storage.postprocess_file(
|
|
|
|
|
+ encoded_img, config=file_storage_config, key=key
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def create_pipeline_app(pipeline: PPChatOCRPipeline, app_config: AppConfig) -> FastAPI:
|
|
|
|
|
+ app, ctx = create_app(
|
|
|
|
|
+ pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if "file_storage_config" in ctx.extra:
|
|
|
|
|
+ ctx.extra["file_storage_config"] = file_storage.parse_file_storage_config(
|
|
|
|
|
+ ctx.extra["file_storage_config"]
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ ctx.extra["file_storage_config"] = file_storage.InMemoryStorageConfig()
|
|
|
|
|
+ ctx.extra.setdefault("max_img_size", _DEFAULT_MAX_IMG_SIZE)
|
|
|
|
|
+ ctx.extra.setdefault("max_num_imgs", _DEFAULT_MAX_NUM_IMGS)
|
|
|
|
|
+
|
|
|
|
|
+ @app.post(
|
|
|
|
|
+ "/chatocr-vision",
|
|
|
|
|
+ operation_id="analyzeImage",
|
|
|
|
|
+ responses={422: {"model": Response}},
|
|
|
|
|
+ )
|
|
|
|
|
+ async def _analyze_image(
|
|
|
|
|
+ request: AnalyzeImageRequest,
|
|
|
|
|
+ ) -> ResultResponse[AnalyzeImageResult]:
|
|
|
|
|
+ pipeline = ctx.pipeline
|
|
|
|
|
+ aiohttp_session = ctx.aiohttp_session
|
|
|
|
|
+
|
|
|
|
|
+ request_id = _generate_request_id()
|
|
|
|
|
+
|
|
|
|
|
+ if request.fileType is None:
|
|
|
|
|
+ if serving_utils.is_url(request.file):
|
|
|
|
|
+ try:
|
|
|
|
|
+ file_type = _infer_file_type(request.file)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logging.exception(e)
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=422,
|
|
|
|
|
+ detail="The file type cannot be inferred from the URL. Please specify the file type explicitly.",
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise HTTPException(status_code=422, detail="Unknown file type")
|
|
|
|
|
+ else:
|
|
|
|
|
+ file_type = request.fileType
|
|
|
|
|
+
|
|
|
|
|
+ if request.inferenceParams:
|
|
|
|
|
+ max_long_side = request.inferenceParams.maxLongSide
|
|
|
|
|
+ if max_long_side:
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=422,
|
|
|
|
|
+ detail="`max_long_side` is currently not supported.",
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ file_bytes = await serving_utils.get_raw_bytes(
|
|
|
|
|
+ request.file, aiohttp_session
|
|
|
|
|
+ )
|
|
|
|
|
+ images = await serving_utils.call_async(
|
|
|
|
|
+ _bytes_to_arrays,
|
|
|
|
|
+ file_bytes,
|
|
|
|
|
+ file_type,
|
|
|
|
|
+ max_img_size=ctx.extra["max_img_size"],
|
|
|
|
|
+ max_num_imgs=ctx.extra["max_num_imgs"],
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ result = await pipeline.infer(
|
|
|
|
|
+ images,
|
|
|
|
|
+ use_oricls=request.useOricls,
|
|
|
|
|
+ use_curve=request.useCurve,
|
|
|
|
|
+ use_uvdoc=request.useUvdoc,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ vision_results: List[VisionResult] = []
|
|
|
|
|
+ for i, (img, item) in enumerate(zip(images, result[0])):
|
|
|
|
|
+ pp_img_futures: List[Awaitable] = []
|
|
|
|
|
+ future = serving_utils.call_async(
|
|
|
|
|
+ _postprocess_image,
|
|
|
|
|
+ img,
|
|
|
|
|
+ request_id=request_id,
|
|
|
|
|
+ filename=f"input_image_{i}.jpg",
|
|
|
|
|
+ file_storage_config=ctx.extra["file_storage_config"],
|
|
|
|
|
+ )
|
|
|
|
|
+ pp_img_futures.append(future)
|
|
|
|
|
+ future = serving_utils.call_async(
|
|
|
|
|
+ _postprocess_image,
|
|
|
|
|
+ item["ocr_result"].img,
|
|
|
|
|
+ request_id=request_id,
|
|
|
|
|
+ filename=f"ocr_image_{i}.jpg",
|
|
|
|
|
+ file_storage_config=ctx.extra["file_storage_config"],
|
|
|
|
|
+ )
|
|
|
|
|
+ pp_img_futures.append(future)
|
|
|
|
|
+ future = serving_utils.call_async(
|
|
|
|
|
+ _postprocess_image,
|
|
|
|
|
+ item["layout_result"].img,
|
|
|
|
|
+ request_id=request_id,
|
|
|
|
|
+ filename=f"layout_image_{i}.jpg",
|
|
|
|
|
+ file_storage_config=ctx.extra["file_storage_config"],
|
|
|
|
|
+ )
|
|
|
|
|
+ pp_img_futures.append(future)
|
|
|
|
|
+ texts: List[Text] = []
|
|
|
|
|
+ for poly, text, score in zip(
|
|
|
|
|
+ item["ocr_result"]["dt_polys"],
|
|
|
|
|
+ item["ocr_result"]["rec_text"],
|
|
|
|
|
+ item["ocr_result"]["rec_score"],
|
|
|
|
|
+ ):
|
|
|
|
|
+ texts.append(Text(poly=poly, text=text, score=score))
|
|
|
|
|
+ tables = [
|
|
|
|
|
+ Table(bbox=r["layout_bbox"], html=r["html"])
|
|
|
|
|
+ for r in item["table_result"]
|
|
|
|
|
+ ]
|
|
|
|
|
+ input_img, ocr_img, layout_img = await asyncio.gather(*pp_img_futures)
|
|
|
|
|
+ vision_result = VisionResult(
|
|
|
|
|
+ texts=texts,
|
|
|
|
|
+ tables=tables,
|
|
|
|
|
+ inputImage=input_img,
|
|
|
|
|
+ ocrImage=ocr_img,
|
|
|
|
|
+ layoutImage=layout_img,
|
|
|
|
|
+ )
|
|
|
|
|
+ vision_results.append(vision_result)
|
|
|
|
|
+
|
|
|
|
|
+ return ResultResponse(
|
|
|
|
|
+ logId=serving_utils.generate_log_id(),
|
|
|
|
|
+ errorCode=0,
|
|
|
|
|
+ errorMsg="Success",
|
|
|
|
|
+ result=AnalyzeImageResult(
|
|
|
|
|
+ visionResults=vision_results,
|
|
|
|
|
+ visionInfo=result[1],
|
|
|
|
|
+ ),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logging.exception(e)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="Internal server error")
|
|
|
|
|
+
|
|
|
|
|
+ @app.post(
|
|
|
|
|
+ "/chatocr-vector",
|
|
|
|
|
+ operation_id="buildVectorStore",
|
|
|
|
|
+ responses={422: {"model": Response}},
|
|
|
|
|
+ )
|
|
|
|
|
+ async def _build_vector_store(
|
|
|
|
|
+ request: BuildVectorStoreRequest,
|
|
|
|
|
+ ) -> ResultResponse[BuildVectorStoreResult]:
|
|
|
|
|
+ pipeline = ctx.pipeline
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ kwargs = {"visual_info": request.visionInfo}
|
|
|
|
|
+ if request.minChars is not None:
|
|
|
|
|
+ kwargs["min_characters"] = request.minChars
|
|
|
|
|
+ else:
|
|
|
|
|
+ kwargs["min_characters"] = 0
|
|
|
|
|
+ if request.llmRequestInterval is not None:
|
|
|
|
|
+ kwargs["llm_request_interval"] = request.llmRequestInterval
|
|
|
|
|
+ if request.llmName is not None:
|
|
|
|
|
+ kwargs["llm_name"] = request.llmName
|
|
|
|
|
+ if request.llmParams is not None:
|
|
|
|
|
+ kwargs["llm_params"] = _llm_params_to_dict(request.llmParams)
|
|
|
|
|
+
|
|
|
|
|
+ result = await serving_utils.call_async(
|
|
|
|
|
+ pipeline.pipeline.get_vector_text, **kwargs
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return ResultResponse(
|
|
|
|
|
+ logId=serving_utils.generate_log_id(),
|
|
|
|
|
+ errorCode=0,
|
|
|
|
|
+ errorMsg="Success",
|
|
|
|
|
+ result=BuildVectorStoreResult(vectorStore=result),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logging.exception(e)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="Internal server error")
|
|
|
|
|
+
|
|
|
|
|
+ @app.post(
|
|
|
|
|
+ "/chatocr-retrieval",
|
|
|
|
|
+ operation_id="retrieveKnowledge",
|
|
|
|
|
+ responses={422: {"model": Response}},
|
|
|
|
|
+ )
|
|
|
|
|
+ async def _retrieve_knowledge(
|
|
|
|
|
+ request: RetrieveKnowledgeRequest,
|
|
|
|
|
+ ) -> ResultResponse[RetrieveKnowledgeResult]:
|
|
|
|
|
+ pipeline = ctx.pipeline
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ kwargs = {
|
|
|
|
|
+ "key_list": request.keys,
|
|
|
|
|
+ "vector": request.vectorStore,
|
|
|
|
|
+ "visual_info": request.visionInfo,
|
|
|
|
|
+ }
|
|
|
|
|
+ if request.llmName is not None:
|
|
|
|
|
+ kwargs["llm_name"] = request.llmName
|
|
|
|
|
+ if request.llmParams is not None:
|
|
|
|
|
+ kwargs["llm_params"] = _llm_params_to_dict(request.llmParams)
|
|
|
|
|
+
|
|
|
|
|
+ result = await serving_utils.call_async(
|
|
|
|
|
+ pipeline.pipeline.get_retrieval_text, **kwargs
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return ResultResponse(
|
|
|
|
|
+ logId=serving_utils.generate_log_id(),
|
|
|
|
|
+ errorCode=0,
|
|
|
|
|
+ errorMsg="Success",
|
|
|
|
|
+ result=RetrieveKnowledgeResult(retrievalResult=result["retrieval"]),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logging.exception(e)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="Internal server error")
|
|
|
|
|
+
|
|
|
|
|
+ @app.post(
|
|
|
|
|
+ "/chatocr-chat", operation_id="chat", responses={422: {"model": Response}}
|
|
|
|
|
+ )
|
|
|
|
|
+ async def _chat(
|
|
|
|
|
+ request: ChatRequest,
|
|
|
|
|
+ ) -> ResultResponse[ChatResult]:
|
|
|
|
|
+ pipeline = ctx.pipeline
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ kwargs = {
|
|
|
|
|
+ "key_list": request.keys,
|
|
|
|
|
+ "visual_info": request.visionInfo,
|
|
|
|
|
+ }
|
|
|
|
|
+ if request.taskDescription is not None:
|
|
|
|
|
+ kwargs["user_task_description"] = request.taskDescription
|
|
|
|
|
+ if request.rules is not None:
|
|
|
|
|
+ kwargs["rules"] = request.rules
|
|
|
|
|
+ if request.fewShot is not None:
|
|
|
|
|
+ kwargs["few_shot"] = request.fewShot
|
|
|
|
|
+ kwargs["use_vector"] = request.useVectorStore
|
|
|
|
|
+ if request.vectorStore is not None:
|
|
|
|
|
+ kwargs["vector"] = request.vectorStore
|
|
|
|
|
+ if request.retrievalResult is not None:
|
|
|
|
|
+ kwargs["retrieval_result"] = request.retrievalResult
|
|
|
|
|
+ kwargs["save_prompt"] = request.returnPrompts
|
|
|
|
|
+ if request.llmName is not None:
|
|
|
|
|
+ kwargs["llm_name"] = request.llmName
|
|
|
|
|
+ if request.llmParams is not None:
|
|
|
|
|
+ kwargs["llm_params"] = _llm_params_to_dict(request.llmParams)
|
|
|
|
|
+
|
|
|
|
|
+ result = await serving_utils.call_async(pipeline.pipeline.chat, **kwargs)
|
|
|
|
|
+
|
|
|
|
|
+ if result["prompt"]:
|
|
|
|
|
+ prompts = Prompts(
|
|
|
|
|
+ ocr=result["prompt"]["ocr_prompt"],
|
|
|
|
|
+ table=result["prompt"]["table_prompt"],
|
|
|
|
|
+ html=result["prompt"]["html_prompt"],
|
|
|
|
|
+ )
|
|
|
|
|
+ chat_result = ChatResult(
|
|
|
|
|
+ chatResult=result["chat_res"],
|
|
|
|
|
+ prompts=prompts,
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ chat_result = ChatResult(
|
|
|
|
|
+ chatResult=result["chat_res"],
|
|
|
|
|
+ )
|
|
|
|
|
+ return ResultResponse(
|
|
|
|
|
+ logId=serving_utils.generate_log_id(),
|
|
|
|
|
+ errorCode=0,
|
|
|
|
|
+ errorMsg="Success",
|
|
|
|
|
+ result=chat_result,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logging.exception(e)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="Internal server error")
|
|
|
|
|
+
|
|
|
|
|
+ return app
|