pp_shitu_v2.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. import faiss
  16. import pickle
  17. from typing import Dict, List, Optional
  18. from fastapi import FastAPI, HTTPException
  19. from pydantic import BaseModel, Field
  20. from typing_extensions import Annotated, TypeAlias
  21. from .....utils import logging
  22. from ....components.retrieval.faiss import IndexData
  23. from ...pp_shitu_v2 import ShiTuV2Pipeline
  24. from ..storage import create_storage
  25. from .. import utils as serving_utils
  26. from ..app import AppConfig, create_app
  27. from ..models import Response, ResultResponse
  28. class ImageLabelPair(BaseModel):
  29. image: str
  30. label: str
  31. class BuildIndexRequest(BaseModel):
  32. imageLabelPairs: List[ImageLabelPair]
  33. class BuildIndexResult(BaseModel):
  34. indexKey: str
  35. idMap: Dict[int, str]
  36. class AddImagesToIndexRequest(BaseModel):
  37. imageLabelPairs: List[ImageLabelPair]
  38. indexKey: str
  39. class AddImagesToIndexResult(BaseModel):
  40. idMap: Dict[int, str]
  41. class RemoveImagesFromIndexRequest(BaseModel):
  42. ids: List[int]
  43. indexKey: str
  44. class RemoveImagesFromIndexResult(BaseModel):
  45. idMap: Dict[int, str]
  46. class InferRequest(BaseModel):
  47. image: str
  48. indexKey: Optional[str] = None
  49. BoundingBox: TypeAlias = Annotated[List[float], Field(min_length=4, max_length=4)]
  50. class RecResult(BaseModel):
  51. label: str
  52. score: float
  53. class DetectedObject(BaseModel):
  54. bbox: BoundingBox
  55. recResults: List[RecResult]
  56. score: float
  57. class InferResult(BaseModel):
  58. detectedObjects: List[DetectedObject]
  59. image: str
  60. # XXX: I have to implement serialization and deserialization functions myself,
  61. # which is fragile.
  62. def _serialize_index_data(index_data: IndexData) -> bytes:
  63. tup = (index_data.index_bytes, index_data.index_info)
  64. return pickle.dumps(tup)
  65. def _deserialize_index_data(index_data_bytes: bytes) -> IndexData:
  66. tup = pickle.loads(index_data_bytes)
  67. index = faiss.deserialize_index(tup[0])
  68. return IndexData(index, tup[1])
  69. def create_pipeline_app(pipeline: ShiTuV2Pipeline, app_config: AppConfig) -> FastAPI:
  70. app, ctx = create_app(
  71. pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
  72. )
  73. if ctx.config.extra and "index_storage" in ctx.config.extra:
  74. ctx.extra["index_storage"] = create_storage(ctx.config.extra["index_storage"])
  75. else:
  76. ctx.extra["index_storage"] = create_storage({"type": "memory"})
  77. @app.post(
  78. "/shitu-index-build",
  79. operation_id="buildIndex",
  80. responses={422: {"model": Response}},
  81. )
  82. async def _build_index(
  83. request: BuildIndexRequest,
  84. ) -> ResultResponse[BuildIndexResult]:
  85. pipeline = ctx.pipeline
  86. aiohttp_session = ctx.aiohttp_session
  87. request_id = serving_utils.generate_request_id()
  88. try:
  89. images = [pair.image for pair in request.imageLabelPairs]
  90. file_bytes_list = await asyncio.gather(
  91. *(serving_utils.get_raw_bytes(img, aiohttp_session) for img in images)
  92. )
  93. images = [
  94. serving_utils.image_bytes_to_array(item) for item in file_bytes_list
  95. ]
  96. labels = [pair.label for pair in request.imageLabelPairs]
  97. # TODO: Support specifying `index_type` and `metric_type` in the
  98. # request
  99. index_data = await pipeline.call(
  100. pipeline.pipeline.build_index,
  101. images,
  102. labels,
  103. index_type="Flat",
  104. metric_type="IP",
  105. )
  106. index_storage = ctx.extra["index_storage"]
  107. index_key = request_id
  108. index_data_bytes = await serving_utils.call_async(
  109. _serialize_index_data, index_data
  110. )
  111. await serving_utils.call_async(
  112. index_storage.set, index_key, index_data_bytes
  113. )
  114. return ResultResponse(
  115. logId=serving_utils.generate_log_id(),
  116. errorCode=0,
  117. errorMsg="Success",
  118. result=BuildIndexResult(indexKey=index_key, idMap=index_data.id_map),
  119. )
  120. except Exception as e:
  121. logging.exception(e)
  122. raise HTTPException(status_code=500, detail="Internal server error")
  123. @app.post(
  124. "/shitu-index-add",
  125. operation_id="buildIndex",
  126. responses={422: {"model": Response}},
  127. )
  128. async def _add_images_to_index(
  129. request: AddImagesToIndexRequest,
  130. ) -> ResultResponse[AddImagesToIndexResult]:
  131. pipeline = ctx.pipeline
  132. aiohttp_session = ctx.aiohttp_session
  133. try:
  134. images = [pair.image for pair in request.imageLabelPairs]
  135. file_bytes_list = await asyncio.gather(
  136. *(serving_utils.get_raw_bytes(img, aiohttp_session) for img in images)
  137. )
  138. images = [
  139. serving_utils.image_bytes_to_array(item) for item in file_bytes_list
  140. ]
  141. labels = [pair.label for pair in request.imageLabelPairs]
  142. index_storage = ctx.extra["index_storage"]
  143. index_data_bytes = await serving_utils.call_async(
  144. index_storage.get, request.indexKey
  145. )
  146. index_data = await serving_utils.call_async(
  147. _deserialize_index_data, index_data_bytes
  148. )
  149. index_data = await pipeline.call(
  150. pipeline.pipeline.append_index, images, labels, index_data
  151. )
  152. index_data_bytes = await serving_utils.call_async(
  153. _serialize_index_data, index_data
  154. )
  155. await serving_utils.call_async(
  156. index_storage.set, request.indexKey, index_data_bytes
  157. )
  158. return ResultResponse(
  159. logId=serving_utils.generate_log_id(),
  160. errorCode=0,
  161. errorMsg="Success",
  162. result=AddImagesToIndexResult(idMap=index_data.id_map),
  163. )
  164. except Exception as e:
  165. logging.exception(e)
  166. raise HTTPException(status_code=500, detail="Internal server error")
  167. @app.post(
  168. "/shitu-index-remove",
  169. operation_id="buildIndex",
  170. responses={422: {"model": Response}},
  171. )
  172. async def _remove_images_from_index(
  173. request: RemoveImagesFromIndexRequest,
  174. ) -> ResultResponse[RemoveImagesFromIndexResult]:
  175. pipeline = ctx.pipeline
  176. try:
  177. index_storage = ctx.extra["index_storage"]
  178. index_data_bytes = await serving_utils.call_async(
  179. index_storage.get, request.indexKey
  180. )
  181. index_data = await serving_utils.call_async(
  182. _deserialize_index_data, index_data_bytes
  183. )
  184. index_data = await pipeline.call(
  185. pipeline.pipeline.remove_index, request.ids, index_data
  186. )
  187. index_data_bytes = await serving_utils.call_async(
  188. _serialize_index_data, index_data
  189. )
  190. await serving_utils.call_async(
  191. index_storage.set, request.indexKey, index_data_bytes
  192. )
  193. return ResultResponse(
  194. logId=serving_utils.generate_log_id(),
  195. errorCode=0,
  196. errorMsg="Success",
  197. result=RemoveImagesFromIndexResult(idMap=index_data.id_map),
  198. )
  199. except Exception as e:
  200. logging.exception(e)
  201. raise HTTPException(status_code=500, detail="Internal server error")
  202. @app.post(
  203. "/shitu-infer",
  204. operation_id="infer",
  205. responses={422: {"model": Response}},
  206. )
  207. async def _infer(request: InferRequest) -> ResultResponse[InferResult]:
  208. pipeline = ctx.pipeline
  209. aiohttp_session = ctx.aiohttp_session
  210. try:
  211. image_bytes = await serving_utils.get_raw_bytes(
  212. request.image, aiohttp_session
  213. )
  214. image = serving_utils.image_bytes_to_array(image_bytes)
  215. if request.indexKey is not None:
  216. index_storage = ctx.extra["index_storage"]
  217. index_data_bytes = await serving_utils.call_async(
  218. index_storage.get, request.indexKey
  219. )
  220. index_data = await serving_utils.call_async(
  221. _deserialize_index_data, index_data_bytes
  222. )
  223. else:
  224. index_data = None
  225. result = list(
  226. await pipeline.call(pipeline.pipeline.predict, image, index=index_data)
  227. )[0]
  228. objects: List[DetectedObject] = []
  229. for obj in result["boxes"]:
  230. rec_results: List[RecResult] = []
  231. if obj["rec_scores"] is not None:
  232. for label, score in zip(obj["labels"], obj["rec_scores"]):
  233. rec_results.append(
  234. RecResult(
  235. label=label,
  236. score=score,
  237. )
  238. )
  239. objects.append(
  240. DetectedObject(
  241. bbox=obj["coordinate"],
  242. recResults=rec_results,
  243. score=obj["det_score"],
  244. )
  245. )
  246. output_image_base64 = serving_utils.base64_encode(
  247. serving_utils.image_to_bytes(result.img)
  248. )
  249. return ResultResponse(
  250. logId=serving_utils.generate_log_id(),
  251. errorCode=0,
  252. errorMsg="Success",
  253. result=InferResult(detectedObjects=objects, image=output_image_base64),
  254. )
  255. except Exception as e:
  256. logging.exception(e)
  257. raise HTTPException(status_code=500, detail="Internal server error")
  258. return app