face_recognition.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. import faiss
  16. import pickle
  17. from typing import Dict, List, Optional
  18. from fastapi import FastAPI, HTTPException
  19. from pydantic import BaseModel, Field
  20. from typing_extensions import Annotated, TypeAlias
  21. from .....utils import logging
  22. from ....components.retrieval.faiss import IndexData
  23. from ...face_recognition import FaceRecPipeline
  24. from ..storage import create_storage
  25. from .. import utils as serving_utils
  26. from ..app import AppConfig, create_app
  27. from ..models import Response, ResultResponse
  28. class ImageLabelPair(BaseModel):
  29. image: str
  30. label: str
  31. class BuildIndexRequest(BaseModel):
  32. imageLabelPairs: List[ImageLabelPair]
  33. class BuildIndexResult(BaseModel):
  34. indexKey: str
  35. idMap: Dict[int, str]
  36. class AddImagesToIndexRequest(BaseModel):
  37. imageLabelPairs: List[ImageLabelPair]
  38. indexKey: str
  39. class AddImagesToIndexResult(BaseModel):
  40. idMap: Dict[int, str]
  41. class RemoveImagesFromIndexRequest(BaseModel):
  42. ids: List[int]
  43. indexKey: str
  44. class RemoveImagesFromIndexResult(BaseModel):
  45. idMap: Dict[int, str]
  46. class InferRequest(BaseModel):
  47. image: str
  48. indexKey: Optional[str] = None
  49. BoundingBox: TypeAlias = Annotated[List[float], Field(min_length=4, max_length=4)]
  50. class RecResult(BaseModel):
  51. label: str
  52. score: float
  53. class Face(BaseModel):
  54. bbox: BoundingBox
  55. recResults: List[RecResult]
  56. score: float
  57. class InferResult(BaseModel):
  58. faces: List[Face]
  59. image: str
  60. def _serialize_index_data(index_data: IndexData) -> bytes:
  61. tup = (index_data.index_bytes, index_data.index_info)
  62. return pickle.dumps(tup)
  63. def _deserialize_index_data(index_data_bytes: bytes) -> IndexData:
  64. tup = pickle.loads(index_data_bytes)
  65. index = faiss.deserialize_index(tup[0])
  66. return IndexData(index, tup[1])
  67. def create_pipeline_app(pipeline: FaceRecPipeline, app_config: AppConfig) -> FastAPI:
  68. app, ctx = create_app(
  69. pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
  70. )
  71. if ctx.config.extra and "index_storage" in ctx.config.extra:
  72. ctx.extra["index_storage"] = create_storage(ctx.config.extra["index_storage"])
  73. else:
  74. ctx.extra["index_storage"] = create_storage({"type": "memory"})
  75. @app.post(
  76. "/face-recognition-index-build",
  77. operation_id="buildIndex",
  78. responses={422: {"model": Response}},
  79. )
  80. async def _build_index(
  81. request: BuildIndexRequest,
  82. ) -> ResultResponse[BuildIndexResult]:
  83. pipeline = ctx.pipeline
  84. aiohttp_session = ctx.aiohttp_session
  85. request_id = serving_utils.generate_request_id()
  86. try:
  87. images = [pair.image for pair in request.imageLabelPairs]
  88. file_bytes_list = await asyncio.gather(
  89. *(serving_utils.get_raw_bytes(img, aiohttp_session) for img in images)
  90. )
  91. images = [
  92. serving_utils.image_bytes_to_array(item) for item in file_bytes_list
  93. ]
  94. labels = [pair.label for pair in request.imageLabelPairs]
  95. # TODO: Support specifying `index_type` and `metric_type` in the
  96. # request
  97. index_data = await pipeline.call(
  98. pipeline.pipeline.build_index,
  99. images,
  100. labels,
  101. index_type="Flat",
  102. metric_type="IP",
  103. )
  104. index_storage = ctx.extra["index_storage"]
  105. index_key = request_id
  106. index_data_bytes = await serving_utils.call_async(
  107. _serialize_index_data, index_data
  108. )
  109. await serving_utils.call_async(
  110. index_storage.set, index_key, index_data_bytes
  111. )
  112. return ResultResponse(
  113. logId=serving_utils.generate_log_id(),
  114. errorCode=0,
  115. errorMsg="Success",
  116. result=BuildIndexResult(indexKey=index_key, idMap=index_data.id_map),
  117. )
  118. except Exception as e:
  119. logging.exception(e)
  120. raise HTTPException(status_code=500, detail="Internal server error")
  121. @app.post(
  122. "/face-recognition-index-add",
  123. operation_id="buildIndex",
  124. responses={422: {"model": Response}},
  125. )
  126. async def _add_images_to_index(
  127. request: AddImagesToIndexRequest,
  128. ) -> ResultResponse[AddImagesToIndexResult]:
  129. pipeline = ctx.pipeline
  130. aiohttp_session = ctx.aiohttp_session
  131. try:
  132. images = [pair.image for pair in request.imageLabelPairs]
  133. file_bytes_list = await asyncio.gather(
  134. *(serving_utils.get_raw_bytes(img, aiohttp_session) for img in images)
  135. )
  136. images = [
  137. serving_utils.image_bytes_to_array(item) for item in file_bytes_list
  138. ]
  139. labels = [pair.label for pair in request.imageLabelPairs]
  140. index_storage = ctx.extra["index_storage"]
  141. index_data_bytes = await serving_utils.call_async(
  142. index_storage.get, request.indexKey
  143. )
  144. index_data = await serving_utils.call_async(
  145. _deserialize_index_data, index_data_bytes
  146. )
  147. index_data = await pipeline.call(
  148. pipeline.pipeline.append_index, images, labels, index_data
  149. )
  150. index_data_bytes = await serving_utils.call_async(
  151. _serialize_index_data, index_data
  152. )
  153. await serving_utils.call_async(
  154. index_storage.set, request.indexKey, index_data_bytes
  155. )
  156. return ResultResponse(
  157. logId=serving_utils.generate_log_id(),
  158. errorCode=0,
  159. errorMsg="Success",
  160. result=AddImagesToIndexResult(idMap=index_data.id_map),
  161. )
  162. except Exception as e:
  163. logging.exception(e)
  164. raise HTTPException(status_code=500, detail="Internal server error")
  165. @app.post(
  166. "/face-recognition-index-remove",
  167. operation_id="buildIndex",
  168. responses={422: {"model": Response}},
  169. )
  170. async def _remove_images_from_index(
  171. request: RemoveImagesFromIndexRequest,
  172. ) -> ResultResponse[RemoveImagesFromIndexResult]:
  173. pipeline = ctx.pipeline
  174. try:
  175. index_storage = ctx.extra["index_storage"]
  176. index_data_bytes = await serving_utils.call_async(
  177. index_storage.get, request.indexKey
  178. )
  179. index_data = await serving_utils.call_async(
  180. _deserialize_index_data, index_data_bytes
  181. )
  182. index_data = await pipeline.call(
  183. pipeline.pipeline.remove_index, request.ids, index_data
  184. )
  185. index_data_bytes = await serving_utils.call_async(
  186. _serialize_index_data, index_data
  187. )
  188. await serving_utils.call_async(
  189. index_storage.set, request.indexKey, index_data_bytes
  190. )
  191. return ResultResponse(
  192. logId=serving_utils.generate_log_id(),
  193. errorCode=0,
  194. errorMsg="Success",
  195. result=RemoveImagesFromIndexResult(idMap=index_data.id_map),
  196. )
  197. except Exception as e:
  198. logging.exception(e)
  199. raise HTTPException(status_code=500, detail="Internal server error")
  200. @app.post(
  201. "/face-recognition-infer",
  202. operation_id="infer",
  203. responses={422: {"model": Response}},
  204. )
  205. async def _infer(request: InferRequest) -> ResultResponse[InferResult]:
  206. pipeline = ctx.pipeline
  207. aiohttp_session = ctx.aiohttp_session
  208. try:
  209. image_bytes = await serving_utils.get_raw_bytes(
  210. request.image, aiohttp_session
  211. )
  212. image = serving_utils.image_bytes_to_array(image_bytes)
  213. if request.indexKey is not None:
  214. index_storage = ctx.extra["index_storage"]
  215. index_data_bytes = await serving_utils.call_async(
  216. index_storage.get, request.indexKey
  217. )
  218. index_data = await serving_utils.call_async(
  219. _deserialize_index_data, index_data_bytes
  220. )
  221. else:
  222. index_data = None
  223. result = list(
  224. await pipeline.call(pipeline.pipeline.predict, image, index=index_data)
  225. )[0]
  226. faces: List[Face] = []
  227. for face in result["boxes"]:
  228. rec_results: List[RecResult] = []
  229. if face["rec_scores"] is not None:
  230. for label, score in zip(face["labels"], face["rec_scores"]):
  231. rec_results.append(
  232. RecResult(
  233. label=label,
  234. score=score,
  235. )
  236. )
  237. faces.append(
  238. Face(
  239. bbox=face["coordinate"],
  240. recResults=rec_results,
  241. score=face["det_score"],
  242. )
  243. )
  244. output_image_base64 = serving_utils.base64_encode(
  245. serving_utils.image_to_bytes(result.img)
  246. )
  247. return ResultResponse(
  248. logId=serving_utils.generate_log_id(),
  249. errorCode=0,
  250. errorMsg="Success",
  251. result=InferResult(faces=faces, image=output_image_base64),
  252. )
  253. except Exception as e:
  254. logging.exception(e)
  255. raise HTTPException(status_code=500, detail="Internal server error")
  256. return app