Selaa lähdekoodia

support kpt pipeline serving (#2982)

学卿 9 kuukautta sitten
vanhempi
commit
47f20f7c23

+ 78 - 0
paddlex/inference/serving/basic_serving/_pipeline_apps/human_keypoint_detection.py

@@ -0,0 +1,78 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, List
+
+from fastapi import FastAPI
+
+from ...infra import utils as serving_utils
+from ...infra.config import AppConfig
+from ...infra.models import ResultResponse
+from ...schemas.human_keypoint_detection import (
+    INFER_ENDPOINT,
+    InferRequest,
+    InferResult,
+)
+from .._app import create_app, primary_operation
+
+
+def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
+    app, ctx = create_app(
+        pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
+    )
+
+    @primary_operation(
+        app,
+        INFER_ENDPOINT,
+        "infer",
+    )
+    async def _infer(request: InferRequest) -> ResultResponse[InferResult]:
+        pipeline = ctx.pipeline
+        aiohttp_session = ctx.aiohttp_session
+
+        file_bytes = await serving_utils.get_raw_bytes_async(
+            request.image, aiohttp_session
+        )
+        image = serving_utils.image_bytes_to_array(file_bytes)
+
+        result = (
+            await pipeline.infer(
+                image,
+                det_threshold=request.detThreshold,
+            )
+        )[0]
+
+        objs: List[Dict[str, Any]] = []
+        for obj in result["boxes"]:
+            objs.append(
+                dict(
+                    bbox=obj["coordinate"],
+                    kpts=obj["keypoints"].tolist(),
+                    detScore=obj["det_score"],
+                    kptScore=obj["kpt_score"],
+                )
+            )
+        if ctx.config.visualize:
+            output_image_base64 = serving_utils.base64_encode(
+                serving_utils.image_to_bytes(result.img["res"])
+            )
+        else:
+            output_image_base64 = None
+
+        return ResultResponse[InferResult](
+            logId=serving_utils.generate_log_id(),
+            result=InferResult(persons=objs, image=output_image_base64),
+        )
+
+    return app

+ 54 - 0
paddlex/inference/serving/schemas/human_keypoint_detection.py

@@ -0,0 +1,54 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Final, List, Optional, TypeAlias, Annotated
+
+from pydantic import BaseModel, Field
+
+from ..infra.models import PrimaryOperations
+from .shared import object_detection
+
+__all__ = [
+    "INFER_ENDPOINT",
+    "InferRequest",
+    "KeyPoint",
+    "Person",
+    "InferResult",
+    "PRIMARY_OPERATIONS",
+]
+
+KeyPoint: TypeAlias = Annotated[List[float], Field(min_length=3, max_length=3)]
+INFER_ENDPOINT: Final[str] = "/human-keypoint-detection"
+
+
+class InferRequest(BaseModel):
+    image: str
+    detThreshold: Optional[float] = None
+
+
+class Person(BaseModel):
+    bbox: object_detection.BoundingBox
+    kpts: List[KeyPoint]
+    detScore: float
+    kptScore: float
+
+
+class InferResult(BaseModel):
+    persons: List[Person]
+    image: Optional[str] = None
+
+
+PRIMARY_OPERATIONS: Final[PrimaryOperations] = {
+    "infer": (INFER_ENDPOINT, InferRequest, InferResult),
+}