face_recognition.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. from operator import attrgetter
  16. from typing import Any, Dict, List
  17. from fastapi import FastAPI
  18. from ....pipelines.components import IndexData
  19. from ...infra import utils as serving_utils
  20. from ...infra.config import AppConfig
  21. from ...infra.models import ResultResponse
  22. from ...schemas import face_recognition as schema
  23. from .._app import create_app, primary_operation
  24. from ._common import image_recognition as ir_common
  25. # XXX: Currently the implementations of the face recognition and PP-ShiTuV2
  26. # pipeline apps overlap significantly. We should aim to facilitate code reuse,
  27. # but is it acceptable to assume a strong similarity between these two
  28. # pipelines?
  29. def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
  30. app, ctx = create_app(
  31. pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
  32. )
  33. ir_common.update_app_context(ctx)
  34. @primary_operation(
  35. app,
  36. schema.BUILD_INDEX_ENDPOINT,
  37. "buildIndex",
  38. )
  39. async def _build_index(
  40. request: schema.BuildIndexRequest,
  41. ) -> ResultResponse[schema.BuildIndexResult]:
  42. pipeline = ctx.pipeline
  43. aiohttp_session = ctx.aiohttp_session
  44. file_bytes_list = await asyncio.gather(
  45. *(
  46. serving_utils.get_raw_bytes_async(img, aiohttp_session)
  47. for img in map(attrgetter("image"), request.imageLabelPairs)
  48. )
  49. )
  50. images = [serving_utils.image_bytes_to_array(item) for item in file_bytes_list]
  51. labels = [pair.label for pair in request.imageLabelPairs]
  52. # TODO: Support specifying `index_type` and `metric_type` in the
  53. # request
  54. index_data = await pipeline.call(
  55. pipeline.pipeline.build_index,
  56. images,
  57. labels,
  58. index_type="Flat",
  59. metric_type="IP",
  60. )
  61. index_storage = ctx.extra["index_storage"]
  62. index_key = ir_common.generate_index_key()
  63. index_data_bytes = index_data.to_bytes()
  64. await serving_utils.call_async(index_storage.set, index_key, index_data_bytes)
  65. return ResultResponse[schema.BuildIndexResult](
  66. logId=serving_utils.generate_log_id(),
  67. result=schema.BuildIndexResult(
  68. indexKey=index_key, imageCount=len(index_data.id_map)
  69. ),
  70. )
  71. @primary_operation(
  72. app,
  73. schema.ADD_IMAGES_TO_INDEX_ENDPOINT,
  74. "addImagesToIndex",
  75. )
  76. async def _add_images_to_index(
  77. request: schema.AddImagesToIndexRequest,
  78. ) -> ResultResponse[schema.AddImagesToIndexResult]:
  79. pipeline = ctx.pipeline
  80. aiohttp_session = ctx.aiohttp_session
  81. file_bytes_list = await asyncio.gather(
  82. *(
  83. serving_utils.get_raw_bytes_async(img, aiohttp_session)
  84. for img in map(attrgetter("image"), request.imageLabelPairs)
  85. )
  86. )
  87. images = [serving_utils.image_bytes_to_array(item) for item in file_bytes_list]
  88. labels = [pair.label for pair in request.imageLabelPairs]
  89. index_storage = ctx.extra["index_storage"]
  90. index_data_bytes = await serving_utils.call_async(
  91. index_storage.get, request.indexKey
  92. )
  93. index_data = IndexData.from_bytes(index_data_bytes)
  94. index_data = await pipeline.call(
  95. pipeline.pipeline.append_index, images, labels, index_data
  96. )
  97. index_data_bytes = index_data.to_bytes()
  98. await serving_utils.call_async(
  99. index_storage.set, request.indexKey, index_data_bytes
  100. )
  101. return ResultResponse[schema.AddImagesToIndexResult](
  102. logId=serving_utils.generate_log_id(),
  103. result=schema.AddImagesToIndexResult(imageCount=len(index_data.id_map)),
  104. )
  105. @primary_operation(
  106. app,
  107. schema.REMOVE_IMAGES_FROM_INDEX_ENDPOINT,
  108. "removeImagesFromIndex",
  109. )
  110. async def _remove_images_from_index(
  111. request: schema.RemoveImagesFromIndexRequest,
  112. ) -> ResultResponse[schema.RemoveImagesFromIndexResult]:
  113. pipeline = ctx.pipeline
  114. index_storage = ctx.extra["index_storage"]
  115. index_data_bytes = await serving_utils.call_async(
  116. index_storage.get, request.indexKey
  117. )
  118. index_data = IndexData.from_bytes(index_data_bytes)
  119. index_data = await pipeline.call(
  120. pipeline.pipeline.remove_index, request.ids, index_data
  121. )
  122. index_data_bytes = index_data.to_bytes()
  123. await serving_utils.call_async(
  124. index_storage.set, request.indexKey, index_data_bytes
  125. )
  126. return ResultResponse[schema.RemoveImagesFromIndexResult](
  127. logId=serving_utils.generate_log_id(),
  128. result=schema.RemoveImagesFromIndexResult(
  129. imageCount=len(index_data.id_map)
  130. ),
  131. )
  132. @primary_operation(
  133. app,
  134. schema.INFER_ENDPOINT,
  135. "infer",
  136. )
  137. async def _infer(
  138. request: schema.InferRequest,
  139. ) -> ResultResponse[schema.InferResult]:
  140. pipeline = ctx.pipeline
  141. aiohttp_session = ctx.aiohttp_session
  142. image_bytes = await serving_utils.get_raw_bytes_async(
  143. request.image, aiohttp_session
  144. )
  145. image = serving_utils.image_bytes_to_array(image_bytes)
  146. if request.indexKey is not None:
  147. index_storage = ctx.extra["index_storage"]
  148. index_data_bytes = await serving_utils.call_async(
  149. index_storage.get, request.indexKey
  150. )
  151. index_data = IndexData.from_bytes(index_data_bytes)
  152. else:
  153. index_data = None
  154. result = list(
  155. await pipeline.call(
  156. pipeline.pipeline.predict,
  157. image,
  158. index=index_data,
  159. det_threshold=request.detThreshold,
  160. rec_threshold=request.recThreshold,
  161. hamming_radius=request.hammingRadius,
  162. topk=request.topk,
  163. )
  164. )[0]
  165. objs: List[Dict[str, Any]] = []
  166. for obj in result["boxes"]:
  167. rec_results: List[Dict[str, Any]] = []
  168. if obj["rec_scores"] is not None:
  169. for label, score in zip(obj["labels"], obj["rec_scores"]):
  170. rec_results.append(
  171. dict(
  172. label=label,
  173. score=score,
  174. )
  175. )
  176. objs.append(
  177. dict(
  178. bbox=obj["coordinate"],
  179. recResults=rec_results,
  180. score=obj["det_score"],
  181. )
  182. )
  183. if ctx.config.visualize:
  184. output_image_base64 = serving_utils.base64_encode(
  185. serving_utils.image_to_bytes(result.img["res"])
  186. )
  187. else:
  188. output_image_base64 = None
  189. return ResultResponse[schema.InferResult](
  190. logId=serving_utils.generate_log_id(),
  191. result=schema.InferResult(faces=objs, image=output_image_base64),
  192. )
  193. return app