Browse Source

[Fix] Update OCR service schemas and FastAPI apps to accommodate the new API (#2946)

* Update OCR FastAPI apps

* Update PP-ChatOCRv3-doc

* Add table_recognition_v2

* Update doc preprocessor schema
Lin Manhui 9 months ago
parent
commit
f8b16e6a6d

+ 23 - 8
paddlex/inference/serving/basic_serving/_app.py

@@ -24,6 +24,7 @@ from typing import (
     List,
     Optional,
     Tuple,
+    TypedDict,
     TypeVar,
 )
 
@@ -33,7 +34,7 @@ from fastapi.encoders import jsonable_encoder
 from fastapi.exceptions import RequestValidationError
 from fastapi.responses import JSONResponse
 from starlette.exceptions import HTTPException
-from typing_extensions import ParamSpec
+from typing_extensions import ParamSpec, TypeGuard
 
 from ....utils import logging
 from ...pipelines import BasePipeline
@@ -46,6 +47,18 @@ _P = ParamSpec("_P")
 _R = TypeVar("_R")
 
 
+class _Error(TypedDict):
+    error: str
+
+
+def _is_error(obj: object) -> TypeGuard[_Error]:
+    return (
+        isinstance(obj, dict)
+        and obj.keys() == {"error"}
+        and isinstance(obj["error"], str)
+    )
+
+
 # XXX: Since typing info (e.g., the pipeline class) cannot be easily obtained
 # without abstraction leaks, generic classes do not offer additional benefits
 # for type hinting. However, I would stick with the current design, as it does
@@ -63,13 +76,15 @@ class PipelineWrapper(Generic[_PipelineT]):
 
     async def infer(self, *args: Any, **kwargs: Any) -> List[Any]:
         def _infer() -> List[Any]:
-            output = list(self._pipeline(*args, **kwargs))
-            if (
-                len(output) == 1
-                and isinstance(output[0], dict)
-                and output[0].keys() == {"error"}
-            ):
-                raise fastapi.HTTPException(status_code=500, detail=output[0]["error"])
+            output: list = []
+            with contextlib.closing(self._pipeline(*args, **kwargs)) as it:
+                for item in it:
+                    if _is_error(item):
+                        raise fastapi.HTTPException(
+                            status_code=500, detail=item["error"]
+                        )
+                    output.append(item)
+
             return output
 
         return await self.call(_infer)

+ 16 - 8
paddlex/inference/serving/basic_serving/_pipeline_apps/doc_preprocessor.py

@@ -53,15 +53,22 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
         doc_pp_results: List[Dict[str, Any]] = []
         for i, (img, item) in enumerate(zip(images, result)):
             pruned_res = common.prune_result(item.json["res"])
+            output_img = common.postprocess_image(
+                item["output_img"],
+                log_id,
+                "output_img.png",
+                file_storage=ctx.extra["file_storage"],
+                return_url=ctx.extra["return_img_urls"],
+                max_img_size=ctx.extra["max_output_img_size"],
+            )
             if ctx.config.visualize:
-                output_imgs = item.img
-                imgs = {
+                vis_imgs = {
                     "input_img": img,
-                    "doc_preprocessing_img": output_imgs["preprocessed_img"],
+                    "doc_preprocessing_img": item.img["preprocessed_img"],
                 }
-                imgs = await serving_utils.call_async(
+                vis_imgs = await serving_utils.call_async(
                     common.postprocess_images,
-                    imgs,
+                    vis_imgs,
                     log_id,
                     filename_template=f"{{key}}_{i}.jpg",
                     file_storage=ctx.extra["file_storage"],
@@ -69,12 +76,13 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
                     max_img_size=ctx.extra["max_output_img_size"],
                 )
             else:
-                imgs = {}
+                vis_imgs = {}
             doc_pp_results.append(
                 dict(
+                    outputImage=output_img,
                     prunedResult=pruned_res,
-                    docPreprocessingImage=imgs.get("doc_preprocessing_img"),
-                    inputImage=imgs.get("input_img"),
+                    docPreprocessingImage=vis_imgs.get("doc_preprocessing_img"),
+                    inputImage=vis_imgs.get("input_img"),
                 )
             )
 

+ 6 - 9
paddlex/inference/serving/basic_serving/_pipeline_apps/formula_recognition.py

@@ -55,15 +55,10 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
         for i, (img, item) in enumerate(zip(images, result)):
             pruned_res = common.prune_result(item.json["res"])
             if ctx.config.visualize:
-                output_imgs = item.img
                 imgs = {
                     "input_img": img,
-                    "formula_rec_img": output_imgs["formula_res_img"],
+                    **item.img,
                 }
-                if "layout_det_res" in output_imgs:
-                    imgs["layout_det_img"] = output_imgs["layout_det_res"]
-                if "preprocessed_img" in output_imgs:
-                    imgs["doc_preprocessing_img"] = output_imgs["preprocessed_img"]
                 imgs = await serving_utils.call_async(
                     common.postprocess_images,
                     imgs,
@@ -78,9 +73,11 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
             formula_rec_results.append(
                 dict(
                     prunedResult=pruned_res,
-                    formulaRecImage=imgs.get("formula_rec_img"),
-                    layoutDetImage=imgs.get("layout_det_img"),
-                    docPreprocessingImage=imgs.get("doc_preprocessing_img"),
+                    outputImages=(
+                        {k: v for k, v in imgs.items() if k != "input_img"}
+                        if imgs
+                        else None
+                    ),
                     inputImage=imgs.get("input_img"),
                 )
             )

+ 51 - 45
paddlex/inference/serving/basic_serving/_pipeline_apps/layout_parsing.py

@@ -14,15 +14,14 @@
 
 from typing import Any, Dict, List
 
-from fastapi import FastAPI, HTTPException
+from fastapi import FastAPI
 
-from .....utils import logging
 from ...infra import utils as serving_utils
 from ...infra.config import AppConfig
 from ...infra.models import ResultResponse
 from ...schemas.layout_parsing import INFER_ENDPOINT, InferRequest, InferResult
 from .._app import create_app, primary_operation
-from ._common import image as image_common
+from ._common import common
 from ._common import ocr as ocr_common
 
 
@@ -49,52 +48,59 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
 
         result = await pipeline.infer(
             images,
-            use_doc_image_ori_cls_model=request.useImgOrientationCls,
-            use_doc_image_unwarp_model=request.useImgUnwarping,
-            use_seal_text_det_model=request.useSealTextDet,
+            use_doc_orientation_classify=request.useDocOrientationClassify,
+            use_doc_unwarping=request.useDocUnwarping,
+            use_general_ocr=request.useGeneralOcr,
+            use_seal_recognition=request.useSealRecognition,
+            use_table_recognition=request.useTableRecognition,
+            use_formula_recognition=request.useFormulaRecognition,
+            text_det_limit_side_len=request.textDetLimitSideLen,
+            text_det_limit_type=request.textDetLimitType,
+            text_det_thresh=request.textDetThresh,
+            text_det_box_thresh=request.textDetBoxThresh,
+            text_det_unclip_ratio=request.textDetUnclipRatio,
+            text_rec_score_thresh=request.textRecScoreThresh,
+            seal_det_limit_side_len=request.sealDetLimitSideLen,
+            seal_det_limit_type=request.sealDetLimitType,
+            seal_det_thresh=request.sealDetThresh,
+            seal_det_box_thresh=request.sealDetBoxThresh,
+            seal_det_unclip_ratio=request.sealDetUnclipRatio,
+            seal_rec_score_thresh=request.sealRecScoreThresh,
+            layout_nms=request.layoutNms,
+            layout_unclip_ratio=request.layoutUnclipRatio,
+            layout_merge_bboxes_mode=request.layoutMergeBboxesMode,
         )
 
         layout_parsing_results: List[Dict[str, Any]] = []
-        for i, item in enumerate(result):
-            layout_elements: List[Dict[str, Any]] = []
-            for j, subitem in enumerate(
-                item["layout_parsing_result"]["parsing_result"]
-            ):
-                dyn_keys = subitem.keys() - {"input_path", "layout_bbox", "layout"}
-                if len(dyn_keys) != 1:
-                    logging.error("Unexpected result: %s", subitem)
-                    raise HTTPException(
-                        status_code=500,
-                        detail="Internal server error",
-                    )
-                label = next(iter(dyn_keys))
-                if label in ("image", "figure", "img", "fig"):
-                    text = subitem[label]["image_text"]
-                    if ctx.config.visualize:
-                        image = await serving_utils.call_async(
-                            image_common.postprocess_image,
-                            subitem[label]["img"],
-                            log_id=log_id,
-                            filename=f"image_{i}_{j}.jpg",
-                            file_storage=ctx.extra["file_storage"],
-                            return_url=ctx.extra["return_img_urls"],
-                            max_img_size=ctx.extra["max_output_img_size"],
-                        )
-                    else:
-                        image = None
-                else:
-                    text = subitem[label]
-                    image = None
-                layout_elements.append(
-                    dict(
-                        bbox=subitem["layout_bbox"],
-                        label=label,
-                        text=text,
-                        layoutType=subitem["layout"],
-                        image=image,
-                    )
+        for i, (img, item) in enumerate(zip(images, result)):
+            pruned_res = common.prune_result(item.json["res"])
+            if ctx.config.visualize:
+                imgs = {
+                    "input_img": img,
+                    **item.img,
+                }
+                imgs = await serving_utils.call_async(
+                    common.postprocess_images,
+                    imgs,
+                    log_id,
+                    filename_template=f"{{key}}_{i}.jpg",
+                    file_storage=ctx.extra["file_storage"],
+                    return_urls=ctx.extra["return_img_urls"],
+                    max_img_size=ctx.extra["max_output_img_size"],
                 )
-            layout_parsing_results.append(dict(layoutElements=layout_elements))
+            else:
+                imgs = {}
+            layout_parsing_results.append(
+                dict(
+                    prunedResult=pruned_res,
+                    outputImages=(
+                        {k: v for k, v in imgs.items() if k != "input_img"}
+                        if imgs
+                        else None
+                    ),
+                    inputImage=imgs.get("input_img"),
+                )
+            )
 
         return ResultResponse[InferResult](
             logId=log_id,

+ 62 - 80
paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv3_doc.py

@@ -12,8 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import os
-import tempfile
 from typing import Any, Dict, List
 
 from fastapi import FastAPI
@@ -23,35 +21,10 @@ from ...infra.config import AppConfig
 from ...infra.models import ResultResponse
 from ...schemas import pp_chatocrv3_doc as schema
 from .._app import create_app, primary_operation
+from ._common import common
 from ._common import ocr as ocr_common
 
 
-# XXX: Since the pipeline class does not provide serialization and
-# deserialization methods, these are implemented here based on the save-to-path
-# and load-from-path methods.
-def _serialize_vector_info(pipeline: Any, vector_info: dict) -> str:
-    with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f:
-        path = f.name
-    try:
-        pipeline.save_vector(vector_info, path)
-        with open(path, "r", encoding="utf-8") as f:
-            return f.read()
-    finally:
-        os.unlink(path)
-
-
-def _deserialize_vector_info(pipeline: Any, vector_info: str) -> dict:
-    with tempfile.NamedTemporaryFile(
-        "w", encoding="utf-8", suffix=".json", delete=False
-    ) as f:
-        f.write(vector_info)
-        path = f.name
-    try:
-        return pipeline.load_vector(path)
-    finally:
-        os.unlink(path)
-
-
 def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
     app, ctx = create_app(
         pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
@@ -81,46 +54,58 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
             use_general_ocr=request.useGeneralOcr,
             use_seal_recognition=request.useSealRecognition,
             use_table_recognition=request.useTableRecognition,
+            text_det_limit_side_len=request.textDetLimitSideLen,
+            text_det_limit_type=request.textDetLimitType,
+            text_det_thresh=request.textDetThresh,
+            text_det_box_thresh=request.textDetBoxThresh,
+            text_det_unclip_ratio=request.textDetUnclipRatio,
+            text_rec_score_thresh=request.textRecScoreThresh,
+            seal_det_limit_side_len=request.sealDetLimitSideLen,
+            seal_det_limit_type=request.sealDetLimitType,
+            seal_det_thresh=request.sealDetThresh,
+            seal_det_box_thresh=request.sealDetBoxThresh,
+            seal_det_unclip_ratio=request.sealDetUnclipRatio,
+            seal_rec_score_thresh=request.sealRecScoreThresh,
         )
 
-        visual_results: List[Dict[str, Any]] = []
-        for i, (img, item) in enumerate(zip(images, result["layout_parsing_result"])):
-            texts: List[dict] = []
-            for poly, text, score in zip(
-                item["ocr_result"]["dt_polys"],
-                item["ocr_result"]["rec_text"],
-                item["ocr_result"]["rec_score"],
-            ):
-                texts.append(dict(poly=poly, text=text, score=score))
-            tables = [
-                dict(bbox=r["layout_bbox"], html=r["html"])
-                for r in item["table_result"]
-            ]
+        layout_parsing_results: List[Dict[str, Any]] = []
+        visual_info: List[dict] = []
+        for i, (img, item) in enumerate(zip(images, result)):
+            pruned_res = common.prune_result(item["layout_parsing_result"].json["res"])
             if ctx.config.visualize:
-                input_img, layout_img, ocr_img = await ocr_common.postprocess_images(
-                    log_id=log_id,
-                    index=i,
-                    app_context=ctx,
-                    input_image=img,
-                    layout_image=item["layout_result"].img,
-                    ocr_image=item["ocr_result"].img,
+                imgs = {
+                    "input_img": img,
+                    **item["layout_parsing_result"].img,
+                }
+                imgs = await serving_utils.call_async(
+                    common.postprocess_images,
+                    imgs,
+                    log_id,
+                    filename_template=f"{{key}}_{i}.jpg",
+                    file_storage=ctx.extra["file_storage"],
+                    return_urls=ctx.extra["return_img_urls"],
+                    max_img_size=ctx.extra["max_output_img_size"],
                 )
             else:
-                input_img, layout_img, ocr_img = None, None, None
-            visual_result = dict(
-                texts=texts,
-                tables=tables,
-                inputImage=input_img,
-                layoutImage=layout_img,
-                ocrImage=ocr_img,
+                imgs = {}
+            layout_parsing_results.append(
+                dict(
+                    prunedResult=pruned_res,
+                    outputImages=(
+                        {k: v for k, v in imgs.items() if k != "input_img"}
+                        if imgs
+                        else None
+                    ),
+                    inputImage=imgs.get("input_img"),
+                )
             )
-            visual_results.append(visual_result)
+            visual_info.append(item["visual_info"])
 
         return ResultResponse[schema.AnalyzeImagesResult](
             logId=log_id,
             result=schema.AnalyzeImagesResult(
-                visualResults=visual_results,
-                visualInfo=result["visual_info"],
+                layoutParsingResults=layout_parsing_results,
+                visualInfo=visual_info,
                 dataInfo=data_info,
             ),
         )
@@ -135,15 +120,16 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
     ) -> ResultResponse[schema.BuildVectorStoreResult]:
         pipeline = ctx.pipeline
 
+        kwargs: Dict[str, Any] = {"flag_save_bytes_vector": True}
+        if request.minCharacters is not None:
+            kwargs["min_characters"] = request.minCharacters
+        if request.llmRequestInterval is not None:
+            kwargs["llm_request_interval"] = request.llmRequestInterval
+
         vector_info = await serving_utils.call_async(
             pipeline.pipeline.build_vector,
             request.visualInfo,
-            min_characters=request.minCharacters,
-            llm_request_interval=request.llmRequestInterval,
-        )
-
-        vector_info = await serving_utils.call_async(
-            _serialize_vector_info, pipeline.pipeline, vector_info
+            **kwargs,
         )
 
         return ResultResponse[schema.BuildVectorStoreResult](
@@ -161,22 +147,8 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
     ) -> ResultResponse[schema.ChatResult]:
         pipeline = ctx.pipeline
 
-        if request.vectorInfo:
-            vector_info = await serving_utils.call_async(
-                _deserialize_vector_info,
-                pipeline.pipeline,
-                request.vectorInfo,
-            )
-        else:
-            vector_info = None
-
-        result = await serving_utils.call_async(
-            pipeline.pipeline.chat,
-            request.keyList,
-            request.visualInfo,
-            use_vector_retrieval=request.useVectorRetrieval,
-            vector_info=vector_info,
-            min_characters=request.minCharacters,
+        kwargs: Dict[str, Any] = dict(
+            vector_info=request.vectorInfo,
             text_task_description=request.textTaskDescription,
             text_output_format=request.textOutputFormat,
             text_rules_str=request.textRulesStr,
@@ -188,6 +160,16 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
             table_few_shot_demo_text_content=request.tableFewShotDemoTextContent,
             table_few_shot_demo_key_value_list=request.tableFewShotDemoKeyValueList,
         )
+        if request.useVectorRetrieval is not None:
+            kwargs["use_vector_retrieval"] = request.useVectorRetrieval
+        if request.minCharacters is not None:
+            kwargs["min_characters"] = request.minCharacters
+
+        result = await serving_utils.call_async(
+            pipeline.pipeline.chat,
+            request.keyList,
+            request.visualInfo,
+        )
 
         return ResultResponse[schema.ChatResult](
             logId=serving_utils.generate_log_id(),

+ 181 - 0
paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv4_doc.py

@@ -0,0 +1,181 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, List
+
+from fastapi import FastAPI
+
+from ...infra import utils as serving_utils
+from ...infra.config import AppConfig
+from ...infra.models import ResultResponse
+from ...schemas import pp_chatocrv4_doc as schema
+from .._app import create_app, primary_operation
+from ._common import common
+from ._common import ocr as ocr_common
+
+
+def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
+    app, ctx = create_app(
+        pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
+    )
+
+    ocr_common.update_app_context(ctx)
+
+    @primary_operation(
+        app,
+        schema.ANALYZE_IMAGES_ENDPOINT,
+        "analyzeImages",
+    )
+    async def _analyze_images(
+        request: schema.AnalyzeImagesRequest,
+    ) -> ResultResponse[schema.AnalyzeImagesResult]:
+        pipeline = ctx.pipeline
+
+        log_id = serving_utils.generate_log_id()
+
+        images, data_info = await ocr_common.get_images(request, ctx)
+
+        result = await pipeline.call(
+            pipeline.pipeline.visual_predict,
+            images,
+            use_doc_orientation_classify=request.useDocOrientationClassify,
+            use_doc_unwarping=request.useDocUnwarping,
+            use_general_ocr=request.useGeneralOcr,
+            use_seal_recognition=request.useSealRecognition,
+            use_table_recognition=request.useTableRecognition,
+            text_det_limit_side_len=request.textDetLimitSideLen,
+            text_det_limit_type=request.textDetLimitType,
+            text_det_thresh=request.textDetThresh,
+            text_det_box_thresh=request.textDetBoxThresh,
+            text_det_unclip_ratio=request.textDetUnclipRatio,
+            text_rec_score_thresh=request.textRecScoreThresh,
+            seal_det_limit_side_len=request.sealDetLimitSideLen,
+            seal_det_limit_type=request.sealDetLimitType,
+            seal_det_thresh=request.sealDetThresh,
+            seal_det_box_thresh=request.sealDetBoxThresh,
+            seal_det_unclip_ratio=request.sealDetUnclipRatio,
+            seal_rec_score_thresh=request.sealRecScoreThresh,
+        )
+
+        layout_parsing_results: List[Dict[str, Any]] = []
+        visual_info: List[dict] = []
+        for i, (img, item) in enumerate(zip(images, result)):
+            pruned_res = common.prune_result(item["layout_parsing_result"].json["res"])
+            if ctx.config.visualize:
+                imgs = {
+                    "input_img": img,
+                    **item["layout_parsing_result"].img,
+                }
+                imgs = await serving_utils.call_async(
+                    common.postprocess_images,
+                    imgs,
+                    log_id,
+                    filename_template=f"{{key}}_{i}.jpg",
+                    file_storage=ctx.extra["file_storage"],
+                    return_urls=ctx.extra["return_img_urls"],
+                    max_img_size=ctx.extra["max_output_img_size"],
+                )
+            else:
+                imgs = {}
+            layout_parsing_results.append(
+                dict(
+                    prunedResult=pruned_res,
+                    outputImages=(
+                        {k: v for k, v in imgs.items() if k != "input_img"}
+                        if imgs
+                        else None
+                    ),
+                    inputImage=imgs.get("input_img"),
+                )
+            )
+            visual_info.append(item["visual_info"])
+
+        return ResultResponse[schema.AnalyzeImagesResult](
+            logId=log_id,
+            result=schema.AnalyzeImagesResult(
+                layoutParsingResults=layout_parsing_results,
+                visualInfo=visual_info,
+                dataInfo=data_info,
+            ),
+        )
+
+    @primary_operation(
+        app,
+        schema.BUILD_VECTOR_STORE_ENDPOINT,
+        "buildVectorStore",
+    )
+    async def _build_vector_store(
+        request: schema.BuildVectorStoreRequest,
+    ) -> ResultResponse[schema.BuildVectorStoreResult]:
+        pipeline = ctx.pipeline
+
+        kwargs: Dict[str, Any] = {"flag_save_bytes_vector": True}
+        if request.minCharacters is not None:
+            kwargs["min_characters"] = request.minCharacters
+        if request.llmRequestInterval is not None:
+            kwargs["llm_request_interval"] = request.llmRequestInterval
+
+        vector_info = await serving_utils.call_async(
+            pipeline.pipeline.build_vector,
+            request.visualInfo,
+            **kwargs,
+        )
+
+        return ResultResponse[schema.BuildVectorStoreResult](
+            logId=serving_utils.generate_log_id(),
+            result=schema.BuildVectorStoreResult(vectorInfo=vector_info),
+        )
+
+    @primary_operation(
+        app,
+        schema.CHAT_ENDPOINT,
+        "chat",
+    )
+    async def _chat(
+        request: schema.ChatRequest,
+    ) -> ResultResponse[schema.ChatResult]:
+        pipeline = ctx.pipeline
+
+        kwargs: Dict[str, Any] = dict(
+            vector_info=request.vectorInfo,
+            text_task_description=request.textTaskDescription,
+            text_output_format=request.textOutputFormat,
+            text_rules_str=request.textRulesStr,
+            text_few_shot_demo_text_content=request.textFewShotDemoTextContent,
+            text_few_shot_demo_key_value_list=request.textFewShotDemoKeyValueList,
+            table_task_description=request.tableTaskDescription,
+            table_output_format=request.tableOutputFormat,
+            table_rules_str=request.tableRulesStr,
+            table_few_shot_demo_text_content=request.tableFewShotDemoTextContent,
+            table_few_shot_demo_key_value_list=request.tableFewShotDemoKeyValueList,
+        )
+        if request.useVectorRetrieval is not None:
+            kwargs["use_vector_retrieval"] = request.useVectorRetrieval
+        if request.minCharacters is not None:
+            kwargs["min_characters"] = request.minCharacters
+
+        result = await serving_utils.call_async(
+            pipeline.pipeline.chat,
+            request.keyList,
+            request.visualInfo,
+        )
+
+        return ResultResponse[schema.ChatResult](
+            logId=serving_utils.generate_log_id(),
+            result=schema.ChatResult(
+                chatResult=result["chat_res"],
+            ),
+        )
+
+    return app

+ 6 - 9
paddlex/inference/serving/basic_serving/_pipeline_apps/seal_recognition.py

@@ -65,15 +65,10 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
         for i, (img, item) in enumerate(zip(images, result)):
             pruned_res = common.prune_result(item.json["res"])
             if ctx.config.visualize:
-                output_imgs = item.img
                 imgs = {
                     "input_img": img,
-                    "seal_rec_img": output_imgs["seal_res_region1"],
+                    **item.img,
                 }
-                if "layout_det_res" in output_imgs:
-                    imgs["layout_det_img"] = output_imgs["layout_det_res"]
-                if "preprocessed_img" in output_imgs:
-                    imgs["doc_preprocessing_img"] = output_imgs["preprocessed_img"]
                 imgs = await serving_utils.call_async(
                     common.postprocess_images,
                     imgs,
@@ -88,9 +83,11 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
             seal_rec_results.append(
                 dict(
                     prunedResult=pruned_res,
-                    sealRecImage=imgs.get("seal_rec_img"),
-                    layoutDetImage=imgs.get("layout_det_img"),
-                    docPreprocessingImage=imgs.get("doc_preprocessing_img"),
+                    outputImages=(
+                        {k: v for k, v in imgs.items() if k != "input_img"}
+                        if imgs
+                        else None
+                    ),
                     inputImage=imgs.get("input_img"),
                 )
             )

+ 7 - 10
paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition.py

@@ -66,15 +66,10 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
         for i, (img, item) in enumerate(zip(images, result)):
             pruned_res = common.prune_result(item.json["res"])
             if ctx.config.visualize:
-                output_imgs = item.img
                 imgs = {
                     "input_img": img,
-                    "ocr_img": output_imgs["ocr_res_img"],
+                    **item.img,
                 }
-                if "layout_det_res" in output_imgs:
-                    imgs["layout_det_img"] = output_imgs["layout_det_res"]
-                if "preprocessed_img" in output_imgs:
-                    imgs["doc_preprocessing_img"] = output_imgs["preprocessed_img"]
                 imgs = await serving_utils.call_async(
                     common.postprocess_images,
                     imgs,
@@ -89,15 +84,17 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
             table_rec_results.append(
                 dict(
                     prunedResult=pruned_res,
-                    ocrImage=imgs.get("ocr_img"),
-                    layoutDetImage=imgs.get("layout_det_img"),
-                    docPreprocessingImage=imgs.get("doc_preprocessing_img"),
+                    outputImages=(
+                        {k: v for k, v in imgs.items() if k != "input_img"}
+                        if imgs
+                        else None
+                    ),
                     inputImage=imgs.get("input_img"),
                 )
             )
 
         return ResultResponse[InferResult](
-            logId=serving_utils.generate_log_id(),
+            logId=log_id,
             result=InferResult(
                 tableRecResults=table_rec_results,
                 dataInfo=data_info,

+ 104 - 0
paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition_v2.py

@@ -0,0 +1,104 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, List
+
+from fastapi import FastAPI
+
+from ...infra import utils as serving_utils
+from ...infra.config import AppConfig
+from ...infra.models import ResultResponse
+from ...schemas.table_recognition_v2 import INFER_ENDPOINT, InferRequest, InferResult
+from .._app import create_app, primary_operation
+from ._common import common
+from ._common import ocr as ocr_common
+
+
+def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
+    app, ctx = create_app(
+        pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
+    )
+
+    ocr_common.update_app_context(ctx)
+
+    @primary_operation(
+        app,
+        INFER_ENDPOINT,
+        "infer",
+    )
+    async def _infer(request: InferRequest) -> ResultResponse[InferResult]:
+        pipeline = ctx.pipeline
+
+        log_id = serving_utils.generate_log_id()
+
+        images, data_info = await ocr_common.get_images(request, ctx)
+
+        result = await pipeline.infer(
+            images,
+            use_doc_orientation_classify=request.useDocOrientationClassify,
+            use_doc_unwarping=request.useDocUnwarping,
+            use_layout_detection=request.useLayoutDetection,
+            use_ocr_model=request.useOcrModel,
+            layout_threshold=request.layoutThreshold,
+            layout_nms=request.layoutNms,
+            layout_unclip_ratio=request.layoutUnclipRatio,
+            layout_merge_bboxes_mode=request.layoutMergeBboxesMode,
+            text_det_limit_side_len=request.textDetLimitSideLen,
+            text_det_limit_type=request.textDetLimitType,
+            text_det_thresh=request.textDetThresh,
+            text_det_box_thresh=request.textDetBoxThresh,
+            text_det_unclip_ratio=request.textDetUnclipRatio,
+            text_rec_score_thresh=request.textRecScoreThresh,
+        )
+
+        table_rec_results: List[Dict[str, Any]] = []
+        for i, (img, item) in enumerate(zip(images, result)):
+            pruned_res = common.prune_result(item.json["res"])
+            if ctx.config.visualize:
+                imgs = {
+                    "input_img": img,
+                    **item.img,
+                }
+                imgs = await serving_utils.call_async(
+                    common.postprocess_images,
+                    imgs,
+                    log_id,
+                    filename_template=f"{{key}}_{i}.jpg",
+                    file_storage=ctx.extra["file_storage"],
+                    return_urls=ctx.extra["return_img_urls"],
+                    max_img_size=ctx.extra["max_output_img_size"],
+                )
+            else:
+                imgs = {}
+            table_rec_results.append(
+                dict(
+                    prunedResult=pruned_res,
+                    outputImages=(
+                        {k: v for k, v in imgs.items() if k != "input_img"}
+                        if imgs
+                        else None
+                    ),
+                    inputImage=imgs.get("input_img"),
+                )
+            )
+
+        return ResultResponse[InferResult](
+            logId=log_id,
+            result=InferResult(
+                tableRecResults=table_rec_results,
+                dataInfo=data_info,
+            ),
+        )
+
+    return app

+ 1 - 2
paddlex/inference/serving/basic_serving/_server.py

@@ -18,8 +18,7 @@ import uvicorn
 from fastapi import FastAPI
 
 
-def run_server(app: FastAPI, *, host: str, port: int, debug: bool) -> None:
-    # XXX: Currently, `debug` is not used.
+def run_server(app: FastAPI, *, host: str, port: int) -> None:
     # HACK: Fix duplicate logs
     uvicorn_version = tuple(int(x) for x in uvicorn.__version__.split("."))
     if uvicorn_version < (0, 19, 0):

+ 1 - 0
paddlex/inference/serving/schemas/doc_preprocessor.py

@@ -38,6 +38,7 @@ class InferRequest(ocr.BaseInferRequest):
 
 
 class DocPreprocessingResult(BaseModel):
+    outputImage: str
     prunedResult: dict
     docPreprocessingImage: Optional[str] = None
     inputImage: Optional[str] = None

+ 2 - 4
paddlex/inference/serving/schemas/formula_recognition.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Final, List, Optional
+from typing import Dict, Final, List, Optional
 
 from pydantic import BaseModel
 
@@ -38,9 +38,7 @@ class InferRequest(ocr.BaseInferRequest):
 
 class FormulaRecResult(BaseModel):
     prunedResult: dict
-    formulaRecImage: Optional[str] = None
-    layoutDetImage: Optional[str] = None
-    docPreprocessingImage: Optional[str] = None
+    outputImages: Optional[Dict[str, str]] = None
     inputImage: Optional[str] = None
 
 

+ 8 - 13
paddlex/inference/serving/schemas/layout_parsing.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Final, List, Optional, Union
+from typing import Dict, Final, List, Optional, Union
 
 from pydantic import BaseModel, Field
 from typing_extensions import Annotated, Literal
@@ -38,12 +38,6 @@ class InferRequest(ocr.BaseInferRequest):
     useSealRecognition: Optional[bool] = None
     useTableRecognition: Optional[bool] = None
     useFormulaRecognition: Optional[bool] = None
-    layoutThreshold: Optional[float] = None
-    layoutNms: Optional[bool] = None
-    layoutUnclipRatio: Optional[
-        Union[float, Annotated[List[float], Field(min_length=2, max_length=2)]]
-    ] = None
-    layoutMergeBboxesMode: Optional[Literal["union", "large", "small"]] = None
     textDetLimitSideLen: Optional[int] = None
     textDetLimitType: Optional[Literal["min", "max"]] = None
     textDetThresh: Optional[float] = None
@@ -56,16 +50,17 @@ class InferRequest(ocr.BaseInferRequest):
     sealDetBoxThresh: Optional[float] = None
     sealDetUnclipRatio: Optional[float] = None
     sealRecScoreThresh: Optional[float] = None
+    layoutThreshold: Optional[float] = None
+    layoutNms: Optional[bool] = None
+    layoutUnclipRatio: Optional[
+        Union[float, Annotated[List[float], Field(min_length=2, max_length=2)]]
+    ] = None
+    layoutMergeBboxesMode: Optional[Literal["union", "large", "small"]] = None
 
 
 class LayoutParsingResult(BaseModel):
     prunedResult: dict
-    ocrImage: Optional[str] = None
-    sealRecImage: Optional[str] = None
-    tableRecImage: Optional[str] = None
-    formulaRecImage: Optional[str] = None
-    layoutDetImage: Optional[str] = None
-    docPreprocessingImage: Optional[str] = None
+    outputImages: Optional[Dict[str, str]] = None
     inputImage: Optional[str] = None
 
 

+ 11 - 11
paddlex/inference/serving/schemas/pp_chatocrv3_doc.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Final, List, Optional
+from typing import Dict, Final, List, Optional
 
 from pydantic import BaseModel
 from typing_extensions import Literal
@@ -23,7 +23,7 @@ from .shared import ocr
 __all__ = [
     "ANALYZE_IMAGES_ENDPOINT",
     "AnalyzeImagesRequest",
-    "VisualResult",
+    "LayoutParsingResult",
     "AnalyzeImagesResult",
     "BUILD_VECTOR_STORE_ENDPOINT",
     "BuildVectorStoreRequest",
@@ -57,17 +57,17 @@ class AnalyzeImagesRequest(ocr.BaseInferRequest):
     sealRecScoreThresh: Optional[float] = None
 
 
-class VisualResult(BaseModel):
+class LayoutParsingResult(BaseModel):
     prunedResult: dict
-    ocrImage: Optional[str] = None
-    layoutDetImage: Optional[str] = None
-    docPreprocessingImage: Optional[str] = None
+    outputImages: Optional[Dict[str, str]] = None
     inputImage: Optional[str] = None
 
 
 class AnalyzeImagesResult(BaseModel):
-    visualResults: List[VisualResult]
-    visualInfo: dict
+    layoutParsingResults: List[LayoutParsingResult]
+    # `visualInfo` is made a separate field to facilitate its use in subsequent
+    # steps.
+    visualInfo: List[dict]
     dataInfo: DataInfo
 
 
@@ -75,7 +75,7 @@ BUILD_VECTOR_STORE_ENDPOINT: Final[str] = "/chatocr-vector"
 
 
 class BuildVectorStoreRequest(BaseModel):
-    visualInfo: dict
+    visualInfo: List[dict]
     minCharacters: Optional[int] = None
     llmRequestInterval: Optional[float] = None
 
@@ -89,9 +89,9 @@ CHAT_ENDPOINT: Final[str] = "/chatocr-chat"
 
 class ChatRequest(BaseModel):
     keyList: List[str]
-    visualInfo: dict
+    visualInfo: List[dict]
     useVectorRetrieval: Optional[bool] = None
-    vectorInfo: Optional[str] = None
+    vectorInfo: Optional[dict] = None
     minCharacters: Optional[int] = None
     textTaskDescription: Optional[str] = None
     textOutputFormat: Optional[str] = None

+ 128 - 0
paddlex/inference/serving/schemas/pp_chatocrv4_doc.py

@@ -0,0 +1,128 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, Final, List, Optional
+
+from pydantic import BaseModel
+from typing_extensions import Literal
+
+from ..infra.models import DataInfo, PrimaryOperations
+from .shared import ocr
+
+__all__ = [
+    "ANALYZE_IMAGES_ENDPOINT",
+    "AnalyzeImagesRequest",
+    "LayoutParsingResult",
+    "AnalyzeImagesResult",
+    "BUILD_VECTOR_STORE_ENDPOINT",
+    "BuildVectorStoreRequest",
+    "BuildVectorStoreResult",
+    "CHAT_ENDPOINT",
+    "ChatRequest",
+    "ChatResult",
+    "PRIMARY_OPERATIONS",
+]
+
+ANALYZE_IMAGES_ENDPOINT: Final[str] = "/chatocr-visual"
+
+
+class AnalyzeImagesRequest(ocr.BaseInferRequest):
+    useDocOrientationClassify: Optional[bool] = None
+    useDocUnwarping: Optional[bool] = None
+    useGeneralOcr: Optional[bool] = None
+    useSealRecognition: Optional[bool] = None
+    useTableRecognition: Optional[bool] = None
+    textDetLimitSideLen: Optional[int] = None
+    textDetLimitType: Optional[Literal["min", "max"]] = None
+    textDetThresh: Optional[float] = None
+    textDetBoxThresh: Optional[float] = None
+    textDetUnclipRatio: Optional[float] = None
+    textRecScoreThresh: Optional[float] = None
+    sealDetLimitSideLen: Optional[int] = None
+    sealDetLimitType: Optional[Literal["min", "max"]] = None
+    sealDetThresh: Optional[float] = None
+    sealDetBoxThresh: Optional[float] = None
+    sealDetUnclipRatio: Optional[float] = None
+    sealRecScoreThresh: Optional[float] = None
+
+
+class LayoutParsingResult(BaseModel):
+    prunedResult: dict
+    outputImages: Optional[Dict[str, str]] = None
+    inputImage: Optional[str] = None
+
+
+class AnalyzeImagesResult(BaseModel):
+    layoutParsingResults: List[LayoutParsingResult]
+    # `visualInfo` is made a separate field to facilitate its use in subsequent
+    # steps.
+    visualInfo: List[dict]
+    dataInfo: DataInfo
+
+
+BUILD_VECTOR_STORE_ENDPOINT: Final[str] = "/chatocr-vector"
+
+
+class BuildVectorStoreRequest(BaseModel):
+    visualInfo: List[dict]
+    minCharacters: Optional[int] = None
+    llmRequestInterval: Optional[float] = None
+
+
+class BuildVectorStoreResult(BaseModel):
+    vectorInfo: dict
+
+
+CHAT_ENDPOINT: Final[str] = "/chatocr-chat"
+
+
+class ChatRequest(BaseModel):
+    keyList: List[str]
+    visualInfo: List[dict]
+    useVectorRetrieval: Optional[bool] = None
+    vectorInfo: Optional[dict] = None
+    minCharacters: Optional[int] = None
+    textTaskDescription: Optional[str] = None
+    textOutputFormat: Optional[str] = None
+    # Is the "Str" in the name unnecessary? Keep the names consistent with the
+    # parameters of the wrapped function though.
+    textRulesStr: Optional[str] = None
+    # Should this be just "text" instead of "text content", given that there is
+    # no container?
+    textFewShotDemoTextContent: Optional[str] = None
+    textFewShotDemoKeyValueList: Optional[str] = None
+    tableTaskDescription: Optional[str] = None
+    tableOutputFormat: Optional[str] = None
+    tableRulesStr: Optional[str] = None
+    tableFewShotDemoTextContent: Optional[str] = None
+    tableFewShotDemoKeyValueList: Optional[str] = None
+
+
+class ChatResult(BaseModel):
+    chatResult: dict
+
+
+PRIMARY_OPERATIONS: Final[PrimaryOperations] = {
+    "analyzeImages": (
+        ANALYZE_IMAGES_ENDPOINT,
+        AnalyzeImagesRequest,
+        AnalyzeImagesResult,
+    ),
+    "buildVectorStore": (
+        BUILD_VECTOR_STORE_ENDPOINT,
+        BuildVectorStoreRequest,
+        BuildVectorStoreResult,
+    ),
+    "chat": (CHAT_ENDPOINT, ChatRequest, ChatResult),
+}

+ 2 - 4
paddlex/inference/serving/schemas/seal_recognition.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Final, List, Optional, Union
+from typing import Dict, Final, List, Optional, Union
 
 from pydantic import BaseModel, Field
 from typing_extensions import Annotated, Literal
@@ -51,9 +51,7 @@ class InferRequest(ocr.BaseInferRequest):
 
 class SealRecResult(BaseModel):
     prunedResult: dict
-    sealRecImage: Optional[str] = None
-    layoutDetImage: Optional[str] = None
-    docPreprocessingImage: Optional[str] = None
+    outputImages: Optional[Dict[str, str]] = None
     inputImage: Optional[str] = None
 
 

+ 2 - 4
paddlex/inference/serving/schemas/table_recognition.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Final, List, Optional, Union
+from typing import Dict, Final, List, Optional, Union
 
 from pydantic import BaseModel, Field
 from typing_extensions import Annotated, Literal
@@ -52,9 +52,7 @@ class InferRequest(ocr.BaseInferRequest):
 
 class TableRecResult(BaseModel):
     prunedResult: dict
-    tableRecImage: Optional[str] = None
-    layoutDetImage: Optional[str] = None
-    docPreprocessingImage: Optional[str] = None
+    outputImages: Optional[Dict[str, str]] = None
     inputImage: Optional[str] = None
 
 

+ 66 - 0
paddlex/inference/serving/schemas/table_recognition_v2.py

@@ -0,0 +1,66 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, Final, List, Optional, Union
+
+from pydantic import BaseModel, Field
+from typing_extensions import Annotated, Literal
+
+from ..infra.models import DataInfo, PrimaryOperations
+from .shared import ocr
+
+__all__ = [
+    "INFER_ENDPOINT",
+    "InferRequest",
+    "TableRecResult",
+    "InferResult",
+    "PRIMARY_OPERATIONS",
+]
+
+INFER_ENDPOINT: Final[str] = "/table-recognition"
+
+
+class InferRequest(ocr.BaseInferRequest):
+    useDocOrientationClassify: Optional[bool] = None
+    useDocUnwarping: Optional[bool] = None
+    useLayoutDetection: Optional[bool] = None
+    useOcrModel: Optional[bool] = None
+    layoutThreshold: Optional[float] = None
+    layoutNms: Optional[bool] = None
+    layoutUnclipRatio: Optional[
+        Union[float, Annotated[List[float], Field(min_length=2, max_length=2)]]
+    ] = None
+    layoutMergeBboxesMode: Optional[Literal["union", "large", "small"]] = None
+    textDetLimitSideLen: Optional[int] = None
+    textDetLimitType: Optional[Literal["min", "max"]] = None
+    textDetThresh: Optional[float] = None
+    textDetBoxThresh: Optional[float] = None
+    textDetUnclipRatio: Optional[float] = None
+    textRecScoreThresh: Optional[float] = None
+
+
+class TableRecResult(BaseModel):
+    prunedResult: dict
+    outputImages: Optional[Dict[str, str]] = None
+    inputImage: Optional[str] = None
+
+
+class InferResult(BaseModel):
+    tableRecResults: List[TableRecResult]
+    dataInfo: DataInfo
+
+
+PRIMARY_OPERATIONS: Final[PrimaryOperations] = {
+    "infer": (INFER_ENDPOINT, InferRequest, InferResult),
+}

+ 5 - 2
paddlex/paddlex_cli.py

@@ -167,7 +167,10 @@ def args_cfg():
     pipeline_name = args.pipeline
     pipeline_args = []
 
-    if not args.install and pipeline_name is not None:
+    if (
+        not (args.install or args.serve or args.paddle2onnx)
+        and pipeline_name is not None
+    ):
 
         if pipeline_name not in PIPELINE_ARGUMENTS:
             support_pipelines = ", ".join(PIPELINE_ARGUMENTS.keys())
@@ -313,7 +316,7 @@ def serve(pipeline, *, device, use_hpip, host, port):
     pipeline_config = load_pipeline_config(pipeline)
     pipeline = create_pipeline(config=pipeline_config, device=device, use_hpip=use_hpip)
     app = create_pipeline_app(pipeline, pipeline_config)
-    run_server(app, host=host, port=port, debug=False)
+    run_server(app, host=host, port=port)
 
 
 # TODO: Move to another module