face_recognition.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. from operator import attrgetter
  16. from typing import Any, Dict, List
  17. from fastapi import FastAPI
  18. from ....pipelines.components import IndexData
  19. from ...infra import utils as serving_utils
  20. from ...infra.config import AppConfig
  21. from ...infra.models import ResultResponse
  22. from ...schemas import face_recognition as schema
  23. from .._app import create_app, primary_operation
  24. from ._common import image_recognition as ir_common
  25. # XXX: Currently the implementations of the face recognition and PP-ShiTuV2
  26. # pipeline apps overlap significantly. We should aim to facilitate code reuse,
  27. # but is it acceptable to assume a strong similarity between these two
  28. # pipelines?
  29. def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
  30. app, ctx = create_app(
  31. pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
  32. )
  33. ir_common.update_app_context(ctx)
  34. @primary_operation(
  35. app,
  36. schema.BUILD_INDEX_ENDPOINT,
  37. "buildIndex",
  38. )
  39. async def _build_index(
  40. request: schema.BuildIndexRequest,
  41. ) -> ResultResponse[schema.BuildIndexResult]:
  42. pipeline = ctx.pipeline
  43. aiohttp_session = ctx.aiohttp_session
  44. file_bytes_list = await asyncio.gather(
  45. *(
  46. serving_utils.get_raw_bytes_async(img, aiohttp_session)
  47. for img in map(attrgetter("image"), request.imageLabelPairs)
  48. )
  49. )
  50. images = [serving_utils.image_bytes_to_array(item) for item in file_bytes_list]
  51. labels = [pair.label for pair in request.imageLabelPairs]
  52. # TODO: Support specifying `index_type` and `metric_type` in the
  53. # request
  54. index_data = await pipeline.call(
  55. pipeline.pipeline.build_index,
  56. images,
  57. labels,
  58. index_type="Flat",
  59. metric_type="IP",
  60. )
  61. index_storage = ctx.extra["index_storage"]
  62. index_key = ir_common.generate_index_key()
  63. index_data_bytes = index_data.to_bytes()
  64. await serving_utils.call_async(index_storage.set, index_key, index_data_bytes)
  65. return ResultResponse[schema.BuildIndexResult](
  66. logId=serving_utils.generate_log_id(),
  67. result=schema.BuildIndexResult(indexKey=index_key, idMap=index_data.id_map),
  68. )
  69. @primary_operation(
  70. app,
  71. schema.ADD_IMAGES_TO_INDEX_ENDPOINT,
  72. "addImagesToIndex",
  73. )
  74. async def _add_images_to_index(
  75. request: schema.AddImagesToIndexRequest,
  76. ) -> ResultResponse[schema.AddImagesToIndexResult]:
  77. pipeline = ctx.pipeline
  78. aiohttp_session = ctx.aiohttp_session
  79. file_bytes_list = await asyncio.gather(
  80. *(
  81. serving_utils.get_raw_bytes_async(img, aiohttp_session)
  82. for img in map(attrgetter("image"), request.imageLabelPairs)
  83. )
  84. )
  85. images = [serving_utils.image_bytes_to_array(item) for item in file_bytes_list]
  86. labels = [pair.label for pair in request.imageLabelPairs]
  87. index_storage = ctx.extra["index_storage"]
  88. index_data_bytes = await serving_utils.call_async(
  89. index_storage.get, request.indexKey
  90. )
  91. index_data = IndexData.from_bytes(index_data_bytes)
  92. index_data = await pipeline.call(
  93. pipeline.pipeline.append_index, images, labels, index_data
  94. )
  95. index_data_bytes = index_data.to_bytes()
  96. await serving_utils.call_async(
  97. index_storage.set, request.indexKey, index_data_bytes
  98. )
  99. return ResultResponse[schema.AddImagesToIndexResult](
  100. logId=serving_utils.generate_log_id(),
  101. result=schema.AddImagesToIndexResult(idMap=index_data.id_map),
  102. )
  103. @primary_operation(
  104. app,
  105. schema.REMOVE_IMAGES_FROM_INDEX_ENDPOINT,
  106. "removeImagesFromIndex",
  107. )
  108. async def _remove_images_from_index(
  109. request: schema.RemoveImagesFromIndexRequest,
  110. ) -> ResultResponse[schema.RemoveImagesFromIndexResult]:
  111. pipeline = ctx.pipeline
  112. index_storage = ctx.extra["index_storage"]
  113. index_data_bytes = await serving_utils.call_async(
  114. index_storage.get, request.indexKey
  115. )
  116. index_data = IndexData.from_bytes(index_data_bytes)
  117. index_data = await pipeline.call(
  118. pipeline.pipeline.remove_index, request.ids, index_data
  119. )
  120. index_data_bytes = index_data.to_bytes()
  121. await serving_utils.call_async(
  122. index_storage.set, request.indexKey, index_data_bytes
  123. )
  124. return ResultResponse[schema.RemoveImagesFromIndexResult](
  125. logId=serving_utils.generate_log_id(),
  126. result=schema.RemoveImagesFromIndexResult(idMap=index_data.id_map),
  127. )
  128. @primary_operation(
  129. app,
  130. schema.INFER_ENDPOINT,
  131. "infer",
  132. )
  133. async def _infer(
  134. request: schema.InferRequest,
  135. ) -> ResultResponse[schema.InferResult]:
  136. pipeline = ctx.pipeline
  137. aiohttp_session = ctx.aiohttp_session
  138. image_bytes = await serving_utils.get_raw_bytes_async(
  139. request.image, aiohttp_session
  140. )
  141. image = serving_utils.image_bytes_to_array(image_bytes)
  142. if request.indexKey is not None:
  143. index_storage = ctx.extra["index_storage"]
  144. index_data_bytes = await serving_utils.call_async(
  145. index_storage.get, request.indexKey
  146. )
  147. index_data = IndexData.from_bytes(index_data_bytes)
  148. else:
  149. index_data = None
  150. result = list(
  151. await pipeline.call(
  152. pipeline.pipeline.predict,
  153. image,
  154. index=index_data,
  155. det_threshold=request.detThreshold,
  156. rec_threshold=request.recThreshold,
  157. hamming_radius=request.hammingRadius,
  158. topk=request.topk,
  159. )
  160. )[0]
  161. objs: List[Dict[str, Any]] = []
  162. for obj in result["boxes"]:
  163. rec_results: List[Dict[str, Any]] = []
  164. if obj["rec_scores"] is not None:
  165. for label, score in zip(obj["labels"], obj["rec_scores"]):
  166. rec_results.append(
  167. dict(
  168. label=label,
  169. score=score,
  170. )
  171. )
  172. objs.append(
  173. dict(
  174. bbox=obj["coordinate"],
  175. recResults=rec_results,
  176. score=obj["det_score"],
  177. )
  178. )
  179. if ctx.config.visualize:
  180. output_image_base64 = serving_utils.base64_encode(
  181. serving_utils.image_to_bytes(result.img["res"])
  182. )
  183. else:
  184. output_image_base64 = None
  185. return ResultResponse[schema.InferResult](
  186. logId=serving_utils.generate_log_id(),
  187. result=schema.InferResult(faces=objs, image=output_image_base64),
  188. )
  189. return app